ml4gw.nn.streaming.snapshotter
Classes
|
Model for converting streaming state updates into a batch of overlapping snaphots of a multichannel timeseries. |
- class ml4gw.nn.streaming.snapshotter.Snapshotter(num_channels, snapshot_size, stride_size, batch_size, channels_per_snapshot=None)
Bases:
Module
Model for converting streaming state updates into a batch of overlapping snaphots of a multichannel timeseries. Can support multiple timeseries in a single state update via the
channels_per_snapshot
kwarg.Specifically, maps tensors of shape
(num_channels, batch_size * stride_size)
to a tensor of shape(batch_size, num_channels, snapshot_size)
. Ifchannels_per_snapshot
is specified, it will returnlen(channels_per_snapshot)
tensors of this shape, with the channel dimension replaced by the corresponding value ofchannels_per_snapshot
. The last tensor returned at call time will be the current state that can be passed to the nextforward
call.- Parameters:
num_channels (
int
) -- Number of channels in the timeseries. Ifchannels_per_snapshot
is notNone
, this should be equal tosum(channels_per_snapshot)
.snapshot_size (
int
) -- The size of the output snapshot windows in number of samplesstride_size (
int
) -- The number of samples in between each output snapshotbatch_size (
int
) -- The number of snapshots to produce at each update. The last dimension of the input tensor should have sizebatch_size * stride_size
.channels_per_snapshot (
Optional
[Sequence
[int
]]) -- How to split up the channels in the timeseries for different tensors. If left asNone
, all the channels will be returned in a single tensor. Otherwise, the channels will be split up intolen(channels_per_snapshot)
tensors, with each tensor's channel dimension being equal to the corresponding value inchannels_per_snapshot
. Therefore, if specified, these values should add up tonum_channels
.
- forward(update, snapshot=None)
- Return type:
Tuple
[Tensor
,...
]- Parameters:
update (Float[Tensor, 'channel time1'])
snapshot (Float[Tensor, 'channel time2'] | None)
- get_initial_state()
- Return type:
Float[Tensor, 'channel time']