JAX Project Overview

JAX TPU Benchmark & Physics Simulation πŸš€

This project provides a comprehensive benchmark script (`tpus_benchmark_single-host_workload.py`) designed to measure and analyze the performance of JAX on Google Cloud TPUs. It tests a wider variety of computations, including 2D/3D matrix operations, 2D/3D FFT, and memory bandwidth, offering insights into performance scaling from single-core (JIT) to multi-core (PMAP) operations.


It also includes several physics simulation scripts accelerated with JAX, such as an N-Body black hole merger simulation, a Molecular Dynamics simulation, and a three-particle simulation in a non-uniform EM field.

✨ Key Features

Multi-Mode Benchmarking

Tests diverse operations: 2D/3D Matrix Ops (`jnp.dot`, `jnp.matmul`), 2D/3D FFT (`jnp.fft.fftn`), and Memory Bandwidth (`jnp.copy`).

Physics Simulations

Includes N-Body Black Hole Merger, Molecular Dynamics (Lennard-Jones), and Three-Particle EM simulations, all accelerated with JAX.

Multi-Core Scaling Analysis

Analyzes parallel processing efficiency by comparing single-core (`jax.jit`) to multi-core (`jax.pmap`) performance.

System & Device Introspection

Gathers system info (OS, CPU, RAM) and lists all available JAX devices (e.g., TPU), their types, and accelerator memory.

Rich Reporting & Plotting

Uses the `rich` library for formatted console tables and `matplotlib` to automatically generate a PNG plot of benchmark results.

Configurable Workloads

Allows customization of parameters via command-line, such as `--matrix_size`, `--steps`, `--precision`, and `--csv` output.

Main Benchmark Script

tpus_benchmark_single-host_workload.py

This is the main script of the project, designed specifically to measure JAX performance on TPUs. It tests a variety of operations to provide a comprehensive overview of processing capabilities.

The main tests include:

  • 2D Matrix Operations (jnp.dot)
  • 3D Tensor Operations (jnp.matmul)
  • 2D & 3D FFT (jnp.fft.fftn)
  • Memory Bandwidth (jnp.copy, jnp.sum)

How to Run:

# (Basic) Run with all default settings
python3 tpus_benchmark_single-host_workload.py

# (Advanced) Run a lighter workload for quick testing
python3 tpus_benchmark_single-host_workload.py -w 5 -m 500 -mxs 8192 -md 64

# (Advanced) Run a full benchmark on 8 cores and export results to CSV
python3 tpus_benchmark_single-host_workload.py --max_cores 8 --csv results.csv

All Arguments:

  • -w / --warmup (int, default: 10): The number of "warmup" runs.
  • -m / --steps (int, default: 1000): The number of test iterations to average.
  • -mxs / --matrix_size (int, default: 16384): The size (N) for (N, N) matrices.
  • -md / --matrix_depth (int, default: 128): The depth (D) for 3D tensors (D, N, N).
  • -c / --conv_size (int, default: 256): The size of the convolution input.
  • -b / --batch_size (int, default: 32): The batch size.
  • --precision (str, default: "float32"): Data precision (float32 or bfloat16).
  • --max_cores (int, default: 0): Maximum number of cores to test (0 = all).
  • --csv (str, default: None): Filename to output results to a CSV file.

πŸ“Š v4-8 Benchmark Analysis Report

The table below shows the results from running the updated tpus_benchmark_single-host_workload.py script in version 4-8. This version includes optimizations for better single-core performance, but multi-core scaling still reveals lingering communication challenges. Results are dynamically loaded from tpus_benchmark_single-host_workload-v4-8.csv.

avg_mstestbandwidth_gbstflopscores
1.93181134719998232D156.519267824291831
15.5643324024000323D155.41466980840121
6.3204329903997862D_FFT0.318532911124599251
80.054866804400083D_FFT0.30178291469809971
0.08025490039981378Bandwidth15575.3728902877121
3.9030321559999722D154.938870884568842
14.8621782500002143D162.757136969956952
7.4934401440001242D_FFT0.53734089585328562
45.383413189200293D_FFT0.53233525956459062
8.133642379199955Bandwidth307.365370082315562
5.5910012708001892D216.323111366508534
10.8799760339999913D222.328208586861144
8.3816934791997492D_FFT0.96079195689805574
24.9915825648000443D_FFT0.96669312466940584
16.474200310000015Bandwidth303.5048685771384
v4-8 Benchmark Performance Plot

