ml4gw.nn.streaming.snapshotter
Classes
|
Model for converting streaming state updates into a batch of overlapping snaphots of a multichannel timeseries. |
- class ml4gw.nn.streaming.snapshotter.Snapshotter(num_channels, snapshot_size, stride_size, batch_size, channels_per_snapshot=None)
Bases:
ModuleModel for converting streaming state updates into a batch of overlapping snaphots of a multichannel timeseries. Can support multiple timeseries in a single state update via the
channels_per_snapshotkwarg.Specifically, maps tensors of shape
(num_channels, batch_size * stride_size)to a tensor of shape(batch_size, num_channels, snapshot_size). Ifchannels_per_snapshotis specified, it will returnlen(channels_per_snapshot)tensors of this shape, with the channel dimension replaced by the corresponding value ofchannels_per_snapshot. The last tensor returned at call time will be the current state that can be passed to the nextforwardcall.- Parameters:
num_channels (
int) -- Number of channels in the timeseries. Ifchannels_per_snapshotis notNone, this should be equal tosum(channels_per_snapshot).snapshot_size (
int) -- The size of the output snapshot windows in number of samplesstride_size (
int) -- The number of samples in between each output snapshotbatch_size (
int) -- The number of snapshots to produce at each update. The last dimension of the input tensor should have sizebatch_size * stride_size.channels_per_snapshot (
Sequence[int] |None) -- How to split up the channels in the timeseries for different tensors. If left asNone, all the channels will be returned in a single tensor. Otherwise, the channels will be split up intolen(channels_per_snapshot)tensors, with each tensor's channel dimension being equal to the corresponding value inchannels_per_snapshot. Therefore, if specified, these values should add up tonum_channels.
- forward(update, snapshot=None)
- Return type:
tuple[Tensor,...]- Parameters:
update (Float[Tensor, 'channel time1'])
snapshot (Float[Tensor, 'channel time2'] | None)
- get_initial_state()
- Return type:
Float[Tensor, 'channel time']