Source code for cosmoglobe.sky.model

from __future__ import annotations

from typing import Dict, Iterator, List, Literal, Optional, Union

import healpy as hp
import numpy as np
from astropy.units import Quantity
from astropy.units.core import UnitsError

from cosmoglobe.h5.chain import Chain
from cosmoglobe.sky._base_components import (
    DiffuseComponent,
    LineComponent,
    PointSourceComponent,
    SkyComponent,
)
from cosmoglobe.sky._component_factory import get_components_from_chain
from cosmoglobe.sky._constants import (
    DEFAULT_BEAM_FWHM,
    DEFAULT_GAL_CUT,
    DEFAULT_OUTPUT_UNIT_STR,
)
from cosmoglobe.sky._exceptions import (
    ComponentError,
    ComponentNotFoundError,
    NsideError,
)
from cosmoglobe.sky._units import Unit
from cosmoglobe.sky.components._labels import SkyComponentLabel
from cosmoglobe.sky.cosmoglobe import cosmoglobe_registry


[docs]class SkyModel: r"""Sky model representing an initialized instance of the Cosmoglobe Sky Model. This class acts as a container for the various components making up the Cosmoglobe Sky Model and provides methods to simulate the sky. The primary use case of this class is to call its ``__call__`` method, which simulates the sky at a single frequency :math:`\nu`, or integrated over a bandpass :math:`\tau`. Examples -------- >>> import cosmoglobe >>> chain = cosmoglobe.get_test_chain() >>> model = cosmoglobe.SkyModel.from_chain(chain, nside=256) >>> print(model) SkyModel( version: BeyondPlanck nside: 256 components( (ame): SpinningDust(freq_peak) (cmb): CMB() (dust): ModifiedBlackbody(beta, T) (ff): LinearOpticallyThin(T_e) (radio): AGNPowerLaw(alpha) (synch): PowerLaw(beta) ) ) Simulating the full sky emission at some frequency, given a beam FWHM: >>> import astropy.units as u >>> model(50*u.GHz, fwhm=30*u.arcmin) [[ 2.25809786e+03 2.24380103e+03 2.25659060e+03 ... -2.34783682e+03 -2.30464421e+03 -2.30387946e+03] [-1.64627550e+00 2.93583564e-01 -1.06788937e+00 ... -1.64354585e+01 1.60621841e+01 -1.05506092e+01] [-4.15682825e+00 3.08881971e-01 -1.02012415e+00 ... 5.44745701e+00 -4.71776995e+00 4.39850830e+00]] uK_RJ """
[docs] def __init__( self, nside: int, components: Dict[str, SkyComponent], version: Optional[str] = None, ) -> None: """Initializes an instance of the Cosmoglobe Sky Model. Parameters ---------- nside Healpix resolution of the maps in sky model. components A list of pre-initialized `SkyComponent`to include in the model. version The version of the Cosmoglobe Model used in the sky model. """ self.nside = nside self.components = components self.version = version if not all( isinstance(component, SkyComponent) for component in components.values() ): raise ComponentError("all components must be subclasses of SkyComponent") if not all( self.nside == hp.get_nside(component.amp) for component in components.values() if not isinstance(component, PointSourceComponent) ): raise NsideError( "all diffuse maps in the sky model needs to be at a common nside" )
[docs] @classmethod def from_chain( cls, chain: Union[str, Chain], nside: int, components: Optional[List[str]] = None, model: str = "BeyondPlanck", samples: Union[range, int, Literal["all"]] = -1, burn_in: Optional[int] = None, ) -> SkyModel: """Initializes the SkyModel from a Cosmoglobe chain. Parameters ---------- chain Path to a Cosmoglobe chainfile or a Chain object. nside Model HEALPIX map resolution parameter. components List of components to include in the model. model String representing which sky model to use. Defaults to BeyondPlanck. samples The sample number for which to extract the model. If the input is 'all', then the model will an average of all samples in the chain. Defaults to the last sample in the chain. burn_in Burn in sample for which all previous samples are disregarded. Returns ------- sky_model Initialized sky model. """ cosmoglobe_model = cosmoglobe_registry.get_model(model) initialized_components = get_components_from_chain( chain=chain, nside=nside, components=components, model=cosmoglobe_model, samples=samples, burn_in=burn_in, ) return cls( nside=nside, components=initialized_components, version=cosmoglobe_model.version, )
[docs] @classmethod def from_hub(cls) -> SkyModel: """Initializes the SkyModel from a public cosmoglobe data release.""" raise NotImplementedError
@property def nside(self) -> int: """HEALPIX map resolution of the maps in the Sky Model.""" return self._nside @nside.setter def nside(self, value: int) -> None: """Validating the nside.""" try: if not hp.isnsideok(value, nest=True): raise NsideError("nside needs to be a power of 2") except (TypeError, ValueError): raise TypeError("nside must be an integer") self._nside = int(value)
[docs] def __call__( self, freqs: Quantity, weights: Optional[Quantity] = None, *, components: Optional[List[str]] = None, fwhm: Quantity = DEFAULT_BEAM_FWHM, output_unit: str = DEFAULT_OUTPUT_UNIT_STR, ) -> Quantity: r"""Simulates and returns the full sky model emission. Parameters ---------- freqs A frequency, or a list of frequencies corresponding to the bandpass weights. weights Bandpass weights corresponding to the frequencies in `freqs`. The units of this quantity must match the units of the related detector, i.e, they must be compatible with K_RJ, K_CMB, or Jy/sr. Even if the bandpass weights are already normalized and now have units of 1/Hz, redefine the quantities unites to be that of the detector. If `bandpass` is None and `freqs` is a single frequency, a delta peak is assumed. Defaults to None. components List of component labels. If None, all components in the sky model is included. Defaults to None. fwhm The full width half max parameter of the Gaussian. Defaults to 0.0, which indicates no smoothing of output maps. output_unit The output units of the emission. For instance 'uK_RJ', 'uK_CMB', or 'MJy/sr'. If Kelvin, Rayleigh-Jeans or CMB must be specified. Returns ------- emission The simulated emission given the Cosmoglobe Sky Model. """ if components is None: included_components = list(self.components.values()) else: if not all(component in self.components for component in components): raise ValueError("all component must be present in the model") included_components = [ component for label, component in self.components.items() if label in components ] diffuse_components: List[Union[DiffuseComponent, LineComponent]] = [] pointsource_components: List[PointSourceComponent] = [] for component_class in included_components: if isinstance(component_class, (DiffuseComponent, LineComponent)): diffuse_components.append(component_class) elif isinstance(component_class, PointSourceComponent): pointsource_components.append(component_class) emission_unit = Unit(output_unit) emission = Quantity(np.zeros((3, hp.nside2npix(self.nside))), unit=emission_unit) for diffuse_component in diffuse_components: component_emission = diffuse_component.simulate_emission( freqs=freqs, weights=weights, nside=self.nside, fwhm=fwhm, output_unit=emission_unit, ) for IQU, diffuse_emission in enumerate(component_emission): emission[IQU] += diffuse_emission # We smooth all diffuse components together in a single smoothing operation. if fwhm != DEFAULT_BEAM_FWHM: emission = Quantity( hp.smoothing(emission, fwhm=fwhm.to("rad").value), unit=emission.unit ) # Pointsource emissions are already smoothed during the stage where # each source is mapped to the HEALPIX map. for pointsource_component in pointsource_components: component_emission = pointsource_component.simulate_emission( freqs=freqs, weights=weights, nside=self.nside, fwhm=fwhm, output_unit=emission_unit, ) for IQU, pointsource_emission in enumerate(component_emission): emission[IQU] += pointsource_emission return emission
[docs] def remove_dipole( self, gal_cut: Quantity = DEFAULT_GAL_CUT, return_dipole: bool = False, ) -> Optional[Quantity]: """Removes the Solar dipole (and monopole), from the CMB sky component. Parameters ---------- gal_cut Ignores pixels in range [-gal_cut;+gal_cut] when fitting and subtracting the dipole. return_dipole If True, returns the dipole map. Returns _______ dipole The dipole of the CMB sky component. """ if not any( component.label == SkyComponentLabel.CMB for component in self.components.values() ): raise ComponentNotFoundError("CMB component not found in sky model.") try: gal_cut = gal_cut.to("deg") except UnitsError: raise UnitsError( "gal_cut must be an astropy quantity compatible with 'deg'" ) if return_dipole: dipole_subtracted_amp = Quantity( hp.remove_dipole( self["cmb"].amp[0], gal_cut=gal_cut.to("deg").value, ), unit=self["cmb"].amp.unit, ) dipole = self["cmb"].amp[0] - dipole_subtracted_amp self["cmb"].amp[0] = dipole_subtracted_amp return dipole hp.remove_dipole( self["cmb"].amp[0], gal_cut=gal_cut.to("deg").value, copy=False )
def __iter__(self) -> Iterator: """Returns an iterator over the model components.""" return iter(list(self.components.values())) def __getitem__(self, key: str) -> SkyComponent: """Returns a SkyComponent class.""" try: return self.components[key] except KeyError: raise ComponentNotFoundError(f"component {key} not found in sky model.") def __repr__(self) -> str: """Representation of the SkyModel and all enabled components.""" reprs = [] for label, component in self.components.items(): component_repr = repr(component) + "\n" reprs.append(f"({label}): {component_repr}") main_repr = "SkyModel(" if self.version is not None: main_repr += f"\n version: {self.version}" main_repr += f"\n nside: {self._nside}" main_repr += "\n components( " main_repr += "\n " + " ".join(reprs) main_repr += " )" main_repr += "\n)" return main_repr