Source code for buoy.models.amplfi

import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

import torch
from jsonargparse import ArgumentParser
from ml4gw.transforms import ChannelWiseScaler, SpectralDensity, Whiten

from buoy.models.base import BuoyModel
from buoy.utils.data import get_local_or_hf, slice_amplfi_data
from buoy.utils.pe import postprocess_samples, run_amplfi

if TYPE_CHECKING:
    from amplfi.train.architectures.flows import FlowArchitecture
    from amplfi.train.prior import AmplfiPrior
    from amplfi.utils.result import AmplfiResult

REPO_ID = "ML4GW/amplfi"


[docs] @dataclass class AmplfiConfig: architecture: "FlowArchitecture" parameter_sampler: "AmplfiPrior" sample_rate: float kernel_length: float inference_params: list[str] event_position: float psd_length: float fduration: float fftlength: float highpass: float lowpass: float | None = None
[docs] class Amplfi(AmplfiConfig, BuoyModel): """ AMPLFI normalizing-flow model for rapid gravitational wave parameter estimation. Wraps a trained Lightning checkpoint and its associated preprocessing pipeline. Config attributes are always loaded; the flow weights are only loaded when ``load_weights=True``. """ def __init__( self, model_weights: str | Path = "amplfi-hlv.ckpt", config: str | Path = "amplfi-hlv-config.yaml", device: str | None = None, revision: str | None = None, load_weights: bool = True, cache_dir: str | Path | None = None, ): """ Args: model_weights: Path to Lightning checkpoint (`.ckpt`) or filename in the ``ML4GW/amplfi`` HuggingFace repository. config: Path to YAML config or filename in the ``ML4GW/amplfi`` HuggingFace repository. device: Device to run inference on (``"cpu"`` or ``"cuda"``). If ``None``, uses CUDA when available, otherwise CPU. revision: HuggingFace repository revision (branch, tag, or commit hash). If ``None``, uses the default branch. load_weights: If ``True``, download and load the flow weights. Set to ``False`` to load only the config (e.g. to inspect hyperparameters without allocating GPU memory). cache_dir: Local directory for HuggingFace download cache. If ``None``, uses the default ``~/.cache/huggingface``. """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device logging.debug(f"Using device: {self.device}") config = get_local_or_hf( filename=config, repo_id=REPO_ID, descriptor="AMPLFI model config", revision=revision, cache_dir=cache_dir, ) parser = ArgumentParser() parser.add_class_arguments( AmplfiConfig, fail_untyped=False, sub_configs=True ) parser.link_arguments( "inference_params", "architecture.init_args.num_params", compute_fn=lambda x: len(x), apply_on="parse", ) args = parser.parse_path(config) args = parser.instantiate_classes(args) super().__init__(**vars(args)) if load_weights: weights_path = get_local_or_hf( filename=model_weights, repo_id=REPO_ID, descriptor="AMPLFI model weights", revision=revision, cache_dir=cache_dir, ) model, scaler = self.load_model( args.architecture, weights_path, len(args.inference_params), ) self.model = model.to(self.device) self.scaler = scaler.to(self.device) self.configure_preprocessing() def load_model( self, model: "FlowArchitecture", model_weights: str, num_params: int, ): checkpoint = torch.load( model_weights, map_location="cpu", weights_only=False ) arch_weights = { k[6:]: v for k, v in checkpoint["state_dict"].items() if k.startswith("model.") } model.load_state_dict(arch_weights) model.eval() scaler_weights = { k[len("scaler.") :]: v for k, v in checkpoint["state_dict"].items() if k.startswith("scaler.") } scaler = ChannelWiseScaler(num_params) scaler.load_state_dict(scaler_weights) return model, scaler def configure_preprocessing(self) -> None: self.spectral_density = SpectralDensity( sample_rate=self.sample_rate, fftlength=self.fftlength, average="median", ).to(self.device) self.whitener = Whiten( fduration=self.fduration, sample_rate=self.sample_rate, highpass=self.highpass, lowpass=self.lowpass, ).to(self.device) @property def minimum_data_size(self) -> int: """Minimum data size required for the model to run.""" return int( (self.kernel_length + self.psd_length + self.fduration) * self.sample_rate ) def __call__( self, data: torch.Tensor, t0: float, tc: float, samples_per_event: int, ) -> "AmplfiResult": if not hasattr(self, "model"): raise RuntimeError( "AMPLFI model weights were not loaded. " "Re-initialize with load_weights=True." ) if data.shape[-1] < self.minimum_data_size: raise ValueError( f"Data size {data.shape[-1]} is less than the minimum " f"size of {self.minimum_data_size}" ) psd_data, window = slice_amplfi_data( data=data, sample_rate=self.sample_rate, t0=t0, tc=tc, kernel_length=self.kernel_length, event_position=self.event_position, psd_length=self.psd_length, fduration=self.fduration, ) samples = run_amplfi( amplfi_strain=window, amplfi_psd_strain=psd_data, samples_per_event=samples_per_event, spectral_density=self.spectral_density, amplfi_whitener=self.whitener, amplfi=self.model, std_scaler=self.scaler, device=self.device, ) samples = samples.cpu() result = postprocess_samples( samples=samples, event_time=tc, inference_params=self.inference_params, parameter_sampler=self.parameter_sampler, ) return result