Source code for cosmoglobe.h5.chain

from __future__ import annotations

import textwrap
from pathlib import Path
from typing import Any, Generator, Optional, Sequence

import h5py
import healpy as hp
import numpy as np

from cosmoglobe.h5 import PARAMETER_GROUP_NAME, ChainVersion
from cosmoglobe.h5._alms import unpack_alms as unpack_alms_
from cosmoglobe.h5._alms import unpack_alms_from_chain
from cosmoglobe.h5._decorators import unpack_alms, validate_key, validate_samples
from cosmoglobe.h5._exceptions import ChainFormatError, ChainKeyError, ChainSampleError
from cosmoglobe.sky.components._labels import SkyComponentLabel


[docs]class Chain: """An interface for Cosmoglobe chainfiles. This class aims to provide a convenient interface for working with Cosmoglobe chain files. """
[docs] def __init__(self, path: str | Path, burn_in: Optional[int] = None) -> None: """Validate and initialize the Chain object. Parameters ---------- path Path to the chainfile. burn_in Burn in sample. All samples prior to (and including) the burn in sample is discarded. """ if not (path := Path(path)).is_file(): raise FileNotFoundError(f"{path.name} was not found") try: with h5py.File(path, "r") as file: pass except OSError: raise ChainFormatError(f"{path.name} is not a HDF5 file") with h5py.File(path, "r") as file: samples = list(file.keys()) if not samples: raise ChainFormatError("chain has no samples") try: samples.remove(PARAMETER_GROUP_NAME) version = ChainVersion.NEW except ValueError: version = ChainVersion.OLD sampled_groups = list(file[samples[0]].keys()) components = [ group for group in sampled_groups if group in [label.value for label in SkyComponentLabel] ] parameters: dict[str, dict[str, Any]] = {} if version is ChainVersion.NEW: for component, group in file[PARAMETER_GROUP_NAME].items(): parameters[component] = {} for key, value in group.items(): if np.issubdtype(value.dtype, np.string_): value = value.asstr() parameters[component][key] = value[()] if burn_in is None: self._samples = samples else: if burn_in >= (nsamples := len(samples)): raise ChainSampleError(f"{burn_in=} out of range with {nsamples=}") self._samples = samples[burn_in:] self._components = components self._parameters = parameters self._path = path self._version = version
@property def samples(self) -> list[str]: """List of all samples in the chain.""" return self._samples @property def nsamples(self) -> int: """Number of samples in the chain.""" return len(self.samples) @property def components(self) -> list[str]: """List of the sky components in the chain.""" return self._components @property def parameters(self) -> dict[str, dict[str, Any]]: """Dictionary of the parameters in the parameter group in the chain.""" return self._parameters @property def path(self) -> Path: """Path to the chainfile.""" return self._path @property def version(self) -> ChainVersion: """Chain version.""" return self._version @property def tree(self): """Prints group and dataset structure of the chainfile.""" def print_attrs(name, _): space = name.count("/") * " " path = name.split("/") if path[0] not in ("000000", "parameters"): return item_name = path[-1] print(space + item_name) with h5py.File(self.path, "r") as file: file.visititems(print_attrs)
[docs] @validate_key @validate_samples @unpack_alms def get( self, key: str, *, samples: Optional[range | int | Sequence[int]] = None, unpack: bool = True, ) -> Any: """Returns the value of an key for all samples. Parameters ---------- key The path to an item that has been sampled in the chain, e.g 'dust/amp_alm'. samples An int or a range of samples for which to return the value. If None, all samples in the chain are used. unpack If True, alms are unpacked from Commander format to healpy format. Default is True Returns ------- values The value of the key for each samples. """ with h5py.File(self.path, "r") as file: values = [file[f"{sample}/{key}"][()] for sample in samples] return np.asarray(values) if len(values) != 1 else values[0]
[docs] @validate_key @validate_samples @unpack_alms def mean( self, key: str, *, samples: Optional[range | int | Sequence[int]] = None, unpack: bool = True, ) -> Any: """Returns the mean of an key over all samples. Parameters ---------- key The path to an item that has been sampled in the chain, e.g 'dust/amp_alm'. samples An int or a range of samples to average over. If None, all samples in the chain are used. unpack If True, alms are unpacked from Commander format to healpy format. Default is True Returns ------- value The averaged value of the key over all samples. """ with h5py.File(self.path, "r") as file: value = file[f"{samples[0]}/{key}"][()] dtype = value.dtype.type if len(samples) > 1: for sample in samples[1:]: value += file[f"{sample}/{key}"][()] return dtype(value / len(samples)) # Converting back to original dtype
[docs] @validate_key @validate_samples def stddev( self, key: str, *, alm2map=False, samples: Optional[range | int | Sequence[int]] = None, ) -> Any: """Returns the stddev of an key over all samples. Parameters ---------- key The path to an item that has been sampled in the chain, e.g 'dust/amp_alm'. samples An int or a range of samples to average over. If None, all samples in the chain are used. Returns ------- value The averaged value of the key over all samples. """ with h5py.File(self.path, "r") as file: value = file[f"{samples[0]}/{key}"][()] dtype = value.dtype.type if len(samples) > 1: for sample in samples[1:]: value += file[f"{sample}/{key}"][()] mu = dtype(value / len(samples)) # Converting back to original dtype # Calculate in map-space if alm2map if alm2map: comp, quantity = key.split("/") nside = self.parameters[comp]["nside"] pol = True if quantity.startswith("amp") else False fwhm = self.parameters[comp]["fwhm"] lmax = 3*nside mu = hp.alm2map(unpack_alms_(mu, lmax), nside=nside, lmax=lmax, fwhm=fwhm, pol=pol, pixwin=True,) with h5py.File(self.path, "r") as file: x = file[f"{samples[0]}/{key}"][()] if alm2map: x = hp.alm2map(unpack_alms_(x, lmax), nside=nside, lmax=lmax, fwhm=fwhm, pol=pol, pixwin=True,) dtype = x.dtype.type numerator = (x - mu)**2 if len(samples) > 1: for sample in samples[1:]: x = file[f"{sample}/{key}"][()] if alm2map: x = hp.alm2map(unpack_alms_(x, lmax), nside=nside, lmax=lmax, fwhm=fwhm, pol=pol, pixwin=True,) numerator += (x - mu)**2 return dtype(np.sqrt(numerator/len(samples))) # Converting back to original dtype
[docs] @validate_key @validate_samples def load( self, key: str, *, samples: Optional[range | int | Sequence[int]] = None, ) -> Generator: """Returns a generator to be used in a for loop. NOTE to devs: The unpack_alms decorator wont work on this function due to it not processing the returned data until it is iterated over. Parameters ---------- key The path to an item that has been sampled in the chain, e.g 'dust/amp_alm'. samples An int or a range of samples to average over. If None, all samples in the chain are used. Returns ------- A generator that can be looped over to yield each sampled value. """ with h5py.File(self.path, "r") as file: for sample in samples: value = file[f"{sample}/{key}"][()] if "alm" in key: value = unpack_alms_from_chain(self, value, key) yield value
@validate_key @unpack_alms def __getitem__(self, key: str, *, unpack: bool = True) -> Any: """Returns the value of a key from the chain. Parameters ---------- key The *full* path to an item in the chain. unpack If True, alms are unpacked from Commander format to healpy format. Default is True Returns ------- The value of the key. """ with h5py.File(self.path, "r") as file: item = file[key] if isinstance(item, h5py.Group): return list(item.keys()) elif isinstance(item, h5py.Dataset): if np.issubdtype(item.dtype, np.string_): return item.asstr()[()] return item[()] def _format_samples(self, samples: list[int] | int) -> list[str] | str: """Converts a range to the string format of the samples in the chain.""" leading_zeros = len(self.samples[0]) if isinstance(samples, list): return [f"{sample:0{leading_zeros}d}" for sample in samples] return f"{samples:0{leading_zeros}d}" def __str__(self) -> str: """Representation of the chain.""" COL_LEN = 40 CHAIN_META = { "Num Samples": self.nsamples, "Num Components": len(self.components), "Size": f"{self.path.stat().st_size / (1024 * 1024 * 1024):.3f}" + " GB", } def center(string: str, fill=" ") -> str: white_space_len = (COL_LEN - len(string)) // 2 white_space = fill * white_space_len output = f"{white_space}{string}{white_space}" if len(output) < COL_LEN: output += fill return output if ".astropy/cache" in str(self.path): name = "cached chainfile" else: name = self.path.name main_repr = "\n" main_repr += "-" * COL_LEN + "\n" main_repr += center(name) + "\n" main_repr += "-" * COL_LEN + "\n" main_repr += "\n" for key, value in CHAIN_META.items(): main_repr += f"{key:<{COL_LEN//2 - 1}}{'='}{value:>{COL_LEN//2}}\n" main_repr += "\n" main_repr += center(" Components ", fill="-") + "\n" main_repr += ( center(textwrap.fill(" ".join(self.components), width=COL_LEN)) + "\n" ) main_repr += "\n" main_repr += "-" * COL_LEN + "\n" return main_repr
[docs] @validate_samples def copy( self, samples: int | Sequence[int] | range = -1, new_name: Optional[str] = None, ) -> None: """Creates a copy of the chain with a single or multiple samples.""" if new_name is None: new_name = self.path.stem + "_copy.h5" with h5py.File(new_name, "x") as new_chain: with h5py.File(self.path, "r") as chain: for idx, sample in enumerate(samples): group = chain[sample] chain.copy( source=group, dest=new_chain, name=self._format_samples(idx) ) parameter_group = chain["parameters"] chain.copy( source=parameter_group, dest=new_chain, name=parameter_group.name )
[docs] def combine( self, other_chain: Chain, group_list: Sequence[str], new_name: Optional[str] = None, ) -> None: """Creates a new chainfile that combines specific groups from two chains.""" sample = other_chain.samples[0] for group in group_list: try: other_chain[f"{sample}/{group}"] except ChainKeyError: raise ChainKeyError(f"group {group} does not exist in `other chain`.") if new_name is None: new_name = self.path.stem + "_combined.h5" self.copy(samples=-1, new_name=new_name) with h5py.File(new_name, "r+") as new_chain: with h5py.File(other_chain.path, "r") as chain: sample = other_chain.samples[0] for group in group_list: if group in new_chain[sample].keys(): del new_chain[f"{sample}/{group}"] group_to_copy = chain[f"{sample}/{group}"] chain.copy( source=group_to_copy, dest=new_chain, name=group_to_copy.name ) if group in chain[f"parameters"].keys(): del new_chain[f"parameters/{group}"] param_to_copy = chain[f"parameters/{group}"] chain.copy( source=param_to_copy, dest=new_chain, name=param_to_copy.name, )
def copy_chain( chain: str | Path | Chain, samples: int | Sequence[int] | range = -1, new_name: Optional[str] = None, ) -> None: """Creates a copy of the chain with a single or multiple samples. Parameters ---------- chain Path to the chain file or the `Chain` object to copy. samples Samples to copy. Can be an int, a list of ints or a python range object. new_name Name of the chain copy. If None, a default is "{chain.name}_copy.h5" """ if not isinstance(chain, Chain): chain = Chain(chain) chain.copy(samples=samples, new_name=new_name) def combine_chains( chain: str | Path | Chain, other_chain: str | Path | Chain, group_list: Sequence[str], new_name: Optional[str] = None, ) -> None: """Creates a new chainfile that combines specific groups from two chains. The new file will contain all content from `chain`, except for the content within the groups in the `group_list`, which are taken from `other_chain` instead. Parameters ---------- chain Path to chain file. This chain defines all the chain whos content you want to overwrite in a new combined file. other_chain Path to chain file. This chain contains the groups you want to overwrite in `chain`. group_list List of hdf5 groups that will be overwritten in the new chainfile. new_name Name of the chain copy. If None, a default is "{chain.name}_copy.h5" """ if not isinstance(chain, Chain): chain = Chain(chain) if not isinstance(other_chain, Chain): other_chain = Chain(other_chain) chain.combine(other_chain, group_list=group_list, new_name=new_name)