ml4gw.nn.norm

Classes

GroupNorm1D(num_channels[, num_groups, eps])

Custom implementation of GroupNorm which is faster than the out-of-the-box PyTorch version at inference time.

GroupNorm1DGetter([groups])

Utility for making a NormLayer Callable that maps from an integer number of channels to a torch Module.

GroupNorm2DGetter([groups])

Utility for making a NormLayer Callable that maps from an integer number of channels to a torch Module.

class ml4gw.nn.norm.GroupNorm1D(num_channels, num_groups=None, eps=1e-05)

Bases: Module

Custom implementation of GroupNorm which is faster than the out-of-the-box PyTorch version at inference time.

Parameters:
  • num_channels (int)

  • num_groups (int | None)

  • eps (float)

forward(x)
Return type:

Float[Tensor, 'batch channel length']

Parameters:

x (Float[Tensor, 'batch channel length'])

class ml4gw.nn.norm.GroupNorm1DGetter(groups=None)

Bases: object

Utility for making a NormLayer Callable that maps from an integer number of channels to a torch Module. Useful for command-line parameterization with jsonargparse.

Parameters:

groups (int | None)

class ml4gw.nn.norm.GroupNorm2DGetter(groups=None)

Bases: object

Utility for making a NormLayer Callable that maps from an integer number of channels to a torch Module. Useful for command-line parameterization with jsonargparse.

Parameters:

groups (int | None)