ml4gw.transforms.decimator
Classes
|
Downsample (decimate) a timeseries according to a user-defined schedule. |
- class ml4gw.transforms.decimator.Decimator(sample_rate=None, schedule=None)
Bases:
ModuleDownsample (decimate) a timeseries according to a user-defined schedule.
Note
This is a naive decimator that does not use any IIR/FIR filtering and selects every M-th sample according to the schedule.
The schedule specifies which segments of the input to keep and at what sampling rate. Each row of the schedule has the form:
[start_time, end_time, target_sample_rate]
- Parameters:
sample_rate (int) -- Sampling rate (Hz) of the input timeseries.
schedule (torch.Tensor) -- Tensor of shape (N, 3) defining start time, end time, and target sample rate for each segment.
- Shape:
- Input: (B, C, T) where
B = batch size
C = channels
- T = number of timesteps
(must equal schedule duration × sample_rate)
- Output:
If
split=False→ (B, C, T') where T' is total number of decimated samples across all segments.If
split=True→ list of tensors, one per segment.
- Returns:
The decimated timeseries, or list of decimated segments if
split=True.- Return type:
torch.Tensor or List[torch.Tensor]
- Parameters:
sample_rate (int)
schedule (Tensor)
Example
>>> import torch >>> from ml4gw.transforms.decimator import Decimator >>> sample_rate = 2048 >>> X_duration = 60 >>> schedule = torch.tensor( ... [[0, 40, 256], [40, 58, 512], [58, 60, 2048]], ... dtype=torch.int, ... ) >>> decimator = Decimator(sample_rate=sample_rate, ... schedule=schedule) >>> X = torch.randn(1, 1, sample_rate * X_duration) >>> X_dec = decimator(X) >>> X_seg = decimator(X, split=True) >>> print("Original shape:", X.shape) Original shape: torch.Size([1, 1, 122880]) >>> print("Decimated shape:", X_dec.shape) Decimated shape: torch.Size([1, 1, 23552]) >>> for i, seg in enumerate(X_seg): ... print(f"Segment {i} shape:", seg.shape) Segment 0 shape: torch.Size([1, 1, 10240]) Segment 1 shape: torch.Size([1, 1, 9216]) Segment 2 shape: torch.Size([1, 1, 4096])
- build_variable_indices()
Compute the time indices to keep based on the schedule.
- Returns:
1D tensor of indices used to decimate the input.
- Return type:
torch.Tensor
- forward(X, split=False)
Apply decimation to the input timeseries.
- Parameters:
X (torch.Tensor) -- Input tensor of shape (B, C, T), where T must equal schedule duration × sample_rate.
split (bool, optional) -- If True, return a list of segments instead of a single concatenated tensor. Default: False.
- Returns:
Decimated timeseries, or list of decimated segments.
- Return type:
torch.Tensor or List[torch.Tensor]
- split_by_schedule(X)
Split a decimated timeseries into segments according to the schedule.
- Parameters:
X (torch.Tensor) -- Decimated input of shape (B, C, T').
- Returns:
Each segment has shape \((B, C, T_i)\) where \(T_i\) is the length implied by the corresponding schedule row.
- Return type:
tuple of torch.Tensor