Analysis:

Version 4-8 shows excellent single-core (JIT) performance, especially with an Internal Bandwidth reaching 15,575 GB/s. However, scaling to multi-core (PMAP) still reveals significant challenges in inter-core communication.

  • Compute-Bound Analysis (2D/3D/FFT): Compute performance (TFLOPS) scales up with more cores. For example, 3D MatMul increases from ~155 TFLOPS (1 core) to ~222 TFLOPS (4 cores), and 3D FFT increases from 0.30 TFLOPS to 0.97 TFLOPS. This indicates that the TPU scales compute-bound tasks well.
  • Bandwidth Analysis: This is the most problematic area. Bandwidth performance drops dramatically from 15,575 GB/s on 1 core to just 307 GB/s (2 cores) and 303 GB/s (4 cores). This suggests the test is no longer measuring HBM speed but is instead bottlenecked by the Inter-Core Interconnect (ICI) as `pmap` forces data exchange.

Conclusion: While v4-8 delivers excellent single-core compute performance and can scale compute tasks effectively, multi-core data bandwidth under `pmap` remains a significant challenge.

πŸ”¬ Physics Simulation Examples

This project includes several physics simulation scripts to demonstrate the power of JAX (e.g., jit, vmap, grad) in accelerating complex scientific computations.

1. Molecular Dynamics (molecular_dynamics_jax_single-host_workload.py)

This script implements a 2D Molecular Dynamics (MD) simulation of a Lennard-Jones fluid, written purely in JAX.

How to Run:

# (Basic) Run with default settings (N=400, 10k steps)
python3 molecular_dynamics_jax_single-host_workload.py

# (Advanced) Run a longer simulation with more particles and lower density
python3 molecular_dynamics_jax_single-host_workload.py --N 800 --rho 0.7 --prod_steps 50000 --eq_steps 20000

All Arguments:

  • --N (int, default: 400): Number of particles.
  • --rho (float, default: 0.8): Density.
  • --kT (float, default: 1.0): Temperature (kT).
  • --dt (float, default: 1e-3): Time step size.
  • --eq_steps (int, default: 10000): Number of equilibration steps.
  • --prod_steps (int, default: 10000): Number of production (simulation) steps.
  • --sample_every (int, default: 100): Sample the state every N steps.
  • --seed (int, default: 42): PRNG seed.
  • --output (str, default: "g_r_plot.png"): Output filename for the g(r) plot.

Physics Concepts:

  • Lennard-Jones Potential: Simulates the interaction force between two neutral particles (U(r)U(r)), featuring a strong short-range repulsion (Pauli exclusion) and a weaker long-range attraction (van der Waals force).
  • Statistical Mechanics: The system is initialized with random positions and velocities (based on temperature kTkT). It undergoes an "Equilibration" phase to reach a stable state before the "Production" (data collection) phase begins.
  • Periodic Boundary Conditions: Simulates an infinite system by using a finite number of particles in a "box" that wraps around on itself.
  • Radial Distribution Function g(r)g(r): The final output, g(r)g(r), describes the probability of finding another particle at a distance rr from a reference particle. Its shape reveals the state of matter (e.g., solid, liquid, gas).

