API Reference¶
Cosmology¶
jaxace.W0WaCDMCosmology
dataclass
¶
W0WaCDMCosmology(ln10As: float, ns: float, h: float, omega_b: float, omega_c: float, m_nu: float = 0.0, w0: float = -1.0, wa: float = 0.0)
D_f_z
¶
Linear growth factor and growth rate (D(z), f(z)).
E_a
¶
Dimensionless Hubble parameter E(a) = H(a)/H0.
E_z
¶
Dimensionless Hubble parameter E(z) = H(z)/H0.
dA_z
¶
Angular diameter distance in Mpc.
dL_z
¶
Luminosity distance at redshift z in Mpc.
f_z
¶
Growth rate f(z) = d log D / d log a.
r̃_z
¶
Dimensionless comoving distance r̃(z).
Ωm_a
¶
Matter density parameter Ωₘ(a) at scale factor a.
Ωtot_z
¶
Total density parameter at redshift z (always 1.0 for flat universe).
ρc_z
¶
Critical density at redshift z in M☉/Mpc³.
Background Functions¶
Hubble Functions¶
jaxace.E_z
¶
jaxace.E_a
¶
E_a(a: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], mν: Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]
Dimensionless Hubble parameter E(a) = H(a)/H0.
The normalized Hubble parameter is given by:
where:
- \(\Omega_{\gamma,0}\) is the photon density parameter today
- \(\Omega_{\mathrm{cb},0}\) is the cold dark matter + baryon density parameter today
- \(\Omega_{\Lambda,0}\) is the dark energy density parameter today (from flatness constraint)
- \(\rho_{\mathrm{DE}}(a)\) is the normalized dark energy density
- \(\Omega_{\nu}(a)\) is the massive neutrino contribution
Returns:
| Type | Description |
|---|---|
Union[float, ndarray]
|
Hubble parameter E(a). Handles NaN/Inf inputs by propagating them appropriately. |
Union[float, ndarray]
|
Returns NaN for invalid parameter combinations. |
jaxace.dlogEdloga
¶
Matter Density¶
jaxace.Ωm_a
¶
Growth Functions¶
jaxace.D_z
¶
Linear growth factor D(z).
The growth factor is normalized such that D(z=0) = 1. It satisfies the differential equation given in growth_solver.
Returns:
| Type | Description |
|---|---|
|
Linear growth factor D(z). Returns NaN for NaN inputs, handles invalid parameters gracefully. |
jaxace.f_z
¶
Growth rate f(z) = d log D / d log a.
The growth rate is defined as:
where D is the linear growth factor.
Returns:
| Type | Description |
|---|---|
|
Growth rate f(z). Returns NaN for NaN inputs, handles invalid parameters gracefully. |
Distance Functions¶
jaxace.r_z
¶
r_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], mν: Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]
jaxace.dA_z
¶
dA_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], mν: Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]
jaxace.dL_z
¶
dL_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], mν: Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]
Luminosity distance at redshift z.
The luminosity distance is related to the comoving distance by: dL(z) = r(z) * (1 + z)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
Union[float, ndarray]
|
Redshift |
required |
Ωcb0
|
Union[float, ndarray]
|
Present-day matter density parameter (CDM + baryons) |
required |
h
|
Union[float, ndarray]
|
Dimensionless Hubble parameter (H0 = 100h km/s/Mpc) |
required |
mν
|
Union[float, ndarray]
|
Sum of neutrino masses in eV |
0.0
|
w0
|
Union[float, ndarray]
|
Dark energy equation of state parameter |
-1.0
|
wa
|
Union[float, ndarray]
|
Dark energy equation of state evolution parameter |
0.0
|
Returns:
| Type | Description |
|---|---|
Union[float, ndarray]
|
Luminosity distance in Mpc |
Density Functions¶
jaxace.ρc_z
¶
ρc_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], mν: Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]
jaxace.Ωtot_z
¶
Ωtot_z(z: Union[float, ndarray], Ωcb0: Union[float, ndarray], h: Union[float, ndarray], mν: Union[float, ndarray] = 0.0, w0: Union[float, ndarray] = -1.0, wa: Union[float, ndarray] = 0.0) -> Union[float, jnp.ndarray]
Utility Functions¶
Neural Network Emulators¶
jaxace.init_emulator
¶
init_emulator(nn_dict: Dict[str, Any], weight: ndarray, emulator_type: Type[FlaxEmulator] = FlaxEmulator, validate: bool = True, validate_weights: Optional[bool] = None) -> FlaxEmulator
Initialize an emulator from neural network dictionary and weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
nn_dict
|
Dict[str, Any]
|
Neural network specification dictionary |
required |
weight
|
ndarray
|
Flattened weight array |
required |
emulator_type
|
Type[FlaxEmulator]
|
Type of emulator (currently only FlaxEmulator) |
FlaxEmulator
|
validate
|
bool
|
Whether to validate nn_dict structure |
True
|
validate_weights
|
Optional[bool]
|
Whether to validate weight dimensions |
None
|
Returns:
| Type | Description |
|---|---|
FlaxEmulator
|
Initialized FlaxEmulator instance |
jaxace.FlaxEmulator
dataclass
¶
FlaxEmulator(model: Module, parameters: Dict[str, Any], states: Optional[Dict[str, Any]] = None, description: Dict[str, Any] = None)
Bases: AbstractTrainedEmulator
Flax-based emulator with automatic JIT compilation.
Key features: 1. Automatic JIT compilation on first use 2. Automatic batch detection and vmap application 3. Cached compiled functions for performance
Attributes:
| Name | Type | Description |
|---|---|---|
model |
Module
|
Flax model (nn.Module) |
parameters |
Dict[str, Any]
|
Model parameters dictionary |
states |
Optional[Dict[str, Any]]
|
Model states (usually empty for standard feedforward networks) |
description |
Dict[str, Any]
|
Emulator description dictionary |
run_emulator
¶
Run the emulator with automatic JIT compilation and batch detection.
This method automatically: 1. Converts numpy arrays to JAX arrays 2. Detects if input is a batch or single sample 3. Applies JIT compilation 4. Uses vmap for batch processing
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_data
|
Union[ndarray, ndarray]
|
Input array (single sample or batch) Shape: (n_features,) for single or (n_samples, n_features) for batch |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Output array from the neural network |
__call__
¶
Allow the emulator to be called directly as a function.
Utilities¶
jaxace.maximin
¶
maximin(input_data: Union[ndarray, ndarray], minmax: Union[ndarray, ndarray]) -> Union[np.ndarray, jnp.ndarray]
Normalize input data using min-max scaling. Matches Julia's maximin function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_data
|
Union[ndarray, ndarray]
|
Input array to normalize (shape: (n_features,) or (n_features, n_samples)) |
required |
minmax
|
Union[ndarray, ndarray]
|
Array of shape (n_features, 2) where column 0 is min, column 1 is max |
required |
Returns:
| Type | Description |
|---|---|
Union[ndarray, ndarray]
|
Normalized array in range [0, 1] |
jaxace.inv_maximin
¶
inv_maximin(output_data: Union[ndarray, ndarray], minmax: Union[ndarray, ndarray]) -> Union[np.ndarray, jnp.ndarray]
Denormalize output data from min-max scaling. Matches Julia's inv_maximin function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_data
|
Union[ndarray, ndarray]
|
Normalized array (shape: (n_features,) or (n_features, n_samples)) |
required |
minmax
|
Union[ndarray, ndarray]
|
Array of shape (n_features, 2) where column 0 is min, column 1 is max |
required |
Returns:
| Type | Description |
|---|---|
Union[ndarray, ndarray]
|
Denormalized array |