torchbox.diagnose package

Submodules

torchbox.diagnose.plotgradflow module

torchbox.diagnose.plotgradflow.plot_gradflow_v1(named_parameters)
torchbox.diagnose.plotgradflow.plot_gradflow_v2(named_parameters)

Plots the gradients flowing through different layers in the net during training. Can be used for checking for possible gradient vanishing / exploding problems.

Usage: Plug this function in Trainer class after loss.backwards() as “plot_grad_flow(self.model.named_parameters())” to visualize the gradient flow

Module contents