ml4gw.nn.streaming.snapshotter

Classes

Snapshotter(num_channels, snapshot_size, ...)

Model for converting streaming state updates into a batch of overlapping snaphots of a multichannel timeseries.

class ml4gw.nn.streaming.snapshotter.Snapshotter(num_channels, snapshot_size, stride_size, batch_size, channels_per_snapshot=None)

Bases: Module

Model for converting streaming state updates into a batch of overlapping snaphots of a multichannel timeseries. Can support multiple timeseries in a single state update via the channels_per_snapshot kwarg.

Specifically, maps tensors of shape (num_channels, batch_size * stride_size) to a tensor of shape (batch_size, num_channels, snapshot_size). If channels_per_snapshot is specified, it will return len(channels_per_snapshot) tensors of this shape, with the channel dimension replaced by the corresponding value of channels_per_snapshot. The last tensor returned at call time will be the current state that can be passed to the next forward call.

Parameters:
  • num_channels (int) -- Number of channels in the timeseries. If channels_per_snapshot is not None, this should be equal to sum(channels_per_snapshot).

  • snapshot_size (int) -- The size of the output snapshot windows in number of samples

  • stride_size (int) -- The number of samples in between each output snapshot

  • batch_size (int) -- The number of snapshots to produce at each update. The last dimension of the input tensor should have size batch_size * stride_size.

  • channels_per_snapshot (Optional[Sequence[int]]) -- How to split up the channels in the timeseries for different tensors. If left as None, all the channels will be returned in a single tensor. Otherwise, the channels will be split up into len(channels_per_snapshot) tensors, with each tensor's channel dimension being equal to the corresponding value in channels_per_snapshot. Therefore, if specified, these values should add up to num_channels.

forward(update, snapshot=None)
Return type:

Tuple[Tensor, ...]

Parameters:
  • update (Float[Tensor, 'channel time1'])

  • snapshot (Float[Tensor, 'channel time2'] | None)

get_initial_state()
Return type:

Float[Tensor, 'channel time']