"""
Project a batch of waveform polarizations onto the Hanford,
Livingston, and Virgo interferometers to compute the observed
gravitational wave strain.
"""
from ml4gw.gw import get_ifo_geometry, compute_observed_strain
from ml4gw.distributions import Cosine
from torch.distributions import Uniform
dec = Cosine()
psi = Uniform(0, torch.pi)
phi = Uniform(-torch.pi, torch.pi)
# Get the interferometer geometry
ifos = ["H1", "L1", "V1"]
tensors, vertices = get_ifo_geometry(*ifos)
# The following assumes that the plus and cross polarizations
# of the gravitational wave have already been computed by
# some method; e.g., using the `TimeDomainCBCWaveformGenerator`
# from the `ml4gw.waveforms` module. `sample_rate` is the sample
# rate at which the polarizations were generated.
waveforms = compute_observed_strain(
dec=dec.sample((num_waveforms,)),
psi=psi.sample((num_waveforms,)),
phi=phi.sample((num_waveforms,)),
detector_tensors=tensors,
detector_vertices=vertices,
sample_rate=sample_rate,
cross=hc,
plus=hp,
)