ml4gw.nn.streaming.online_average
Classes
|
Module for performing stateful online averaging of batches of overlapping timeseries. |
- class ml4gw.nn.streaming.online_average.OnlineAverager(update_size, batch_size, num_updates, num_channels, offset=None)
Bases:
Module
Module for performing stateful online averaging of batches of overlapping timeseries. At present, the first
num_updates
predictions produced by this model will underestimate the true average.- Parameters:
update_size (
int
) -- The number of samples separating the timestamps of subsequent inputs.batch_size (
int
) -- The number of batched inputs to expect at inference time.num_updates (
int
) -- The number of steps over which to average predictions before returning them.num_channels (
int
) -- The expected channel dimension of the input passed to the module at inference time.offset (
Optional
[int
]) -- Number of samples to throw away from the front edge of the kernel when averaging.
- forward(update, state=None)
- Return type:
Tuple
[Float[Tensor, 'channel time3']
,Float[Tensor, 'channel time4']
]- Parameters:
update (Float[Tensor, 'batch channel time1'])
state (Float[Tensor, 'channel time2'] | None)
- get_initial_state()
- Return type:
Float[Tensor, 'channel time']