Mathematical & JAX Implementation:

  • Lennard-Jones Potential U(r)U(r):
    U(r)=4Ο΅[(Οƒr)12βˆ’(Οƒr)6] U(r) = 4\epsilon \left[ \left( \frac{\sigma}{r} \right)^{12} - \left( \frac{\sigma}{r} \right)^6 \right]
    Here, Ο΅\epsilon represents the depth of the potential well (energy scale of attraction), and Οƒ\sigma is the finite distance at which the inter-particle potential is zero (size parameter). The rβˆ’12r^{-12} term models the steep repulsive force due to overlapping electron clouds, while the rβˆ’6r^{-6} term captures the attractive dispersion forces. The total potential energy for the system is the sum over all unique particle pairs: Utotal=βˆ‘i<jU(rij)U_{\text{total}} = \sum_{i < j} U(r_{ij}). This is implemented in the `total_energy_fn` function, which efficiently computes this sum using vectorized JAX operations to avoid explicit loops.
  • Force Calculation (via `jax.grad`): The force on each particle is the negative gradient of the potential energy with respect to its position: Fβƒ—i=βˆ’βˆ‡iUtotal\vec{F}_i = -\nabla_i U_{\text{total}}. Instead of manually deriving the complex analytical expression for the force from the Lennard-Jones potential (which involves differentiating the pairwise terms), the script leverages JAX's automatic differentiation:
    force_fn = jit(grad(lambda R: -total_energy_fn(R)))
    Here, `grad` computes the vector-Jacobian product automatically, treating positions R⃗\vec{R} as the input. The negative sign ensures we get the force from the potential. This approach is highly advantageous in JAX for physics simulations, as it allows easy modification of potentials without re-deriving forces, and it compiles efficiently on TPUs.
  • Integration (Velocity Verlet): To propagate the positions Rβƒ—\vec{R} and velocities Vβƒ—\vec{V} over time, the `verlet_step` function employs the Velocity Verlet algorithm, a second-order symplectic integrator that conserves energy better than simpler Euler methods. It alternates velocity and position updates in a leapfrog manner:
    V⃗(t+12Δt)=V⃗(t)+12A⃗(t)Δt,R⃗(t+Δt)=R⃗(t)+V⃗(t+12Δt)Δt,A⃗(t+Δt)=F⃗(R⃗(t+Δt))m,V⃗(t+Δt)=V⃗(t+12Δt)+12A⃗(t+Δt)Δt. \begin{align*} \vec{V}\left(t + \frac{1}{2} \Delta t \right) &= \vec{V}(t) + \frac{1}{2} \vec{A}(t) \Delta t, \\ \vec{R}(t + \Delta t) &= \vec{R}(t) + \vec{V}\left(t + \frac{1}{2} \Delta t \right) \Delta t, \\ \vec{A}(t + \Delta t) &= \frac{\vec{F}(\vec{R}(t + \Delta t))}{m}, \\ \vec{V}(t + \Delta t) &= \vec{V}\left(t + \frac{1}{2} \Delta t \right) + \frac{1}{2} \vec{A}(t + \Delta t) \Delta t. \end{align*}
    This method is time-reversible and preserves the symplectic structure of Hamiltonian systems, reducing long-term energy drift in simulations. The acceleration A⃗=F⃗/m\vec{A} = \vec{F}/m is computed at each half-step using the force function.
  • Performance: `jax.jit` is applied to `verlet_step`, `equilibrate_fn`, `production_fn`, and `calculate_g_r`. `jax.lax.fori_loop` is used to compile the entire simulation loop onto the TPU for maximum speed.

Example Output: Radial Distribution Function

This plot shows the g(r)g(r). The example image provided shows a result where g(r)β‰ˆ0g(r) \approx 0, which can occur if the simulation parameters (like density `rho` or temperature `kT`) are not set to form a stable liquid structure, or if the production run is too short. A typical liquid would show distinct peaks.

Radial Distribution Function (g(r)) Plot

2. N-Body Black Hole Merger (nbody_bh_merger_sim_single-host_workload.py)

This script simulates the dynamics of N bodies (e.g., 3 black holes) under their mutual gravity and computes the resulting gravitational waves (GW).

How to Run (Interactive):

# (Basic) Run the script and press Enter to accept all default values
python3 nbody_bh_merger_sim_single-host_workload.py
# -> Number of black holes (default=3): [Enter]
# -> Mass of BH1 (default=30.0): [Enter]
# -> ... (etc.)

# (Advanced) Run the script and input custom values
python3 nbody_bh_merger_sim_single-host_workload.py
# -> Number of black holes: 5
# -> Mass of BH1: 50
# -> ... (etc.)
# -> Compute Lyapunov exponent? (y/n): n

This script is **interactive** and will prompt you for parameters in the console (Number of black holes, Mass, Separation, Velocity, etc.) instead of using command-line arguments.

