JAX TPU Benchmark & Physics Simulation π
This project provides a comprehensive benchmark script (`tpus_benchmark_single-host_workload.py`) designed to measure and analyze the performance of JAX on Google Cloud TPUs. It tests a wider variety of computations, including 2D/3D matrix operations, 2D/3D FFT, and memory bandwidth, offering insights into performance scaling from single-core (JIT) to multi-core (PMAP) operations.
It also includes several physics simulation scripts accelerated with JAX, such as an N-Body black hole merger simulation, a Molecular Dynamics simulation, and a three-particle simulation in a non-uniform EM field.
β¨ Key Features
Multi-Mode Benchmarking
Tests diverse operations: 2D/3D Matrix Ops (`jnp.dot`, `jnp.matmul`), 2D/3D FFT (`jnp.fft.fftn`), and Memory Bandwidth (`jnp.copy`).
Physics Simulations
Includes N-Body Black Hole Merger, Molecular Dynamics (Lennard-Jones), and Three-Particle EM simulations, all accelerated with JAX.
Multi-Core Scaling Analysis
Analyzes parallel processing efficiency by comparing single-core (`jax.jit`) to multi-core (`jax.pmap`) performance.
System & Device Introspection
Gathers system info (OS, CPU, RAM) and lists all available JAX devices (e.g., TPU), their types, and accelerator memory.
Rich Reporting & Plotting
Uses the `rich` library for formatted console tables and `matplotlib` to automatically generate a PNG plot of benchmark results.
Configurable Workloads
Allows customization of parameters via command-line, such as `--matrix_size`, `--steps`, `--precision`, and `--csv` output.
Main Benchmark Script
tpus_benchmark_single-host_workload.py
This is the main script of the project, designed specifically to measure JAX performance on TPUs. It tests a variety of operations to provide a comprehensive overview of processing capabilities.
The main tests include:
- 2D Matrix Operations (jnp.dot)
- 3D Tensor Operations (jnp.matmul)
- 2D & 3D FFT (jnp.fft.fftn)
- Memory Bandwidth (jnp.copy, jnp.sum)
How to Run:
# (Basic) Run with all default settings
python3 tpus_benchmark_single-host_workload.py
# (Advanced) Run a lighter workload for quick testing
python3 tpus_benchmark_single-host_workload.py -w 5 -m 500 -mxs 8192 -md 64
# (Advanced) Run a full benchmark on 8 cores and export results to CSV
python3 tpus_benchmark_single-host_workload.py --max_cores 8 --csv results.csv
All Arguments:
-w/--warmup(int, default: 10): The number of "warmup" runs.-m/--steps(int, default: 1000): The number of test iterations to average.-mxs/--matrix_size(int, default: 16384): The size (N) for(N, N)matrices.-md/--matrix_depth(int, default: 128): The depth (D) for 3D tensors(D, N, N).-c/--conv_size(int, default: 256): The size of the convolution input.-b/--batch_size(int, default: 32): The batch size.--precision(str, default: "float32"): Data precision (float32orbfloat16).--max_cores(int, default: 0): Maximum number of cores to test (0 = all).--csv(str, default: None): Filename to output results to a CSV file.
π v4-8 Benchmark Analysis Report
The table below shows the results from running the updated tpus_benchmark_single-host_workload.py script in version 4-8. This version includes optimizations for better single-core performance, but multi-core scaling still reveals lingering communication challenges. Results are dynamically loaded from tpus_benchmark_single-host_workload-v4-8.csv.
| avg_ms | test | bandwidth_gbs | tflops | cores |
|---|---|---|---|---|
| 1.9318113471999823 | 2D | 156.51926782429183 | 1 | |
| 15.564332402400032 | 3D | 155.4146698084012 | 1 | |
| 6.320432990399786 | 2D_FFT | 0.31853291112459925 | 1 | |
| 80.05486680440008 | 3D_FFT | 0.3017829146980997 | 1 | |
| 0.08025490039981378 | Bandwidth | 15575.372890287712 | 1 | |
| 3.903032155999972 | 2D | 154.93887088456884 | 2 | |
| 14.862178250000214 | 3D | 162.75713696995695 | 2 | |
| 7.493440144000124 | 2D_FFT | 0.5373408958532856 | 2 | |
| 45.38341318920029 | 3D_FFT | 0.5323352595645906 | 2 | |
| 8.133642379199955 | Bandwidth | 307.36537008231556 | 2 | |
| 5.591001270800189 | 2D | 216.32311136650853 | 4 | |
| 10.879976033999991 | 3D | 222.32820858686114 | 4 | |
| 8.381693479199749 | 2D_FFT | 0.9607919568980557 | 4 | |
| 24.991582564800044 | 3D_FFT | 0.9666931246694058 | 4 | |
| 16.474200310000015 | Bandwidth | 303.504868577138 | 4 |
Analysis:
Version 4-8 shows excellent single-core (JIT) performance, especially with an Internal Bandwidth reaching 15,575 GB/s. However, scaling to multi-core (PMAP) still reveals significant challenges in inter-core communication.
- Compute-Bound Analysis (2D/3D/FFT): Compute performance (TFLOPS) scales up with more cores. For example, 3D MatMul increases from ~155 TFLOPS (1 core) to ~222 TFLOPS (4 cores), and 3D FFT increases from 0.30 TFLOPS to 0.97 TFLOPS. This indicates that the TPU scales compute-bound tasks well.
- Bandwidth Analysis: This is the most problematic area. Bandwidth performance drops dramatically from 15,575 GB/s on 1 core to just 307 GB/s (2 cores) and 303 GB/s (4 cores). This suggests the test is no longer measuring HBM speed but is instead bottlenecked by the Inter-Core Interconnect (ICI) as `pmap` forces data exchange.
Conclusion: While v4-8 delivers excellent single-core compute performance and can scale compute tasks effectively, multi-core data bandwidth under `pmap` remains a significant challenge.
π¬ Physics Simulation Examples
This project includes several physics simulation scripts to demonstrate the power of JAX (e.g., jit, vmap, grad) in accelerating complex scientific computations.
1. Molecular Dynamics (molecular_dynamics_jax_single-host_workload.py)
This script implements a 2D Molecular Dynamics (MD) simulation of a Lennard-Jones fluid, written purely in JAX.
How to Run:
# (Basic) Run with default settings (N=400, 10k steps)
python3 molecular_dynamics_jax_single-host_workload.py
# (Advanced) Run a longer simulation with more particles and lower density
python3 molecular_dynamics_jax_single-host_workload.py --N 800 --rho 0.7 --prod_steps 50000 --eq_steps 20000
All Arguments:
--N(int, default: 400): Number of particles.--rho(float, default: 0.8): Density.--kT(float, default: 1.0): Temperature (kT).--dt(float, default: 1e-3): Time step size.--eq_steps(int, default: 10000): Number of equilibration steps.--prod_steps(int, default: 10000): Number of production (simulation) steps.--sample_every(int, default: 100): Sample the state every N steps.--seed(int, default: 42): PRNG seed.--output(str, default: "g_r_plot.png"): Output filename for the g(r) plot.
Physics Concepts:
- Lennard-Jones Potential: Simulates the interaction force between two neutral particles (), featuring a strong short-range repulsion (Pauli exclusion) and a weaker long-range attraction (van der Waals force).
- Statistical Mechanics: The system is initialized with random positions and velocities (based on temperature ). It undergoes an "Equilibration" phase to reach a stable state before the "Production" (data collection) phase begins.
- Periodic Boundary Conditions: Simulates an infinite system by using a finite number of particles in a "box" that wraps around on itself.
- Radial Distribution Function : The final output, , describes the probability of finding another particle at a distance from a reference particle. Its shape reveals the state of matter (e.g., solid, liquid, gas).
Mathematical & JAX Implementation:
- Lennard-Jones Potential :
Here, represents the depth of the potential well (energy scale of attraction), and is the finite distance at which the inter-particle potential is zero (size parameter). The term models the steep repulsive force due to overlapping electron clouds, while the term captures the attractive dispersion forces. The total potential energy for the system is the sum over all unique particle pairs: . This is implemented in the `total_energy_fn` function, which efficiently computes this sum using vectorized JAX operations to avoid explicit loops.
- Force Calculation (via `jax.grad`): The force on each particle is the negative gradient of the potential energy with respect to its position: . Instead of manually deriving the complex analytical expression for the force from the Lennard-Jones potential (which involves differentiating the pairwise terms), the script leverages JAX's automatic differentiation:
Here, `grad` computes the vector-Jacobian product automatically, treating positions as the input. The negative sign ensures we get the force from the potential. This approach is highly advantageous in JAX for physics simulations, as it allows easy modification of potentials without re-deriving forces, and it compiles efficiently on TPUs.force_fn = jit(grad(lambda R: -total_energy_fn(R))) - Integration (Velocity Verlet): To propagate the positions and velocities over time, the `verlet_step` function employs the Velocity Verlet algorithm, a second-order symplectic integrator that conserves energy better than simpler Euler methods. It alternates velocity and position updates in a leapfrog manner:
This method is time-reversible and preserves the symplectic structure of Hamiltonian systems, reducing long-term energy drift in simulations. The acceleration is computed at each half-step using the force function.
- Performance: `jax.jit` is applied to `verlet_step`, `equilibrate_fn`, `production_fn`, and `calculate_g_r`. `jax.lax.fori_loop` is used to compile the entire simulation loop onto the TPU for maximum speed.
Example Output: Radial Distribution Function
This plot shows the . The example image provided shows a result where , which can occur if the simulation parameters (like density `rho` or temperature `kT`) are not set to form a stable liquid structure, or if the production run is too short. A typical liquid would show distinct peaks.
2. N-Body Black Hole Merger (nbody_bh_merger_sim_single-host_workload.py)
This script simulates the dynamics of N bodies (e.g., 3 black holes) under their mutual gravity and computes the resulting gravitational waves (GW).
How to Run (Interactive):
# (Basic) Run the script and press Enter to accept all default values
python3 nbody_bh_merger_sim_single-host_workload.py
# -> Number of black holes (default=3): [Enter]
# -> Mass of BH1 (default=30.0): [Enter]
# -> ... (etc.)
# (Advanced) Run the script and input custom values
python3 nbody_bh_merger_sim_single-host_workload.py
# -> Number of black holes: 5
# -> Mass of BH1: 50
# -> ... (etc.)
# -> Compute Lyapunov exponent? (y/n): n
This script is **interactive** and will prompt you for parameters in the console (Number of black holes, Mass, Separation, Velocity, etc.) instead of using command-line arguments.
Physics Concepts:
- Newtonian Gravity (N-Body): Solves the classic N-body problem by calculating the gravitational force between all pairs of objects.
- Gravitational Waves (Approximation): The script computes the GW strain () using the Quadrupole approximation, summing the contributions from each orbiting pair.
- Chaos Theory: For N > 2, the system is often chaotic. The script can compute the Lyapunov Exponent (), which measures the exponential rate at which nearby trajectories diverge, quantifying the system's chaos.
Mathematical & JAX Implementation:
- ODE System: The N-body problem is formulated as a first-order system of ordinary differential equations (ODEs). The state vector concatenates all positions and velocities: . The time derivative is , where are the accelerations due to gravity. This structure allows efficient vectorized computation in JAX.
- Gravitational Acceleration :
This arises from Newton's law of universal gravitation, , divided by for acceleration. The force law leads to the form after normalization. Regularization (e.g., softening) may be added to avoid singularities at close encounters, implemented in `pairwise_acc` using JAX's vectorized broadcasting for all pairs.
- Numerical Integration (RK4): To solve the ODE system , the script uses the classical 4th-order Runge-Kutta (RK4) method in `rk4_step`. RK4 approximates the solution at via four evaluations of at intermediate points: This provides a local error of , making it accurate and stable for non-stiff systems like Newtonian gravity.
- GW Strain : In the post-Newtonian quadrupole approximation for inspiraling binaries, the plus polarization strain is
where is the chirp mass , is the GW frequency (twice the orbital due to quadrupole), is the distance, and is the phase with instantaneous angular frequency from Kepler's law. The amplitude scales as . The script sums contributions over pairs, assuming far-field propagation.
- Lyapunov Exponent :
This quantifies sensitivity to initial conditions in chaotic systems. is a small perturbation (e.g., in position), and is the Euclidean norm of the deviation after time between the nominal trajectory and perturbed . The exponent indicates chaos, with divergence . The script runs parallel simulations for and using JAX's `vmap`.
- Performance: The entire simulation loop `simulate_nbody` is JIT-compiled using `jax.jit` and `jax.lax.scan`, allowing the whole trajectory to be computed efficiently on the TPU.
Example Simulation Outputs:
Particle positions (left) and cumulative GW signal (right) at a specific time.
Standalone gravitational waveform (h+ strain) over time.
3D space-time trajectories of the N-bodies, showing their paths over time.
3. Three Particles in Non-Uniform E/M Field (three_particles_em_nonuni_single-host_workload.py)
This is a simple simulation of 3 charged particles moving under their mutual gravity and an external, non-uniform electromagnetic field.
How to Run:
# (Basic) Run with default settings (only gravity and constant B-field)
python3 three_particles_em_nonuni_single-host_workload.py
# (Advanced) Add non-uniform B-field, E-field, and run for more steps
python3 three_particles_em_nonuni_single-host_workload.py --Bz 5.0 --Bk 0.5 --Ex 0.2 --n_steps 2000
All Arguments:
--dt(float, default: 0.01): Time step size.--n_steps(int, default: 1000): Total number of steps to simulate.--G(float, default: 1.0): Gravitational constant.--Bz(float, default: 1.0): Constant component of the magnetic field (Z-axis).--Bk(float, default: 0.0): Gradient of the magnetic field along x (Bz = Bz + Bk*x).--Ex(float, default: 0.0): Electric field strength (X-axis).--Ey(float, default: 0.0): Electric field strength (Y-axis).
Physics Concepts:
- Superposition of Forces: The net force on each particle is the vector sum of gravity from other particles and the Lorentz force from the E/M field:
- Lorentz Force: The force from the E/M field is .
- Non-Uniform Field: The magnetic field is position-dependent (), making the dynamics more complex than a simple circular or helical motion.
Mathematical & JAX Implementation:
- Total Acceleration : Newton's second law gives . The `acceleration` function decomposes this into three additive components, computed in parallel:
- Gravitational ():
Identical to the N-body case, this captures pairwise inverse-square attractions (in `pairwise_acc`). For small N=3, it's computed directly without approximation.
- Electric (): A constant acceleration if is uniform (e.g., a constant field); implemented in `elec_acc` as a simple scaling.
- Magnetic (): Velocity-dependent, causing deflection perpendicular to and . No work is done (), preserving kinetic energy in pure magnetic fields.
- Gravitational ():
- Cross Product in 2D: The simulation is in the xy-plane with (out-of-plane, position-dependent) and . The vector cross product simplifies via the right-hand rule:
Thus, , inducing cyclotron motion with radius . The non-uniformity introduces gradients, leading to drifts like or gradient drifts. This matches the JAX code in `mag_acc`: `qm * jnp.array([vy * bz, -vx * bz])`, where `qm = q/m`.
- Integration (Leapfrog/Velocity Verlet): Similar to the MD simulation, the `step` function applies Velocity Verlet for the combined gravitational + Lorentz forces. This integrator handles the velocity-dependent magnetic term accurately, maintaining second-order accuracy and symplectic properties for the conservative gravitational part (though magnetism adds dissipation in some senses).
- Performance: `jax.vmap` is used to calculate the acceleration for all particles in parallel, and `jax.jit` compiles the entire `step` function.
Example Initial State:
Initial positions of the three particles in the 2D plane.
4. Variational & Diffusion Monte Carlo (vmc_dmc_jax_quantum_harmonic_oscillator.py)
This script implements Variational Monte Carlo (VMC) followed by Diffusion Monte Carlo (DMC) to approximate the ground state energy and wavefunction of a D-dimensional isotropic quantum harmonic oscillator using JAX for high-performance computation. It optimizes a variational parameter Ξ± via stochastic gradient descent in VMC, then refines the distribution via branching diffusion in DMC. Outputs include energy convergence plots, marginal probability density histograms vs. exact ground state, and optional GIF animations of walker distributions.
How to Run:
# (Basic) Run with default settings (N=10k walkers, 3k VMC epochs, 500 DMC steps, 3D)
python3 vmc_dmc_jax_quantum_harmonic_oscillator.py
# (Advanced) Run a longer simulation in 1D with more epochs and steps
python3 vmc_dmc_jax_quantum_harmonic_oscillator.py --n_epochs 5000 --n_dmc 1000 --dim 1
All Arguments:
--n_walkers(int, default: 10000): Number of Monte Carlo walkers.--n_epochs(int, default: 3000): Number of VMC optimization epochs.--n_equil(int, default: 100): Number of equilibration steps per VMC epoch.--step_size(float, default: 2.0): Proposal step size for Metropolis-Hastings sampling.--lr(float, default: 0.02): Learning rate for Adam optimizer in VMC.--n_dmc(int, default: 500): Number of DMC propagation steps.--dmc_dt(float, default: 0.01): Time step size for DMC diffusion/branching.--dim(int, default: 3): Spatial dimensionality of the harmonic oscillator.--no-gif(flag): Disable generation of VMC and DMC animation GIFs (vmc_animation.gifanddmc_animation.gif).--no-plot(flag): Disable display of final matplotlib plots (energy convergence, histograms).
Physics Concepts:
- Quantum Harmonic Oscillator (QHO): The D-dimensional isotropic QHO has a Hamiltonian (set ), with exact ground state energy and Gaussian wavefunction . The script approximates this via stochastic sampling of the wavefunction.
- Variational Monte Carlo (VMC): Uses a trial wavefunction to minimize the expectation value via Monte Carlo integration over samples from . Optimizes using stochastic gradients.
- Diffusion Monte Carlo (DMC): Projects onto the ground state via imaginary-time evolution of the SchrΓΆdinger equation, interpreted as a branching random walk with drift (from the quantum force) and diffusion (from the kinetic term), biased by local energy branching.
- Walker Sampling: "Walkers" are Monte Carlo samples representing the probability distribution. Metropolis-Hastings ensures ergodic sampling in VMC, while DMC uses resampling based on weights to focus on low-energy configurations.
Mathematical & JAX Implementation:
- Local Energy : For a trial wavefunction, the local energy is , where is the potential. For the Gaussian trial, , so and . Thus, . The VMC energy is the Monte Carlo average . For , exactly. Implemented in `local_energy` using JAX's `grad` for the Laplacian and gradient terms.
- Metropolis-Hastings Sampling (VMC): Samples from using a symmetric proposal ( `step_size`), with acceptance . Equilibration discards initial steps to ensure sampling from the target distribution. Vectorized over walkers via `vmap(metropolis_step)`, and looped with `fori_loop` for efficiency. The variational gradient is , optimized with Adam (`optax.adam`).
- DMC Propagation: Solves the imaginary-time SchrΓΆdinger equation (with reference energy ) via a Green's function approximation: drift-diffusion with velocity and branching weight . Walkers are resampled via multinomial choice based on normalized weights, then updated as (). The mixed estimator for energy is the average over steps, with statistical error from variance. Implemented with `jax.lax.scan` for the full trajectory, vectorized via `vmap` for local energy and drift.
- Performance: All core functions (`local_energy`, `metropolis_step`, `dmc_step_body`) are JIT-compiled. `vmap` parallelizes over walkers (up to 10k), and `fori_loop`/`scan` compile loops for TPU acceleration. Progress tracking uses `rich` for console output, with optional GIFs via `imageio` for visualizing walker histograms evolving toward the exact Gaussian marginal.
Utility Scripts
The project includes utility scripts in the utils/ folder to help:
- check_deps.py - Checks if all required libraries (jax, rich, psutil) are installed.
- jax_devices.py - Lists all JAX devices visible to the system (e.g., TPU, CPU).
- plt.py - Uses Matplotlib and Pandas to create a summary plot of TFLOPS and Avg. Time from benchmark results.