Skip to content

Solver Comparison

This benchmark compares ODE solver performance across different Python backends and MATLAB.

Test Configuration

  • ODE: Driven damped pendulum
  • t_span: (0, 1000)
  • Tolerances: rtol=1e-8, atol=1e-6
  • Sample sizes: 100 / 200 / 500 / 1,000 / 2,000 / 5,000 / 10,000 / 20,000 / 50,000 / 100,000 initial conditions

Solvers Tested

Solver Backend Devices Method
MATLAB ode45 MATLAB CPU Dormand-Prince 5(4)
JAX/Diffrax JAX CPU, CUDA Dormand-Prince 5(4)
torchdiffeq PyTorch CPU, CUDA Dormand-Prince 5(4)
torchode PyTorch CUDA Dormand-Prince 5(4)

TorchOde Performance Issues

TorchOde was excluded from CPU benchmarks due to severe performance issues observed in previous runs. Additionally, it performs very poorly at larger N values (e.g., ~1133s at N=100k vs ~11s for JAX/Diffrax), indicating it is not properly optimized for GPU batch processing in this use case.

Results by Sample Size

Solver N = 100 N = 200 N = 500 N = 1,000 N = 2,000 N = 5,000 N = 10,000 N = 20,000 N = 50,000 N = 100,000
MATLAB ode45 0.198s 0.287s 0.604s 1.108s 2.117s 5.827s 11.637s 20.567s 51.712s 102.258s
JAX (CPU) 0.274s (0.721×) 0.336s (0.853×) 0.495s (1.220×) 0.894s (1.239×) 1.752s (1.208×) 4.769s (1.222×) 9.777s (1.190×) 20.615s (0.998×) 27.879s (1.855×) 64.133s (1.594×)
JAX (CUDA) 11.195s (0.018×) 11.143s (0.026×) 11.217s (0.054×) 11.171s (0.099×) 11.191s (0.189×) 11.252s (0.518×) 11.194s (1.040×) 11.032s (1.864×) 10.874s (4.755×) 11.554s (8.850×)
torchdiffeq (CPU) 2.026s (0.098×) 2.212s (0.130×) 2.496s (0.242×) 2.920s (0.379×) 3.879s (0.546×) 8.122s (0.717×) 10.267s (1.133×) 15.716s (1.309×) 22.785s (2.270×) 36.376s (2.811×)
torchdiffeq (CUDA) 31.878s (0.006×) 31.809s (0.009×) 32.471s (0.019×) 32.051s (0.035×) 32.713s (0.065×) 32.686s (0.178×) 32.268s (0.361×) 32.383s (0.635×) 30.692s (1.685×) 30.143s (3.392×)
torchode (CUDA) 27.637s (0.007×) 27.954s (0.010×) 28.933s (0.021×) 28.683s (0.039×) 27.700s (0.076×) 29.733s (0.196×) 29.729s (0.391×) 33.206s (0.619×) 40.019s (1.292×) 444.743s (0.230×)
Julia (CPU) 0.060s (3.301×) 0.056s (5.094×) 0.099s (6.088×) 0.153s (7.220×) 0.272s (7.790×) 0.654s (8.905×) 1.326s (8.775×) 2.606s (7.894×) 6.359s (8.132×) 12.365s (8.270×)
Julia (GPU) 0.051s (3.862×) 0.020s (14.593×) 0.023s (26.466×) 0.028s (39.038×) 0.033s (64.815×) 0.049s (118.782×) 0.104s (112.124×) 0.192s (107.015×) 0.384s (134.542×) 0.741s (138.028×)
scipy (CPU) 0.753s (0.262×) 1.342s (0.214×) 3.024s (0.200×) 6.011s (0.184×) 12.717s (0.166×) 32.411s (0.180×) 64.495s (0.180×)

Scaling Plot

Solver Scaling

Key Findings

  1. JAX/Diffrax (CPU) is the fastest option for small to medium N (5k-10k samples)
  2. JAX/Diffrax (CUDA) achieves near-constant time (~11.5s) regardless of N, making it 8.9x faster than MATLAB at N=100k
  3. torchdiffeq scales reasonably well on both CPU and CUDA
  4. GPU acceleration only provides significant benefit at large sample sizes (N ≥ 100k)
  5. At small N, GPU overhead makes CPU solvers faster

Recommendations

JAX/Diffrax should be the default solver for pybasin. When a GPU is available, it delivers unmatched performance with near-constant integration time regardless of sample size—making it the clear choice for any workload.

Additionally, JAX/Diffrax is the only solver that supports event-based termination with individual trajectory stopping. This is critical for systems with unbounded trajectories (e.g., Lorenz "broken butterfly"), where some initial conditions diverge to infinity. With JAX events, each trajectory stops independently when it exceeds a threshold, while bounded trajectories continue integrating. Other solvers either stop all trajectories simultaneously or require workarounds like zero masking. See the Handling Unbounded Trajectories guide for details.

For CPU-only systems, the choice depends on scale. At smaller sample sizes (N ≤ 10k), JAX/Diffrax on CPU is the fastest option. However, at larger scales (N = 100k), torchdiffeq on CPU offers a meaningful ~2x improvement over JAX/Diffrax CPU (32s vs 64s) . While this difference is negligible for single runs, it becomes significant for parameter studies that require evaluating many ODE configurations. A study testing 50 parameter combinations at N=100k would save roughly 26 minutes by using torchdiffeq instead of JAX/Diffrax on CPU—though JAX/Diffrax on GPU would complete the same workload in just 10 minutes.

Hardware

Benchmarks run on:

  • CPU: Intel Core Ultra 9 275HX
  • GPU: NVIDIA GeForce RTX 5070 Ti Laptop GPU (12 GB VRAM)