Physics Concepts:

  • Newtonian Gravity (N-Body): Solves the classic N-body problem by calculating the gravitational force between all pairs of objects.
  • Gravitational Waves (Approximation): The script computes the GW strain (h+h_+) using the Quadrupole approximation, summing the contributions from each orbiting pair.
  • Chaos Theory: For N > 2, the system is often chaotic. The script can compute the Lyapunov Exponent (Ξ»\lambda), which measures the exponential rate at which nearby trajectories diverge, quantifying the system's chaos.

Mathematical & JAX Implementation:

  • ODE System: The N-body problem is formulated as a first-order system of ordinary differential equations (ODEs). The state vector YY concatenates all positions and velocities: Y=[rβƒ—1,…,rβƒ—N,vβƒ—1,…,vβƒ—N]T∈R6NY = [\vec{r}_1, \dots, \vec{r}_N, \vec{v}_1, \dots, \vec{v}_N]^T \in \mathbb{R}^{6N}. The time derivative is dYdt=[vβƒ—1,…,vβƒ—N,aβƒ—1,…,aβƒ—N]T\frac{dY}{dt} = [\vec{v}_1, \dots, \vec{v}_N, \vec{a}_1, \dots, \vec{a}_N]^T, where aβƒ—i\vec{a}_i are the accelerations due to gravity. This structure allows efficient vectorized computation in JAX.
  • Gravitational Acceleration aβƒ—g,i\vec{a}_{g,i}:
    aβƒ—g,i=βˆ‘jβ‰ iGmjrβƒ—jβˆ’rβƒ—i∣rβƒ—jβˆ’rβƒ—i∣3 \vec{a}_{g,i} = \sum_{j \neq i} G m_j \frac{\vec{r}_j - \vec{r}_i}{|\vec{r}_j - \vec{r}_i|^3}
    This arises from Newton's law of universal gravitation, Fβƒ—ij=βˆ’Gmimjrβƒ—iβˆ’rβƒ—j∣rβƒ—iβˆ’rβƒ—j∣3\vec{F}_{ij} = -G m_i m_j \frac{\vec{r}_i - \vec{r}_j}{|\vec{r}_i - \vec{r}_j|^3}, divided by mim_i for acceleration. The 1/r21/r^2 force law leads to the 1/r31/r^3 form after normalization. Regularization (e.g., softening) may be added to avoid singularities at close encounters, implemented in `pairwise_acc` using JAX's vectorized broadcasting for all pairs.
  • Numerical Integration (RK4): To solve the ODE system dYdt=f(Y,t)\frac{dY}{dt} = f(Y, t), the script uses the classical 4th-order Runge-Kutta (RK4) method in `rk4_step`. RK4 approximates the solution at t+Ξ”tt + \Delta t via four evaluations of ff at intermediate points: k1=f(Y,t),k2=f(Y+Ξ”t2k1,t+Ξ”t2), k_1 = f(Y, t), \quad k_2 = f(Y + \frac{\Delta t}{2} k_1, t + \frac{\Delta t}{2}), k3=f(Y+Ξ”t2k2,t+Ξ”t2),k4=f(Y+Ξ”tk3,t+Ξ”t), k_3 = f(Y + \frac{\Delta t}{2} k_2, t + \frac{\Delta t}{2}), \quad k_4 = f(Y + \Delta t k_3, t + \Delta t), Y(t+Ξ”t)=Y(t)+Ξ”t6(k1+2k2+2k3+k4). Y(t + \Delta t) = Y(t) + \frac{\Delta t}{6} (k_1 + 2k_2 + 2k_3 + k_4). This provides a local error of O(Ξ”t5)O(\Delta t^5), making it accurate and stable for non-stiff systems like Newtonian gravity.
  • GW Strain h+h_+: In the post-Newtonian quadrupole approximation for inspiraling binaries, the plus polarization strain is
    h+=4D(GMc2)5/3(Ο€fGWc3)2/3cos⁑Φ(t), h_+ = \frac{4}{D} \left( \frac{G \mathcal{M}}{c^2} \right)^{5/3} \left( \frac{\pi f_{\text{GW}}}{c^3} \right)^{2/3} \cos \Phi(t),
    where M\mathcal{M} is the chirp mass M=(m1m2)3/5(m1+m2)1/5\mathcal{M} = \frac{(m_1 m_2)^{3/5}}{(m_1 + m_2)^{1/5}}, fGW=2forbf_{\text{GW}} = 2 f_{\text{orb}} is the GW frequency (twice the orbital due to quadrupole), DD is the distance, and Ξ¦(t)=2βˆ«Ο‰(t)dt\Phi(t) = 2 \int \omega(t) dt is the phase with instantaneous angular frequency Ο‰=G(mi+mj)/r3\omega = \sqrt{G(m_i + m_j)/r^3} from Kepler's law. The amplitude scales as A∝(MΟ‰)2/3A \propto (\mathcal{M} \omega)^{2/3}. The script sums contributions over pairs, assuming far-field propagation.
  • Lyapunov Exponent Ξ»\lambda:
    Ξ»β‰ˆ1tln⁑(∣δ(t)∣∣δ(0)∣) \lambda \approx \frac{1}{t} \ln \left( \frac{|\delta(t)|}{|\delta(0)|} \right)
    This quantifies sensitivity to initial conditions in chaotic systems. Ξ΄(0)\delta(0) is a small perturbation (e.g., 10βˆ’1010^{-10} in position), and ∣δ(t)∣|\delta(t)| is the Euclidean norm of the deviation after time tt between the nominal trajectory Y(t)Y(t) and perturbed Yβ€²(t)Y'(t). The exponent Ξ»>0\lambda > 0 indicates chaos, with divergence eΞ»te^{\lambda t}. The script runs parallel simulations for YY and Yβ€²Y' using JAX's `vmap`.
  • Performance: The entire simulation loop `simulate_nbody` is JIT-compiled using `jax.jit` and `jax.lax.scan`, allowing the whole trajectory to be computed efficiently on the TPU.

