ml4gw.transforms.decimator

Classes

Decimator([sample_rate, schedule])

Downsample (decimate) a timeseries according to a user-defined schedule.

class ml4gw.transforms.decimator.Decimator(sample_rate=None, schedule=None)

Bases: Module

Downsample (decimate) a timeseries according to a user-defined schedule.

Note

This is a naive decimator that does not use any IIR/FIR filtering and selects every M-th sample according to the schedule.

The schedule specifies which segments of the input to keep and at what sampling rate. Each row of the schedule has the form:

[start_time, end_time, target_sample_rate]

Parameters:
  • sample_rate (int) -- Sampling rate (Hz) of the input timeseries.

  • schedule (torch.Tensor) -- Tensor of shape (N, 3) defining start time, end time, and target sample rate for each segment.

Shape:
  • Input: (B, C, T) where
    • B = batch size

    • C = channels

    • T = number of timesteps

      (must equal schedule duration × sample_rate)

  • Output:
    • If split=False(B, C, T') where T' is total number of decimated samples across all segments.

    • If split=True → list of tensors, one per segment.

Returns:

The decimated timeseries, or list of decimated segments if split=True.

Return type:

torch.Tensor or List[torch.Tensor]

Parameters:
  • sample_rate (int)

  • schedule (Tensor)

Example

>>> import torch
>>> from ml4gw.transforms.decimator import Decimator

>>> sample_rate = 2048
>>> X_duration = 60

>>> schedule = torch.tensor(
...     [[0, 40, 256], [40, 58, 512], [58, 60, 2048]],
...     dtype=torch.int,
... )

>>> decimator = Decimator(sample_rate=sample_rate,
...    schedule=schedule)
>>> X = torch.randn(1, 1, sample_rate * X_duration)
>>> X_dec = decimator(X)
>>> X_seg = decimator(X, split=True)

>>> print("Original shape:", X.shape)
Original shape: torch.Size([1, 1, 122880])
>>> print("Decimated shape:", X_dec.shape)
Decimated shape: torch.Size([1, 1, 23552])
>>> for i, seg in enumerate(X_seg):
...     print(f"Segment {i} shape:", seg.shape)
Segment 0 shape: torch.Size([1, 1, 10240])
Segment 1 shape: torch.Size([1, 1, 9216])
Segment 2 shape: torch.Size([1, 1, 4096])
build_variable_indices()

Compute the time indices to keep based on the schedule.

Returns:

1D tensor of indices used to decimate the input.

Return type:

torch.Tensor

forward(X, split=False)

Apply decimation to the input timeseries.

Parameters:
  • X (torch.Tensor) -- Input tensor of shape (B, C, T), where T must equal schedule duration × sample_rate.

  • split (bool, optional) -- If True, return a list of segments instead of a single concatenated tensor. Default: False.

Returns:

Decimated timeseries, or list of decimated segments.

Return type:

torch.Tensor or List[torch.Tensor]

split_by_schedule(X)

Split a decimated timeseries into segments according to the schedule.

Parameters:

X (torch.Tensor) -- Decimated input of shape (B, C, T').

Returns:

Each segment has shape \((B, C, T_i)\) where \(T_i\) is the length implied by the corresponding schedule row.

Return type:

tuple of torch.Tensor