Examples¶
Basic Cosmology Calculations¶
This example demonstrates how to use the w0waCDMCosmology class to compute various cosmological quantities and visualize them.
Setting up the Cosmology¶
First, let's create a cosmology instance with standard ΛCDM parameters (Planck 2018):
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jaxace
# Create a cosmology instance with Planck 2018 parameters
cosmo = jaxace.w0waCDMCosmology(
ln10As=3.044, # ln(10^10 A_s)
ns=0.9649, # Scalar spectral index
h=0.6736, # Hubble parameter
omega_b=0.02237, # Baryon density
omega_c=0.1200, # CDM density
m_nu=0.06, # Sum of neutrino masses in eV
w0=-1.0, # Dark energy equation of state
wa=0.0 # Dark energy evolution parameter
)
Computing Distances¶
Now let's compute the comoving distance for a range of redshifts:
# Create redshift array from 0 to 3
z = jnp.linspace(0.01, 3.0, 100)
# Compute comoving distance
r_comoving = cosmo.r_z(z)
# Also compute angular diameter and luminosity distances
dA = cosmo.dA_z(z)
dL = cosmo.dL_z(z)
Computing Growth Functions¶
Let's compute the linear growth factor and growth rate:
# Compute growth factor D(z) and growth rate f(z)
D = cosmo.D_z(z)
f = cosmo.f_z(z)
# We can also get both at once for efficiency
D_both, f_both = cosmo.D_f_z(z)
Visualizing the Results¶
Now let's create plots to visualize these cosmological quantities:
# Create a figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
# Plot 1: Distance measures
ax = axes[0, 0]
ax.plot(z, r_comoving, label='Comoving distance $r(z)$', linewidth=2)
ax.plot(z, dA, label='Angular diameter distance $d_A(z)$', linewidth=2)
ax.plot(z, dL, label='Luminosity distance $d_L(z)$', linewidth=2)
ax.set_xlabel('Redshift $z$')
ax.set_ylabel('Distance [Mpc]')
ax.set_title('Cosmological Distances')
ax.legend()
ax.grid(True, alpha=0.3)
# Plot 2: Hubble parameter
ax = axes[0, 1]
E_z = jaxace.E_z(z, cosmo.omega_b + cosmo.omega_c, cosmo.h,
mν=cosmo.m_nu, w0=cosmo.w0, wa=cosmo.wa)
H_z = 100 * cosmo.h * E_z # H(z) in km/s/Mpc
ax.plot(z, H_z, color='darkblue', linewidth=2)
ax.set_xlabel('Redshift $z$')
ax.set_ylabel('$H(z)$ [km/s/Mpc]')
ax.set_title('Hubble Parameter Evolution')
ax.grid(True, alpha=0.3)
# Plot 3: Growth factor
ax = axes[1, 0]
ax.plot(z, D, color='darkgreen', linewidth=2)
ax.set_xlabel('Redshift $z$')
ax.set_ylabel('$D(z)$')
ax.set_title('Linear Growth Factor')
ax.grid(True, alpha=0.3)
ax.invert_xaxis() # Convention: D grows as z decreases
# Plot 4: Growth rate
ax = axes[1, 1]
ax.plot(z, f, color='darkred', linewidth=2)
ax.set_xlabel('Redshift $z$')
ax.set_ylabel('$f(z) = d\\ln D / d\\ln a$')
ax.set_title('Growth Rate')
ax.grid(True, alpha=0.3)
ax.set_ylim([0.4, 1.0])
plt.tight_layout()
plt.show()