Example Simulation Outputs:

N-Body Positions and Cumulative GW

Particle positions (left) and cumulative GW signal (right) at a specific time.

N-Body Gravitational Waveform Plot

Standalone gravitational waveform (h+ strain) over time.

N-Body 3D Trajectories Plot

3D space-time trajectories of the N-bodies, showing their paths over time.

3. Three Particles in Non-Uniform E/M Field (three_particles_em_nonuni_single-host_workload.py)

This is a simple simulation of 3 charged particles moving under their mutual gravity and an external, non-uniform electromagnetic field.

How to Run:

# (Basic) Run with default settings (only gravity and constant B-field)
python3 three_particles_em_nonuni_single-host_workload.py

# (Advanced) Add non-uniform B-field, E-field, and run for more steps
python3 three_particles_em_nonuni_single-host_workload.py --Bz 5.0 --Bk 0.5 --Ex 0.2 --n_steps 2000

All Arguments:

  • --dt (float, default: 0.01): Time step size.
  • --n_steps (int, default: 1000): Total number of steps to simulate.
  • --G (float, default: 1.0): Gravitational constant.
  • --Bz (float, default: 1.0): Constant component of the magnetic field (Z-axis).
  • --Bk (float, default: 0.0): Gradient of the magnetic field along x (Bz = Bz + Bk*x).
  • --Ex (float, default: 0.0): Electric field strength (X-axis).
  • --Ey (float, default: 0.0): Electric field strength (Y-axis).

Physics Concepts:

  • Superposition of Forces: The net force on each particle is the vector sum of gravity from other particles and the Lorentz force from the E/M field: Fβƒ—i=βˆ‘jβ‰ iFβƒ—g,ij+Fβƒ—Lorentz,i\vec{F}_i = \sum_{j \neq i} \vec{F}_{g,ij} + \vec{F}_{Lorentz,i}
  • Lorentz Force: The force from the E/M field is Fβƒ—Lorentz=q(Eβƒ—+vβƒ—Γ—Bβƒ—)\vec{F}_{Lorentz}=q(\vec{E}+\vec{v}\times\vec{B}).
  • Non-Uniform Field: The magnetic field BzB_z is position-dependent (Bz=B0+Bkβ‹…xB_z = B_0 + B_k \cdot x), making the dynamics more complex than a simple circular or helical motion.

