ml4gw.nn.autoencoder.skip_connection

Classes

AddSkipConnect(*args, **kwargs)

ConcatSkipConnect([groups])

SkipConnection(*args, **kwargs)

class ml4gw.nn.autoencoder.skip_connection.AddSkipConnect(*args, **kwargs)

Bases: SkipConnection

forward(X, state)
Return type:

Tensor

Parameters:
  • X (Tensor)

  • state (Tensor)

class ml4gw.nn.autoencoder.skip_connection.ConcatSkipConnect(groups=1)

Bases: SkipConnection

Parameters:

groups (int)

forward(X, state)
Return type:

Tensor

Parameters:
  • X (Tensor)

  • state (Tensor)

get_out_channels(in_channels)
Return type:

int

Parameters:

in_channels (int)

class ml4gw.nn.autoencoder.skip_connection.SkipConnection(*args, **kwargs)

Bases: Module

forward(X, state)
Return type:

Tensor

Parameters:
  • X (Tensor)

  • state (Tensor)

get_out_channels(in_channels)
Return type:

int

Parameters:

in_channels (int)