Using JAX Features¶
One of the advantages of jaxace is that all functions are JAX-compatible, enabling automatic differentiation and JIT compilation:
# JIT compile for faster repeated calculations
@jax.jit
def compute_distances_fast(z_array, omega_c, omega_b, h):
"""JIT-compiled function for fast distance calculations."""
cosmo = jaxace.w0waCDMCosmology(
ln10As=3.044, ns=0.9649, h=h,
omega_b=omega_b, omega_c=omega_c,
m_nu=0.06, w0=-1.0, wa=0.0
)
return cosmo.r_z(z_array)
# Compute gradient with respect to cosmological parameters
grad_fn = jax.grad(lambda omega_c: compute_distances_fast(
jnp.array([1.0]), omega_c, 0.02237, 0.6736
).sum())
# Calculate derivative of comoving distance at z=1 with respect to omega_c
d_r_d_omega_c = grad_fn(0.12)
print(f"∂r(z=1)/∂Ω_c = {d_r_d_omega_c:.2f} Mpc")
Computing Jacobians of Growth Functions¶
A powerful feature of jaxace is the ability to compute Jacobians (derivatives) of cosmological quantities with respect to all input parameters simultaneously:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import jaxace
# Define a function that computes growth factor for given parameters
def growth_factor_function(params, z):
"""Compute growth factor D(z) for given cosmological parameters."""
omega_c, omega_b, h, m_nu, w0, wa = params
cosmo = jaxace.w0waCDMCosmology(
ln10As=3.044, # Keep fixed for this example
ns=0.9649, # Keep fixed for this example
h=h,
omega_b=omega_b,
omega_c=omega_c,
m_nu=m_nu,
w0=w0,
wa=wa
)
return cosmo.D_z(z)
# Define fiducial parameters (Planck 2018)
fiducial_params = jnp.array([
0.1200, # omega_c
0.02237, # omega_b
0.6736, # h
0.06, # m_nu (eV)
-1.0, # w0
0.0 # wa
])
# Redshift array
z = jnp.linspace(0.01, 3.0, 50)
# Compute Jacobian matrix: dD/dθ for all parameters at all redshifts
jacobian_fn = jax.jacobian(growth_factor_function, argnums=0)
jacobian = jacobian_fn(fiducial_params, z)
# jacobian shape is (n_z, n_params)
print(f"Jacobian shape: {jacobian.shape}")
print(f"Parameters: [omega_c, omega_b, h, m_nu, w0, wa]")
# Plot the derivatives
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
param_names = ['$\\Omega_c$', '$\\Omega_b$', '$h$', '$m_\\nu$ [eV]', '$w_0$', '$w_a$']
param_labels = ['omega_c', 'omega_b', 'h', 'm_nu', 'w0', 'wa']
for i, (ax, name, label) in enumerate(zip(axes.flat, param_names, param_labels)):
ax.plot(z, jacobian[:, i], linewidth=2.5, color=plt.cm.viridis(i/5))
ax.set_xlabel('Redshift $z$')
ax.set_ylabel(f'$\\partial D/\\partial${name}')
ax.set_title(f'Growth Factor Sensitivity to {name}')
ax.grid(True, alpha=0.3)
ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
# Add text with max sensitivity
max_idx = jnp.argmax(jnp.abs(jacobian[:, i]))
max_z = z[max_idx]
max_val = jacobian[max_idx, i]
ax.text(0.95, 0.95, f'Max at z={max_z:.2f}\\n∂D/∂{label}={max_val:.3f}',
transform=ax.transAxes, ha='right', va='top',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
plt.suptitle('Jacobian of Growth Factor D(z) w.r.t. Cosmological Parameters',
fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

Computing Jacobians for Multiple Quantities¶
We can also compute Jacobians for both growth factor D(z) and growth rate f(z) simultaneously:
# Define a function that returns both D and f
def growth_functions(params, z):
"""Compute both D(z) and f(z) for given parameters."""
omega_c, omega_b, h, m_nu, w0, wa = params
cosmo = jaxace.w0waCDMCosmology(
ln10As=3.044, ns=0.9649,
h=h, omega_b=omega_b, omega_c=omega_c,
m_nu=m_nu, w0=w0, wa=wa
)
D, f = cosmo.D_f_z(z)
return jnp.stack([D, f]) # Stack to get shape (2, n_z)
# Compute Jacobian for both quantities
jacobian_both_fn = jax.jacobian(growth_functions, argnums=0)
jacobian_both = jacobian_both_fn(fiducial_params, z)
# jacobian_both shape is (2, n_z, n_params)
# where index 0 is D(z) and index 1 is f(z)
jacobian_D = jacobian_both[0] # Shape: (n_z, n_params)
jacobian_f = jacobian_both[1] # Shape: (n_z, n_params)
# Plot comparison of sensitivities for D and f
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
for i, (ax, name, label) in enumerate(zip(axes.flat, param_names, param_labels)):
ax.plot(z, jacobian_D[:, i], label='$\\partial D/\\partial$' + name,
linewidth=2.5, color='blue')
ax.plot(z, jacobian_f[:, i], label='$\\partial f/\\partial$' + name,
linewidth=2.5, color='red', linestyle='--')
ax.set_xlabel('Redshift $z$')
ax.set_ylabel('Derivative')
ax.set_title(f'Sensitivity to {name}')
ax.legend(loc='best')
ax.grid(True, alpha=0.3)
ax.axhline(y=0, color='k', linestyle=':', alpha=0.3)
plt.suptitle('Comparison of D(z) and f(z) Sensitivities',
fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

Comparing Different Cosmologies¶
Let's compare how different dark energy models affect the growth of structure:
# Define different cosmologies
cosmologies = {
'ΛCDM': {'w0': -1.0, 'wa': 0.0},
'wCDM': {'w0': -0.9, 'wa': 0.0},
'w0waCDM': {'w0': -0.95, 'wa': 0.3},
}
# Compute growth factors for each cosmology
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
for name, params in cosmologies.items():
cosmo_test = jaxace.w0waCDMCosmology(
ln10As=3.044, ns=0.9649, h=0.6736,
omega_b=0.02237, omega_c=0.1200,
m_nu=0.06, **params
)
D_test = cosmo_test.D_z(z)
f_test = cosmo_test.f_z(z)
ax1.plot(z, D_test, label=name, linewidth=2)
ax2.plot(z, f_test, label=name, linewidth=2)
ax1.set_xlabel('Redshift $z$')
ax1.set_ylabel('$D(z)$')
ax1.set_title('Growth Factor Comparison')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.invert_xaxis()
ax2.set_xlabel('Redshift $z$')
ax2.set_ylabel('$f(z)$')
ax2.set_title('Growth Rate Comparison')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim([0.4, 1.0])
plt.tight_layout()
plt.show()

Performance Tips¶
JIT Compilation¶
For optimal performance, especially when computing many values:
# JIT compile the entire calculation pipeline
@jax.jit
def compute_all_quantities(z_array, cosmo):
"""Compute all cosmological quantities at once."""
r = cosmo.r_z(z_array)
D, f = cosmo.D_f_z(z_array)
return r, D, f
# First call will compile
z_test = jnp.linspace(0.01, 3.0, 1000)
r, D, f = compute_all_quantities(z_test, cosmo)
# Subsequent calls will be much faster
%timeit compute_all_quantities(z_test, cosmo)
Vectorization¶
All functions support vectorized inputs for efficient batch calculations:
# Compute for multiple redshifts at once (vectorized)
z_grid = jnp.logspace(-2, 0.5, 50) # z from 0.01 to ~3.16
r_grid = cosmo.r_z(z_grid)
# This is much faster than a loop:
# for zi in z_grid:
# ri = cosmo.r_z(zi) # Slower!
Next Steps¶
- Explore the API Reference for detailed function documentation
- Check the neutrino mass effects by varying
m_nu - Experiment with different dark energy models using
w0andwa - Use JAX's automatic differentiation to compute derivatives of cosmological quantities