Solver Comparison
This benchmark compares ODE solver performance across different Python backends and MATLAB.
Test Configuration
- ODE: Driven damped pendulum
- t_span: (0, 1000)
- Tolerances: rtol=1e-8, atol=1e-6
- Sample sizes: 100 / 200 / 500 / 1,000 / 2,000 / 5,000 / 10,000 / 20,000 / 50,000 / 100,000 initial conditions
Solvers Tested
| Solver | Backend | Devices | Method |
|---|---|---|---|
| MATLAB ode45 | MATLAB | CPU | Dormand-Prince 5(4) |
| JAX/Diffrax | JAX | CPU, CUDA | Dormand-Prince 5(4) |
| torchdiffeq | PyTorch | CPU, CUDA | Dormand-Prince 5(4) |
| torchode | PyTorch | CUDA | Dormand-Prince 5(4) |
TorchOde Performance Issues
TorchOde was excluded from CPU benchmarks due to severe performance issues observed in previous runs. Additionally, it performs very poorly at larger N values (e.g., ~1133s at N=100k vs ~11s for JAX/Diffrax), indicating it is not properly optimized for GPU batch processing in this use case.
Results by Sample Size
| Solver | N = 100 | N = 200 | N = 500 | N = 1,000 | N = 2,000 | N = 5,000 | N = 10,000 | N = 20,000 | N = 50,000 | N = 100,000 |
|---|---|---|---|---|---|---|---|---|---|---|
| MATLAB ode45 | 0.198s | 0.287s | 0.604s | 1.108s | 2.117s | 5.827s | 11.637s | 20.567s | 51.712s | 102.258s |
| JAX (CPU) | 0.274s (0.721×) | 0.336s (0.853×) | 0.495s (1.220×) | 0.894s (1.239×) | 1.752s (1.208×) | 4.769s (1.222×) | 9.777s (1.190×) | 20.615s (0.998×) | 27.879s (1.855×) | 64.133s (1.594×) |
| JAX (CUDA) | 11.195s (0.018×) | 11.143s (0.026×) | 11.217s (0.054×) | 11.171s (0.099×) | 11.191s (0.189×) | 11.252s (0.518×) | 11.194s (1.040×) | 11.032s (1.864×) | 10.874s (4.755×) | 11.554s (8.850×) |
| torchdiffeq (CPU) | 2.026s (0.098×) | 2.212s (0.130×) | 2.496s (0.242×) | 2.920s (0.379×) | 3.879s (0.546×) | 8.122s (0.717×) | 10.267s (1.133×) | 15.716s (1.309×) | 22.785s (2.270×) | 36.376s (2.811×) |
| torchdiffeq (CUDA) | 31.878s (0.006×) | 31.809s (0.009×) | 32.471s (0.019×) | 32.051s (0.035×) | 32.713s (0.065×) | 32.686s (0.178×) | 32.268s (0.361×) | 32.383s (0.635×) | 30.692s (1.685×) | 30.143s (3.392×) |
| torchode (CUDA) | 27.637s (0.007×) | 27.954s (0.010×) | 28.933s (0.021×) | 28.683s (0.039×) | 27.700s (0.076×) | 29.733s (0.196×) | 29.729s (0.391×) | 33.206s (0.619×) | 40.019s (1.292×) | 444.743s (0.230×) |
| Julia (CPU) | 0.060s (3.301×) | 0.056s (5.094×) | 0.099s (6.088×) | 0.153s (7.220×) | 0.272s (7.790×) | 0.654s (8.905×) | 1.326s (8.775×) | 2.606s (7.894×) | 6.359s (8.132×) | 12.365s (8.270×) |
| Julia (GPU) | 0.051s (3.862×) | 0.020s (14.593×) | 0.023s (26.466×) | 0.028s (39.038×) | 0.033s (64.815×) | 0.049s (118.782×) | 0.104s (112.124×) | 0.192s (107.015×) | 0.384s (134.542×) | 0.741s (138.028×) |
| scipy (CPU) | 0.753s (0.262×) | 1.342s (0.214×) | 3.024s (0.200×) | 6.011s (0.184×) | 12.717s (0.166×) | 32.411s (0.180×) | 64.495s (0.180×) | — | — | — |
Scaling Plot

Key Findings
- JAX/Diffrax (CPU) is the fastest option for small to medium N (5k-10k samples)
- JAX/Diffrax (CUDA) achieves near-constant time (~11.5s) regardless of N, making it 8.9x faster than MATLAB at N=100k
- torchdiffeq scales reasonably well on both CPU and CUDA
- GPU acceleration only provides significant benefit at large sample sizes (N ≥ 100k)
- At small N, GPU overhead makes CPU solvers faster
Recommendations
JAX/Diffrax should be the default solver for pybasin. When a GPU is available, it delivers unmatched performance with near-constant integration time regardless of sample size—making it the clear choice for any workload.
Additionally, JAX/Diffrax is the only solver that supports event-based termination with individual trajectory stopping. This is critical for systems with unbounded trajectories (e.g., Lorenz "broken butterfly"), where some initial conditions diverge to infinity. With JAX events, each trajectory stops independently when it exceeds a threshold, while bounded trajectories continue integrating. Other solvers either stop all trajectories simultaneously or require workarounds like zero masking. See the Handling Unbounded Trajectories guide for details.
For CPU-only systems, the choice depends on scale. At smaller sample sizes (N ≤ 10k), JAX/Diffrax on CPU is the fastest option. However, at larger scales (N = 100k), torchdiffeq on CPU offers a meaningful ~2x improvement over JAX/Diffrax CPU (32s vs 64s) . While this difference is negligible for single runs, it becomes significant for parameter studies that require evaluating many ODE configurations. A study testing 50 parameter combinations at N=100k would save roughly 26 minutes by using torchdiffeq instead of JAX/Diffrax on CPU—though JAX/Diffrax on GPU would complete the same workload in just 10 minutes.
Hardware
Benchmarks run on:
- CPU: Intel Core Ultra 9 275HX
- GPU: NVIDIA GeForce RTX 5070 Ti Laptop GPU (12 GB VRAM)