ml4gw.nn.streaming.online_average

Classes

OnlineAverager(update_size, batch_size, ...)

Module for performing stateful online averaging of batches of overlapping timeseries.

class ml4gw.nn.streaming.online_average.OnlineAverager(update_size, batch_size, num_updates, num_channels, offset=None)

Bases: Module

Module for performing stateful online averaging of batches of overlapping timeseries. At present, the first num_updates predictions produced by this model will underestimate the true average.

Parameters:
  • update_size (int) -- The number of samples separating the timestamps of subsequent inputs.

  • batch_size (int) -- The number of batched inputs to expect at inference time.

  • num_updates (int) -- The number of steps over which to average predictions before returning them.

  • num_channels (int) -- The expected channel dimension of the input passed to the module at inference time.

  • offset (Optional[int]) -- Number of samples to throw away from the front edge of the kernel when averaging.

forward(update, state=None)
Return type:

Tuple[Float[Tensor, 'channel time3'], Float[Tensor, 'channel time4']]

Parameters:
  • update (Float[Tensor, 'batch channel time1'])

  • state (Float[Tensor, 'channel time2'] | None)

get_initial_state()
Return type:

Float[Tensor, 'channel time']