ml4gw.nn.ssm.s4d

Classes

S4D(d_model[, d_state, dropout, transposed, ...])

Single S4D layer operating on (B, d_model, L) sequences.

S4DKernel(d_model[, N, dt_min, dt_max])

Build the convolution kernel for one S4D layer.

S4Model(d_input, d_output[, d_model, ...])

Full S4D sequence model for regression / classification.

class ml4gw.nn.ssm.s4d.S4D(d_model, d_state=64, dropout=0.0, transposed=True, dt_min=0.001, dt_max=0.1)

Bases: Module

Single S4D layer operating on (B, d_model, L) sequences.

Parameters:
  • d_model (int) -- Model dimension (number of channels).

  • d_state (int) -- State size per model dimension.

  • dropout (float) -- Dropout probability applied to the S4D layer output.

  • transposed (bool) -- If True, input/output are (batch, channels, length). If False, input/output are (batch, length, channels).

  • dt_min (float) -- Minimum timestep for the S4D kernel.

  • dt_max (float) -- Maximum timestep for the S4D kernel.

forward(u, **kwargs)
Parameters:

u (Tensor) -- (B, d_model, L) if transposed else (B, L, d_model)

Return type:

Tensor

Returns:

(B, d_model, L) output tensor.

class ml4gw.nn.ssm.s4d.S4DKernel(d_model, N=64, dt_min=0.001, dt_max=0.1)

Bases: Module

Build the convolution kernel for one S4D layer.

Each channel is an independent diagonal linear time-invariant state-space model, defined in continuous time by

\begin{align*} x'(t) &= A x(t) + B u(t) \\ y(t) &= C x(t) \end{align*}

with \(A\) diagonal. Linear time invariance means the output equals the input convolved with a single kernel, so the whole sequence is produced in one convolution instead of being stepped through position by position. Discretizing with timestep \(dt\) gives the kernel this module returns:

\[K_l = 2 \mathrm{Re} \left( \sum_n C_n \frac{e^{dt A_n} - 1}{A_n} \left(e^{dt A_n}\right)^l \right), \qquad l = 0, \dots, L-1\]

where \(n\) runs over the \(N/2\) conjugate-pair states.

Parameters:
  • d_model (int) -- Model dimension. Equivalent to the number of channels in CNN nomenclature, i.e. independent SSMs created.

  • N (int) -- State size per model dimension. The states are stored as N / 2 conjugate pairs, so N must be even.

  • dt_min (float) -- Lower bound of the per-channel timestep, sampled log-uniformly in [dt_min, dt_max]. The timestep sets the timescale each SSM resolves.

  • dt_max (float) -- Upper bound of the per-channel timestep.

forward(L)

Returns: (d_model, L) convolution kernel.

Return type:

Tensor

Parameters:

L (int)

class ml4gw.nn.ssm.s4d.S4Model(d_input, d_output, d_model=256, d_state=64, n_layers=4, dropout=0.2, dt_min=0.001, dt_max=0.1)

Bases: Module

Full S4D sequence model for regression / classification.

Input: (B, d_input, L) Output: (B, d_output)

Parameters:
  • d_input (int) -- Input dimension (number of input features).

  • d_output (int) -- Output dimension (number of output classes/predictions).

  • d_model (int) -- Hidden model dimension (channels in S4D layers).

  • d_state (int) -- State size per model dimension in S4D layers.

  • n_layers (int) -- Number of stacked S4D layers.

  • dropout (float) -- Dropout probability in S4D layers.

  • dt_min (float) -- Minimum timestep for S4D kernels.

  • dt_max (float) -- Maximum timestep for S4D kernels.

Example

>>> model = S4Model(
...     d_input=2,
...     d_output=1,
...     d_model=64,
...     d_state=64,
...     n_layers=4,
... )
>>> x = torch.randn(4, 2, 2048)  # x shape: (B, d_input, L)
>>> y = model(x)
>>> y.shape
torch.Size([4, 1])
forward(x)
Parameters:

x (Tensor) -- (B, d_input, L)

Return type:

Tensor

Returns:

(B, d_output)