torchbox.diagnose package
Submodules
torchbox.diagnose.plotgradflow module
- torchbox.diagnose.plotgradflow.plot_gradflow_v1(named_parameters)
- torchbox.diagnose.plotgradflow.plot_gradflow_v2(named_parameters)
Plots the gradients flowing through different layers in the net during training. Can be used for checking for possible gradient vanishing / exploding problems.
Usage: Plug this function in Trainer class after loss.backwards() as “plot_grad_flow(self.model.named_parameters())” to visualize the gradient flow