torchbox.module.misc package
Submodules
torchbox.module.misc.transform module
- class torchbox.module.misc.transform.Standardization(mean=None, std=None, axis=None, unbiased=False, retall=False)
Bases:
torch.nn.modules.module.Module
\[\bar{X} = \frac{X-\mu}{\sigma} \]- Parameters
X (Tensor) – data to be normalized,
mean (list or None, optional) – mean value (the default is None, which means auto computed)
std (list or None, optional) – Standard deviation (the default is None, which means auto computed)
axis (list or int, optional) – Specify the axis for computing mean and standard deviation (the default is None, which means all elements)
unbiased (bool, optional) – If unbiased is False, then the standard-deviation will be calculated via the biased estimator. Otherwise, Bessel’s correction will be used.
retall (bool, optional) – If True, also return the mean and std (the default is False, which means just return the standardized data)
Examples
import torchbox as tb tb.setseed(seed=2020, target='torch') x = th.randn(5, 2, 4, 3) f = Standardization(axis=(2, 3), unbiased=False, retall=True) y, meanv, stdv = f(x) print(y[0], y.shape) g = th.nn.InstanceNorm2d(2) z = g(x) print(z[0], z.shape) f = Standardization(axis=(0, 2, 3), unbiased=False, retall=True) y, meanv, stdv = f(x) print(y[0], y.shape) g = th.nn.BatchNorm2d(2) z = g(x) print(z[0], z.shape)
The results are:
tensor([[[ 0.2761, -0.1161, -1.3316], [ 0.4918, 0.5450, -0.7350], [ 1.5699, -1.8567, 1.7366], [-0.1463, -0.1318, -0.3019]], [[-1.0576, 0.5794, -0.6489], [-0.3410, -1.6589, 0.2531], [ 1.2150, 0.7262, 0.3333], [-1.1270, -0.2132, 1.9397]]]) torch.Size([5, 2, 4, 3]) tensor([[[ 0.2761, -0.1161, -1.3316], [ 0.4918, 0.5450, -0.7350], [ 1.5699, -1.8567, 1.7366], [-0.1463, -0.1318, -0.3019]], [[-1.0576, 0.5794, -0.6489], [-0.3410, -1.6588, 0.2531], [ 1.2150, 0.7262, 0.3333], [-1.1270, -0.2132, 1.9397]]]) torch.Size([5, 2, 4, 3]) tensor([[[ 0.0498, -0.2576, -1.2101], [ 0.2188, 0.2605, -0.7426], [ 1.0637, -1.6216, 1.1943], [-0.2812, -0.2698, -0.4032]], [[-1.1965, 0.0760, -0.8788], [-0.6395, -1.6639, -0.1776], [ 0.5701, 0.1901, -0.1153], [-1.2505, -0.5402, 1.1335]]]) torch.Size([5, 2, 4, 3]) tensor([[[ 0.0498, -0.2576, -1.2101], [ 0.2188, 0.2605, -0.7426], [ 1.0637, -1.6216, 1.1943], [-0.2812, -0.2698, -0.4032]], [[-1.1965, 0.0760, -0.8788], [-0.6395, -1.6639, -0.1776], [ 0.5701, 0.1901, -0.1153], [-1.2505, -0.5401, 1.1335]]], grad_fn=<SelectBackward>) torch.Size([5, 2, 4, 3])
- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.