Mathematical & JAX Implementation:

  • Total Acceleration aβƒ—=Fβƒ—/m\vec{a}=\vec{F}/m: Newton's second law gives aβƒ—i=Fβƒ—i/mi\vec{a}_i = \vec{F}_i / m_i. The `acceleration` function decomposes this into three additive components, computed in parallel:
    • Gravitational (aga_g):
      aβƒ—g,i=βˆ‘jβ‰ iGmjrβƒ—jβˆ’rβƒ—i∣rβƒ—jβˆ’rβƒ—i∣3 \vec{a}_{g,i} = \sum_{j \neq i} G m_j \frac{\vec{r}_j - \vec{r}_i}{|\vec{r}_j - \vec{r}_i|^3}
      Identical to the N-body case, this captures pairwise inverse-square attractions (in `pairwise_acc`). For small N=3, it's computed directly without approximation.
    • Electric (aEa_E): aβƒ—E=qmEβƒ— \vec{a}_E = \frac{q}{m} \vec{E} A constant acceleration if Eβƒ—\vec{E} is uniform (e.g., a constant field); implemented in `elec_acc` as a simple scaling.
    • Magnetic (aBa_B): aβƒ—B=qm(vβƒ—Γ—Bβƒ—) \vec{a}_B = \frac{q}{m} (\vec{v} \times \vec{B}) Velocity-dependent, causing deflection perpendicular to vβƒ—\vec{v} and Bβƒ—\vec{B}. No work is done (aβƒ—Bβ‹…vβƒ—=0 \vec{a}_B \cdot \vec{v} = 0 ), preserving kinetic energy in pure magnetic fields.
  • Cross Product in 2D: The simulation is in the xy-plane with Bβƒ—=[0,0,Bz(x)]\vec{B} = [0, 0, B_z(x)] (out-of-plane, position-dependent) and vβƒ—=[vx,vy,0]\vec{v} = [v_x, v_y, 0]. The vector cross product simplifies via the right-hand rule:
    vβƒ—Γ—Bβƒ—=det⁑∣i^j^k^vxvy000Bz∣=i^(vyBz)βˆ’j^(vxBz)+k^(0)=[vyBz,βˆ’vxBz,0] \vec{v} \times \vec{B} = \det \begin{vmatrix} \hat{i} & \hat{j} & \hat{k} \\ v_x & v_y & 0 \\ 0 & 0 & B_z \end{vmatrix} = \hat{i}(v_y B_z) - \hat{j}(v_x B_z) + \hat{k}(0) = [v_y B_z, -v_x B_z, 0]
    Thus, aβƒ—B=qm[vyBz,βˆ’vxBz]\vec{a}_B = \frac{q}{m} [v_y B_z, -v_x B_z], inducing cyclotron motion with radius rc=mvβŠ₯qBzr_c = \frac{m v_\perp}{q B_z}. The non-uniformity Bz(x)=B0+BkxB_z(x) = B_0 + B_k x introduces gradients, leading to drifts like Eβƒ—Γ—Bβƒ—\vec{E} \times \vec{B} or gradient drifts. This matches the JAX code in `mag_acc`: `qm * jnp.array([vy * bz, -vx * bz])`, where `qm = q/m`.
  • Integration (Leapfrog/Velocity Verlet): Similar to the MD simulation, the `step` function applies Velocity Verlet for the combined gravitational + Lorentz forces. This integrator handles the velocity-dependent magnetic term accurately, maintaining second-order accuracy and symplectic properties for the conservative gravitational part (though magnetism adds dissipation in some senses).
  • Performance: `jax.vmap` is used to calculate the acceleration for all particles in parallel, and `jax.jit` compiles the entire `step` function.

Example Initial State:

Three particle initial positions

Initial positions of the three particles in the 2D plane.

4. Variational & Diffusion Monte Carlo (vmc_dmc_jax_quantum_harmonic_oscillator.py)

This script implements Variational Monte Carlo (VMC) followed by Diffusion Monte Carlo (DMC) to approximate the ground state energy and wavefunction of a D-dimensional isotropic quantum harmonic oscillator using JAX for high-performance computation. It optimizes a variational parameter Ξ± via stochastic gradient descent in VMC, then refines the distribution via branching diffusion in DMC. Outputs include energy convergence plots, marginal probability density histograms vs. exact ground state, and optional GIF animations of walker distributions.

