ml4gw.transforms.decimator
Classes
|
Downsample (decimate) a timeseries according to a user-defined schedule. |
- class ml4gw.transforms.decimator.Decimator(sample_rate=None, schedule=None, split=False)
Bases:
ModuleDownsample (decimate) a timeseries according to a user-defined schedule.
Note
This is a naive decimator that does not use any IIR/FIR filtering and selects every M-th sample according to the schedule.
The schedule specifies which segments of the input to keep and at what sampling rate. Each row of the schedule has the form:
[start_time, end_time, target_sample_rate]
- Parameters:
sample_rate (int) -- Sampling rate (Hz) of the input timeseries.
schedule (torch.Tensor) -- Tensor of shape (N, 3) defining start time, end time, and target sample rate for each segment.
split (bool, optional) --
If True, the module returns a list of decimated segments (one per schedule entry). Overlapping schedule segments are only allowed when
split=True.If False (default), the segments are concatenated into a single continuous output tensor.
- Shape:
- Input: (B, C, T) where
B = batch size
C = channels
- T = number of timesteps
(must equal schedule duration x sample_rate)
- Output:
If
split=False→ (B, C, T') where T' is the total number of decimated samples across all segments.If
split=True→ list of tensors, each with shape \((B, C, T_i)\), corresponding to the decimated samples in each schedule segment.
- Returns:
The decimated timeseries, or list of decimated segments if
split=True.- Return type:
torch.Tensor or List[torch.Tensor]
- Parameters:
sample_rate (int)
schedule (Tensor)
split (bool)
Example
>>> import torch >>> from ml4gw.transforms.decimator import Decimator >>> sample_rate = 2048 >>> X_duration = 60 # seconds >>> X = torch.randn(1, 1, sample_rate * X_duration) >>> schedule = torch.tensor( ... [[0, 40, 256], [40, 58, 512], [58, 60, 2048]], ... dtype=torch.int, ... ) >>> decimator = Decimator(sample_rate=sample_rate, ... schedule=schedule) >>> X_dec = decimator(X) >>> X_seg = decimator(X, split=True) >>> print("Original shape:", X.shape) Original shape: torch.Size([1, 1, 122880]) >>> print("Decimated shape:", X_dec.shape) Decimated shape: torch.Size([1, 1, 23552]) >>> for i, seg in enumerate(X_seg): ... print(f"Segment {i} shape:", seg.shape) Segment 0 shape: torch.Size([1, 1, 10240]) Segment 1 shape: torch.Size([1, 1, 9216]) Segment 2 shape: torch.Size([1, 1, 4096]) >>> overlap_schedule = torch.tensor( ... [[0, 40, 256], [32, 58, 512]], [52, 60, 2048]], ... dtype=torch.int, ... ) >>> decimator_ov = Decimator( ... sample_rate=sample_rate, ... schedule=overlap_schedule, ... split=True, ... ) >>> X_overlap = decimator_ov(X) >>> for i, seg in enumerate(X_overlap): ... print(f"Overlapping segment {i} shape:", seg.shape) Overlapping segment 0 shape: torch.Size([1, 1, 10240]) Overlapping segment 1 shape: torch.Size([1, 1, 13312]) Overlapping segment 2 shape: torch.Size([1, 1, 16384])
- build_variable_indices()
Compute the time indices to keep based on the schedule.
- Returns:
1D tensor of indices used to decimate the input.
- Return type:
torch.Tensor
- forward(X)
Apply decimation to the input timeseries according to the schedule.
- Parameters:
X (torch.Tensor) -- Input tensor of shape (B, C, T), where T must equal schedule duration x sample_rate.
- Returns:
If
split=False(default), returns a single decimated tensor of shape (B, C, T').If
split=True, returns a list of decimated segments, one per schedule entry.
- Return type:
torch.Tensor or list[torch.Tensor]
- split_by_schedule(X)
Split and decimate a timeseries into segments according to the schedule.
This method applies the decimation defined by each schedule row and returns a list of the resulting segments.
- Parameters:
X (torch.Tensor) -- Input timeseries of shape (B, C, T) before decimation.
- Returns:
Each segment has shape \((B, C, T_i)\) where \(T_i\) is the length implied by the corresponding schedule row.
- Return type:
tuple of torch.Tensor