ml4gw.nn.ssm.s4d
Classes
|
Single S4D layer operating on (B, d_model, L) sequences. |
|
Build the convolution kernel for one S4D layer. |
|
Full S4D sequence model for regression / classification. |
- class ml4gw.nn.ssm.s4d.S4D(d_model, d_state=64, dropout=0.0, transposed=True, dt_min=0.001, dt_max=0.1)
Bases:
ModuleSingle S4D layer operating on (B, d_model, L) sequences.
- Parameters:
d_model (
int) -- Model dimension (number of channels).d_state (
int) -- State size per model dimension.dropout (
float) -- Dropout probability applied to the S4D layer output.transposed (
bool) -- If True, input/output are (batch, channels, length). If False, input/output are (batch, length, channels).dt_min (
float) -- Minimum timestep for the S4D kernel.dt_max (
float) -- Maximum timestep for the S4D kernel.
- forward(u, **kwargs)
- Parameters:
u (
Tensor) -- (B, d_model, L) if transposed else (B, L, d_model)- Return type:
Tensor- Returns:
(B, d_model, L) output tensor.
- class ml4gw.nn.ssm.s4d.S4DKernel(d_model, N=64, dt_min=0.001, dt_max=0.1)
Bases:
ModuleBuild the convolution kernel for one S4D layer.
Each channel is an independent diagonal linear time-invariant state-space model, defined in continuous time by
\begin{align*} x'(t) &= A x(t) + B u(t) \\ y(t) &= C x(t) \end{align*}with \(A\) diagonal. Linear time invariance means the output equals the input convolved with a single kernel, so the whole sequence is produced in one convolution instead of being stepped through position by position. Discretizing with timestep \(dt\) gives the kernel this module returns:
\[K_l = 2 \mathrm{Re} \left( \sum_n C_n \frac{e^{dt A_n} - 1}{A_n} \left(e^{dt A_n}\right)^l \right), \qquad l = 0, \dots, L-1\]where \(n\) runs over the \(N/2\) conjugate-pair states.
- Parameters:
d_model (
int) -- Model dimension. Equivalent to the number of channels in CNN nomenclature, i.e. independent SSMs created.N (
int) -- State size per model dimension. The states are stored as N / 2 conjugate pairs, so N must be even.dt_min (
float) -- Lower bound of the per-channel timestep, sampled log-uniformly in [dt_min, dt_max]. The timestep sets the timescale each SSM resolves.dt_max (
float) -- Upper bound of the per-channel timestep.
- forward(L)
Returns: (d_model, L) convolution kernel.
- Return type:
Tensor- Parameters:
L (int)
- class ml4gw.nn.ssm.s4d.S4Model(d_input, d_output, d_model=256, d_state=64, n_layers=4, dropout=0.2, dt_min=0.001, dt_max=0.1)
Bases:
ModuleFull S4D sequence model for regression / classification.
Input: (B, d_input, L) Output: (B, d_output)
- Parameters:
d_input (
int) -- Input dimension (number of input features).d_output (
int) -- Output dimension (number of output classes/predictions).d_model (
int) -- Hidden model dimension (channels in S4D layers).d_state (
int) -- State size per model dimension in S4D layers.n_layers (
int) -- Number of stacked S4D layers.dropout (
float) -- Dropout probability in S4D layers.dt_min (
float) -- Minimum timestep for S4D kernels.dt_max (
float) -- Maximum timestep for S4D kernels.
Example
>>> model = S4Model( ... d_input=2, ... d_output=1, ... d_model=64, ... d_state=64, ... n_layers=4, ... ) >>> x = torch.randn(4, 2, 2048) # x shape: (B, d_input, L) >>> y = model(x) >>> y.shape torch.Size([4, 1])
- forward(x)
- Parameters:
x (
Tensor) -- (B, d_input, L)- Return type:
Tensor- Returns:
(B, d_output)