How to Run:

# (Basic) Run with default settings (N=10k walkers, 3k VMC epochs, 500 DMC steps, 3D)
python3 vmc_dmc_jax_quantum_harmonic_oscillator.py

# (Advanced) Run a longer simulation in 1D with more epochs and steps
python3 vmc_dmc_jax_quantum_harmonic_oscillator.py --n_epochs 5000 --n_dmc 1000 --dim 1

All Arguments:

  • --n_walkers (int, default: 10000): Number of Monte Carlo walkers.
  • --n_epochs (int, default: 3000): Number of VMC optimization epochs.
  • --n_equil (int, default: 100): Number of equilibration steps per VMC epoch.
  • --step_size (float, default: 2.0): Proposal step size for Metropolis-Hastings sampling.
  • --lr (float, default: 0.02): Learning rate for Adam optimizer in VMC.
  • --n_dmc (int, default: 500): Number of DMC propagation steps.
  • --dmc_dt (float, default: 0.01): Time step size for DMC diffusion/branching.
  • --dim (int, default: 3): Spatial dimensionality of the harmonic oscillator.
  • --no-gif (flag): Disable generation of VMC and DMC animation GIFs (vmc_animation.gif and dmc_animation.gif).
  • --no-plot (flag): Disable display of final matplotlib plots (energy convergence, histograms).

Physics Concepts:

  • Quantum Harmonic Oscillator (QHO): The D-dimensional isotropic QHO has a Hamiltonian H=βˆ’β„22mβˆ‡2+12mΟ‰2r2H = -\frac{\hbar^2}{2m} \nabla^2 + \frac{1}{2} m \omega^2 \mathbf{r}^2 (set ℏ=m=Ο‰=1\hbar = m = \omega = 1), with exact ground state energy E0=D/2E_0 = D/2 and Gaussian wavefunction Ξ¨0(r)∝eβˆ’r2/2\Psi_0(\mathbf{r}) \propto e^{-\mathbf{r}^2/2}. The script approximates this via stochastic sampling of the wavefunction.
  • Variational Monte Carlo (VMC): Uses a trial wavefunction Ξ¨T(r;Ξ±)=eβˆ’Ξ±r2\Psi_T(\mathbf{r}; \alpha) = e^{-\alpha \mathbf{r}^2} to minimize the expectation value ⟨E⟩=⟨ΨT∣H∣ΨT⟩⟨ΨT∣ΨT⟩\langle E \rangle = \frac{\langle \Psi_T | H | \Psi_T \rangle}{\langle \Psi_T | \Psi_T \rangle} via Monte Carlo integration over samples from ∣ΨT∣2|\Psi_T|^2. Optimizes Ξ±\alpha using stochastic gradients.
  • Diffusion Monte Carlo (DMC): Projects onto the ground state via imaginary-time evolution of the SchrΓΆdinger equation, interpreted as a branching random walk with drift (from the quantum force) and diffusion (from the kinetic term), biased by local energy branching.
  • Walker Sampling: "Walkers" are Monte Carlo samples representing the probability distribution. Metropolis-Hastings ensures ergodic sampling in VMC, while DMC uses resampling based on weights to focus on low-energy configurations.

