ml4gw.nn.autoencoder.base

Classes

Autoencoder([skip_connection])

Base autoencoder class that defines some of the basic methods and functionality.

class ml4gw.nn.autoencoder.base.Autoencoder(skip_connection=None)

Bases: Module

Base autoencoder class that defines some of the basic methods and functionality. Autoencoders are defined here as a set of sequential blocks that have an encode method, which acts on the input data to the autoencoder, and a decode method, which acts on the encoded vector generated by the encode method. forward just runs these steps one after the other. Although it isn't explicitly enforced, a good rule of thumb is that the ouput of a block's decode method should have the same shape as the _input_ of its encode method.

Accepts a skip_connection argument that defines how to combine information from the input of one block's encode layer with the output to its decode layer. See skip_connections.py for more info about what these classes are expected to contain and how they operate.

Parameters:

skip_connection (SkipConnection | None)

decode(*X, states=None)
Return type:

Tensor

Parameters:

states (Sequence[Tensor] | None)

encode(*X, return_states=False)
Return type:

Union[Tensor, Tuple[Tensor, Sequence]]

Parameters:
  • X (Tensor)

  • return_states (bool)

forward(*X)
Return type:

Tensor

Parameters:

X (Tensor)