Skip to content

API Reference

Cosmology

jaxace.W0WaCDMCosmology dataclass

W0WaCDMCosmology(ln10As: float, ns: float, h: float, omega_b: float, omega_c: float, m_nu: float = 0.0, w0: float = -1.0, wa: float = 0.0)

D_f_z

D_f_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Linear growth factor and growth rate (D(z), f(z)).

D_z

D_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Linear growth factor D(z).

E_a

E_a(a: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Dimensionless Hubble parameter E(a) = H(a)/H0.

E_z

E_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Dimensionless Hubble parameter E(z) = H(z)/H0.

dA_z

dA_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Angular diameter distance in Mpc.

dL_z

dL_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Luminosity distance at redshift z in Mpc.

f_z

f_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Growth rate f(z) = d log D / d log a.

r_z

r_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Comoving distance in Mpc.

r̃_z

r̃_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Dimensionless comoving distance r̃(z).

Ωm_a

Ωm_a(a: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Matter density parameter Ωₘ(a) at scale factor a.

Ωtot_z

Ωtot_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Total density parameter at redshift z (always 1.0 for flat universe).

ρc_z

ρc_z(z: Union[float, ndarray]) -> Union[float, jnp.ndarray]

Critical density at redshift z in M☉/Mpc³.

Background Functions

Hubble Functions

jaxace.E_z

E_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

Dimensionless Hubble parameter E(z) = H(z)/H0.

This is equivalent to E(a) with the transformation a = 1/(1+z).

Returns:

Type Description
Union[float, ndarray]

Hubble parameter E(z). Handles NaN/Inf inputs by propagating them appropriately.

jaxace.E_a

E_a(a: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

Dimensionless Hubble parameter E(a) = H(a)/H0.

The normalized Hubble parameter is given by:

\[E(a) = \sqrt{\Omega_{\gamma,0} a^{-4} + \Omega_{\mathrm{cb},0} a^{-3} + \Omega_{\Lambda,0} \rho_{\mathrm{DE}}(a) + \Omega_{\nu}(a)}\]

where:

  • \(\Omega_{\gamma,0}\) is the photon density parameter today
  • \(\Omega_{\mathrm{cb},0}\) is the cold dark matter + baryon density parameter today
  • \(\Omega_{\Lambda,0}\) is the dark energy density parameter today (from flatness constraint)
  • \(\rho_{\mathrm{DE}}(a)\) is the normalized dark energy density
  • \(\Omega_{\nu}(a)\) is the massive neutrino contribution

Returns:

Type Description
Union[float, ndarray]

Hubble parameter E(a). Handles NaN/Inf inputs by propagating them appropriately.

Union[float, ndarray]

Returns NaN for invalid parameter combinations.

jaxace.dlogEdloga

dlogEdloga(a: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

Logarithmic derivative of the Hubble parameter.

\[\frac{\mathrm{d} \ln E}{\mathrm{d} \ln a} = \frac{a}{E} \frac{\mathrm{d}E}{\mathrm{d}a}\]

This quantity appears in the growth factor differential equation.

Returns:

Type Description
Union[float, ndarray]

Logarithmic derivative d(ln E)/d(ln a).

Matter Density

jaxace.Ωm_a

Ωm_a(a: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

Matter density parameter Ωₘ(a) at scale factor a.

\[\Omega_{\mathrm{m}}(a) = \frac{\Omega_{\mathrm{cb},0} a^{-3}}{E(a)^2}\]

where E(a) is the normalized Hubble parameter.

Returns:

Type Description
Union[float, ndarray]

Matter density parameter Ωₘ(a).

Growth Functions

jaxace.D_z

D_z(z, Ωcb0, h, =0.0, w0=-1.0, wa=0.0)

Linear growth factor D(z).

The growth factor is normalized such that D(z=0) = 1. It satisfies the differential equation given in growth_solver.

Returns:

Type Description

Linear growth factor D(z). Returns NaN for NaN inputs, handles invalid parameters gracefully.

jaxace.f_z

f_z(z, Ωcb0, h, =0.0, w0=-1.0, wa=0.0)

Growth rate f(z) = d log D / d log a.

The growth rate is defined as:

\[f(z) = \frac{\mathrm{d} \ln D}{\mathrm{d} \ln a}\]

where D is the linear growth factor.

Returns:

Type Description

Growth rate f(z). Returns NaN for NaN inputs, handles invalid parameters gracefully.

jaxace.D_f_z

D_f_z(z, Ωcb0, h, =0.0, w0=-1.0, wa=0.0)

Distance Functions

jaxace.r_z

r_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

jaxace.dA_z

dA_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

jaxace.dL_z

dL_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

Luminosity distance at redshift z.

The luminosity distance is related to the comoving distance by: dL(z) = r(z) * (1 + z)

Parameters:

Name Type Description Default
z Union[float, ndarray]

Redshift

required
Ωcb0 Union[float, ndarray]

Present-day matter density parameter (CDM + baryons)

required
h Union[float, ndarray]

Dimensionless Hubble parameter (H0 = 100h km/s/Mpc)

required
Union[float, ndarray]

Sum of neutrino masses in eV

0.0
w0 Union[float, ndarray]

Dark energy equation of state parameter

-1.0
wa Union[float, ndarray]

Dark energy equation of state evolution parameter

0.0

Returns:

Type Description
Union[float, ndarray]

Luminosity distance in Mpc

Density Functions

jaxace.ρc_z

ρc_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

jaxace.Ωtot_z

Ωtot_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], : Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]

Utility Functions

jaxace.a_z

a_z(z)

Neural Network Emulators

jaxace.init_emulator

init_emulator(nn_dict: Dict[str, Any], weight: ndarray, emulator_type: Type[FlaxEmulator] = FlaxEmulator, validate: bool = True, validate_weights: Optional[bool] = None) -> FlaxEmulator

Initialize an emulator from neural network dictionary and weights.

Parameters:

Name Type Description Default
nn_dict Dict[str, Any]

Neural network specification dictionary

required
weight ndarray

Flattened weight array

required
emulator_type Type[FlaxEmulator]

Type of emulator (currently only FlaxEmulator)

FlaxEmulator
validate bool

Whether to validate nn_dict structure

True
validate_weights Optional[bool]

Whether to validate weight dimensions

None

Returns:

Type Description
FlaxEmulator

Initialized FlaxEmulator instance

jaxace.FlaxEmulator dataclass

FlaxEmulator(model: Module, parameters: Dict[str, Any], states: Optional[Dict[str, Any]] = None, description: Dict[str, Any] = None)

Bases: AbstractTrainedEmulator

Flax-based emulator with automatic JIT compilation.

Key features: 1. Automatic JIT compilation on first use 2. Automatic batch detection and vmap application 3. Cached compiled functions for performance

Attributes:

Name Type Description
model Module

Flax model (nn.Module)

parameters Dict[str, Any]

Model parameters dictionary

states Optional[Dict[str, Any]]

Model states (usually empty for standard feedforward networks)

description Dict[str, Any]

Emulator description dictionary

run_emulator

run_emulator(input_data: Union[ndarray, ndarray]) -> jnp.ndarray

Run the emulator with automatic JIT compilation and batch detection.

This method automatically: 1. Converts numpy arrays to JAX arrays 2. Detects if input is a batch or single sample 3. Applies JIT compilation 4. Uses vmap for batch processing

Parameters:

Name Type Description Default
input_data Union[ndarray, ndarray]

Input array (single sample or batch) Shape: (n_features,) for single or (n_samples, n_features) for batch

required

Returns:

Type Description
ndarray

Output array from the neural network

__call__

__call__(input_data: Union[ndarray, ndarray]) -> jnp.ndarray

Allow the emulator to be called directly as a function.

Utilities

jaxace.maximin

maximin(input_data: Union[ndarray, ndarray], minmax: Union[ndarray, ndarray]) -> Union[np.ndarray, jnp.ndarray]

Normalize input data using min-max scaling. Matches Julia's maximin function.

Parameters:

Name Type Description Default
input_data Union[ndarray, ndarray]

Input array to normalize (shape: (n_features,) or (n_features, n_samples))

required
minmax Union[ndarray, ndarray]

Array of shape (n_features, 2) where column 0 is min, column 1 is max

required

Returns:

Type Description
Union[ndarray, ndarray]

Normalized array in range [0, 1]

jaxace.inv_maximin

inv_maximin(output_data: Union[ndarray, ndarray], minmax: Union[ndarray, ndarray]) -> Union[np.ndarray, jnp.ndarray]

Denormalize output data from min-max scaling. Matches Julia's inv_maximin function.

Parameters:

Name Type Description Default
output_data Union[ndarray, ndarray]

Normalized array (shape: (n_features,) or (n_features, n_samples))

required
minmax Union[ndarray, ndarray]

Array of shape (n_features, 2) where column 0 is min, column 1 is max

required

Returns:

Type Description
Union[ndarray, ndarray]

Denormalized array