Mathematical & JAX Implementation:

  • Local Energy EL(r)E_L(\mathbf{r}): For a trial wavefunction, the local energy is EL=HΞ¨TΞ¨T=βˆ’12βˆ‡2ln⁑ΨTβˆ’12βˆ£βˆ‡ln⁑ΨT∣2+V(r)E_L = \frac{H \Psi_T}{\Psi_T} = -\frac{1}{2} \nabla^2 \ln \Psi_T - \frac{1}{2} |\nabla \ln \Psi_T|^2 + V(\mathbf{r}), where V(r)=12r2V(\mathbf{r}) = \frac{1}{2} \mathbf{r}^2 is the potential. For the Gaussian trial, ln⁑ΨT=βˆ’Ξ±r2\ln \Psi_T = -\alpha \mathbf{r}^2, so βˆ‡ln⁑ΨT=βˆ’2Ξ±r\nabla \ln \Psi_T = -2\alpha \mathbf{r} and βˆ‡2ln⁑ΨT=βˆ’2Ξ±D\nabla^2 \ln \Psi_T = -2\alpha D. Thus, EL=Ξ±Dβˆ’2Ξ±2r2+12r2E_L = \alpha D - 2\alpha^2 \mathbf{r}^2 + \frac{1}{2} \mathbf{r}^2. The VMC energy is the Monte Carlo average ⟨ELβŸ©β‰ˆE0\langle E_L \rangle \approx E_0. For Ξ±=0.5\alpha=0.5, ⟨EL⟩=0.5D\langle E_L \rangle = 0.5 D exactly. Implemented in `local_energy` using JAX's `grad` for the Laplacian and gradient terms.
  • Metropolis-Hastings Sampling (VMC): Samples from ∣ΨT∣2|\Psi_T|^2 using a symmetric proposal rβ€²=r+δϡ\mathbf{r}' = \mathbf{r} + \delta \mathbf{\epsilon} (Ξ΄=\delta = `step_size`), with acceptance A=min⁑(1,∣ΨT(rβ€²)∣2/∣ΨT(r)∣2)=min⁑(1,e2(ln⁑ΨT(rβ€²)βˆ’ln⁑ΨT(r)))A = \min(1, |\Psi_T(\mathbf{r}')|^2 / |\Psi_T(\mathbf{r})|^2) = \min(1, e^{2(\ln \Psi_T(\mathbf{r}') - \ln \Psi_T(\mathbf{r}))}). Equilibration discards initial steps to ensure sampling from the target distribution. Vectorized over walkers via `vmap(metropolis_step)`, and looped with `fori_loop` for efficiency. The variational gradient is βˆ‚βŸ¨EβŸ©βˆ‚Ξ±β‰ˆ2⟨(ELβˆ’βŸ¨EL⟩)βˆ‚ln⁑ΨTβˆ‚Ξ±βŸ©\frac{\partial \langle E \rangle}{\partial \alpha} \approx 2 \langle (E_L - \langle E_L \rangle) \frac{\partial \ln \Psi_T}{\partial \alpha} \rangle, optimized with Adam (`optax.adam`).
  • DMC Propagation: Solves the imaginary-time SchrΓΆdinger equation βˆ‚Ξ¨βˆ‚Ο„=βˆ’(Hβˆ’ET)Ξ¨\frac{\partial \Psi}{\partial \tau} = -\left( H - E_T \right) \Psi (with reference energy ETβ‰ˆE0E_T \approx E_0) via a Green's function approximation: drift-diffusion with velocity F=βˆ‡ln⁑ΨT=βˆ’2Ξ±r\mathbf{F} = \nabla \ln \Psi_T = -2\alpha \mathbf{r} and branching weight w=eβˆ’(ELβˆ’ET)Δτw = e^{-(E_L - E_T) \Delta \tau}. Walkers are resampled via multinomial choice based on normalized weights, then updated as rβ€²=r+FΔτ+Δτη\mathbf{r}' = \mathbf{r} + \mathbf{F} \Delta \tau + \sqrt{\Delta \tau} \mathbf{\eta} (η∼N(0,1)\mathbf{\eta} \sim \mathcal{N}(0,1)). The mixed estimator for energy is the average ETE_T over steps, with statistical error from variance. Implemented with `jax.lax.scan` for the full trajectory, vectorized via `vmap` for local energy and drift.
  • Performance: All core functions (`local_energy`, `metropolis_step`, `dmc_step_body`) are JIT-compiled. `vmap` parallelizes over walkers (up to 10k), and `fori_loop`/`scan` compile loops for TPU acceleration. Progress tracking uses `rich` for console output, with optional GIFs via `imageio` for visualizing walker histograms evolving toward the exact Gaussian marginal.
vmc animation

Utility Scripts

The project includes utility scripts in the utils/ folder to help:

  • check_deps.py - Checks if all required libraries (jax, rich, psutil) are installed.
  • jax_devices.py - Lists all JAX devices visible to the system (e.g., TPU, CPU).
  • plt.py - Uses Matplotlib and Pandas to create a summary plot of TFLOPS and Avg. Time from benchmark results.