Skip to content

Solvers

Solvers numerically integrate ODE systems from batches of initial conditions. Every solver in pybasin conforms to SolverProtocol, accepts PyTorch tensors as input, returns PyTorch tensors as output, and supports persistent disk caching -- regardless of the underlying numerical backend.

Unified Tensor Interface

All solvers accept torch.Tensor inputs and return torch.Tensor outputs. Internal conversions to JAX arrays, NumPy arrays, or other formats happen transparently. You do not need to handle backend-specific tensor types.

Available Solvers

Class Backend CPU GPU (CUDA) Event Functions Recommended For
TorchDiffEqSolver torchdiffeq Yes Yes No Default solver -- works out of the box
JaxSolver JAX/Diffrax Yes Yes Yes Fastest on GPU (requires pybasin[jax])
TorchOdeSolver torchode Yes Yes No Independent per-trajectory step sizes (requires pybasin[torchode])
ScipyParallelSolver scipy/sklearn Yes No No Debugging, reference baselines

ODE System Pairing

Each solver backend expects a specific ODE system base class:

Solver ODE System Class Why
JaxSolver JaxODESystem Uses pure JAX operations for JIT compilation and jax.vmap batching
TorchDiffEqSolver ODESystem Wraps a torch.nn.Module with standard PyTorch tensor operations
TorchOdeSolver ODESystem Same PyTorch interface as TorchDiffEqSolver
ScipyParallelSolver NumpyODESystem Passed directly to solve_ivp as fun; no PyTorch-to-NumPy conversion overhead

All three ODE base classes use the ode(t, y, p) signature, where p is a flat parameter array built from the TypedDict field declaration order (see Defining ODE Systems below). JaxODESystem uses jax.numpy operations for JIT compilation and vectorized GPU execution. ODESystem uses torch operations and inherits from torch.nn.Module, so its parameters can be moved between devices with .to(device). NumpyODESystem uses plain NumPy operations; the solver passes the parameter array to scipy.integrate.solve_ivp via its args keyword, so each trajectory can receive its own parameter set without mutating class state.

When solver=None is passed to BasinStabilityEstimator, the solver is chosen automatically based on the ODE class:

  • JaxODESystemJaxSolver (only if pybasin[jax] is installed)
  • ODESystemTorchDiffEqSolver

If the ODE inherits from ODESystem, TorchDiffEqSolver is always selected -- even if JAX is installed. To use JaxSolver, the ODE must inherit from JaxODESystem.

Defining ODE Systems

Every ODE system in pybasin declares its parameters through a TypedDict that gives each parameter a name and type. The field declaration order of the TypedDict fixes the column order of the flat parameter array p passed to ode().

Parameter Declaration

Define a TypedDict for your parameters. The field order determines how they are packed into the flat array:

from typing import TypedDict


class PendulumParams(TypedDict):
    alpha: float  # damping coefficient
    T: float      # external torque
    K: float      # stiffness coefficient

The TypedDict field order determines the mapping between dictionary keys and array indices. With PendulumParams defined as above, calling params_to_array() produces a flat array where index 0 is alpha, index 1 is T, and index 2 is K.

The ode(t, y, p) Signature

All three base classes receive parameters as a flat array p, passed as the third argument to ode(). This design enables parameter batching -- the solver can pass different parameter sets per trajectory without modifying the ODE code.

Access individual parameters by indexing into p. The three backends differ slightly in indexing convention:

  • JAX (JaxODESystem): use p[i] -- each trajectory sees an unbatched 1-D array because jax.vmap handles the batch dimension externally.
  • PyTorch (ODESystem): use p[..., i] -- the ellipsis broadcasts over any leading batch dimensions that torchdiffeq may introduce.
  • NumPy (NumpyODESystem): use p[i] -- each trajectory receives a plain 1-D NumPy array via solve_ivp's args.

State variables follow the same pattern. JAX ODEs index with y[i], while PyTorch ODEs use y[..., i].

JAX Example

import jax.numpy as jnp
from jax import Array
from pybasin.solvers.jax_ode_system import JaxODESystem


class PendulumJaxODE(JaxODESystem[PendulumParams]):
    def __init__(self, params: PendulumParams):
        super().__init__(params)

    def ode(self, t: Array, y: Array, p: Array) -> Array:
        alpha, torque, k = p[0], p[1], p[2]
        theta, theta_dot = y[0], y[1]

        dtheta_dt = theta_dot
        dtheta_dot_dt = -alpha * theta_dot + torque - k * jnp.sin(theta)

        return jnp.array([dtheta_dt, dtheta_dot_dt])

PyTorch Example

import torch
from pybasin.solvers.torch_ode_system import ODESystem


class PendulumODE(ODESystem[PendulumParams]):
    def __init__(self, params: PendulumParams):
        super().__init__(params)

    def ode(self, t: torch.Tensor, y: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
        alpha, torque, k = p[..., 0], p[..., 1], p[..., 2]
        theta = y[..., 0]
        theta_dot = y[..., 1]

        dtheta_dt = theta_dot
        dtheta_dot_dt = -alpha * theta_dot + torque - k * torch.sin(theta)

        return torch.stack([dtheta_dt, dtheta_dot_dt], dim=-1)

NumpyODESystem (scipy)

NumpyODESystem follows the same ode(t, y, p) convention as the other base classes. Parameters arrive as a flat NumPy array ordered by the TypedDict field declarations. Under the hood, ScipyParallelSolver passes p to scipy.integrate.solve_ivp through its args keyword, so each trajectory can receive different parameter values without mutating the ODE instance.

import numpy as np
from pybasin.solvers.numpy_ode_system import NumpyODESystem


class PendulumNumpyODE(NumpyODESystem[PendulumParams]):
    def ode(self, t: float, y: np.ndarray, p: np.ndarray) -> np.ndarray:
        alpha, torque, k = p[0], p[1], p[2]
        return np.array([y[1], -alpha * y[1] + torque - k * np.sin(y[0])])

Helper Methods

All ODE system classes provide these utilities:

Method Description
params_to_array() Converts self.params dict into a flat array ordered by the TypedDict field declarations. Returns a torch.Tensor, jax.Array, or np.ndarray depending on the base class.
get_str() Returns a string representation of the ODE (auto-generated from source code by default). Used for caching and logging.

The solver calls params_to_array() internally when no explicit params tensor is passed to integrate(). You rarely need to call it yourself.

Common Parameters

All solvers share these constructor parameters:

Parameter Type Default Description
t_span tuple[float, float] (0, 1000) Integration interval (t_start, t_end)
t_steps int 1000 Evaluation points in the save region
device str or None None "cuda", "cpu", or None for auto-detect
rtol float 1e-8 Relative tolerance for adaptive stepping
atol float 1e-6 Absolute tolerance for adaptive stepping
cache_dir str or None ".pybasin_cache" Cache directory path. None disables caching
t_eval tuple[float, float] or None None Save region (save_start, save_end). Only points in this range are stored. Must be within t_span. If None, defaults to t_span (save all points).

Generic Solver API

All solvers expose three capabilities through SolverProtocol:

integrate(ode_system, y0, params=None)

Solves the ODE system for a batch of initial conditions. The y0 tensor must be 2D with shape (batch, n_dims), where batch is the number of initial conditions and n_dims is the number of state variables. A ValueError is raised if y0 is not 2D -- the error message suggests using y0.unsqueeze(0) for single trajectories.

The optional params argument accepts a 2D tensor of shape (P, n_params) containing P parameter combinations. Column order must match the TypedDict field declaration order of the ODE system's parameter type. The solver runs every IC against every combination, producing B*P output trajectories in IC-major order: trajectory ic*P + p carries (y0[ic], params[p]). When params is None (the default), the solver calls ode_system.params_to_array() once and reuses those parameters for all ICs.

Returns a tuple (t_eval, y_values):

  • t_eval has shape (t_steps,) -- the time points at which solutions are evaluated
  • y_values has shape (t_steps, B, n_dims) when params=None, or (t_steps, B*P, n_dims) when params is provided
# Default: one parameter set, B initial conditions
t_eval, y_values = solver.integrate(ode_system, y0)
# t_eval.shape   -> (t_steps,)
# y_values.shape -> (t_steps, B, n_dims)

# Parameter sweep: P combinations over B initial conditions
t_eval, y_values = solver.integrate(ode_system, y0, params=param_combos)
# y_values.shape -> (t_steps, B*P, n_dims)  -- IC-major: index ic*P+p ~ (y0[ic], params[p])

clone(*, device, t_steps_factor, cache_dir)

Creates a copy of the solver with optionally overridden settings. Useful for creating a high-resolution variant for plotting while keeping the original solver for computation:

plot_solver = solver.clone(t_steps_factor=10, device="cpu")
Parameter Type Default Description
device str or None None Override the device. None keeps current device
t_steps_factor int 1 Multiply t_steps by this factor
cache_dir str, None, or UNSET UNSET Override cache directory. None disables caching. UNSET keeps current setting

device attribute

A torch.device indicating where output tensors are placed. Always reflects the normalized device (e.g. torch.device("cuda:0"), never bare "cuda").

Behavior Notes

Several behaviors apply to all solvers and happen automatically. Understanding them prevents unexpected surprises.

Device Auto-Detection

When device=None, the solver checks whether CUDA is available via torch.cuda.is_available() (or the JAX equivalent for JaxSolver). If a GPU is found, "cuda:0" is selected; otherwise the solver falls back to "cpu". The string "cuda" is always normalized to "cuda:0" internally.

Dtype and Precision

Solvers do not enforce a specific dtype. Time evaluation points are created with the same dtype as y0, so float32, float64, and other dtypes are accepted. No automatic casting is performed -- the solver logs a warning only when y0 is on a different device than the solver, not for dtype differences.

float32 for GPU workloads

GPU solvers (JaxSolver, TorchDiffEqSolver, TorchOdeSolver) perform best with float32 tensors. CUDA devices process float32 significantly faster than float64. When using float32, tighten tolerances accordingly -- the default rtol=1e-8 targets float64 precision and will cause the adaptive stepper to stagnate with single-precision arithmetic. Values such as rtol=1e-5, atol=1e-6 are more appropriate for float32. Use float64 when higher accuracy is required.

ODE System Device Transfer

PyTorch-based solvers (TorchDiffEqSolver, TorchOdeSolver) call ode_system.to(self.device) before integration, moving the ODE system's nn.Module parameters to the solver's device automatically. JaxSolver does not perform this transfer because JAX ODE systems are stateless and device placement is handled through jax.device_put. ScipyParallelSolver calls .to() as well, but NumpyODESystem.to() is a no-op -- the method exists only for interface compatibility and does nothing.

Caching

Every solver caches integration results to disk so that repeated runs with identical inputs skip the numerical integration entirely. Caching is controlled through the cache_dir constructor parameter, which defaults to ".pybasin_cache". Pass None to disable it.

# Default -- cache under .pybasin_cache/ at the project root
solver = JaxSolver(t_span=(0, 1000), t_steps=5000)

# Explicit subfolder for a specific system
solver = JaxSolver(t_span=(0, 1000), t_steps=5000, cache_dir=".pybasin_cache/pendulum")

# No caching
solver = JaxSolver(t_span=(0, 1000), t_steps=5000, cache_dir=None)

Path Resolution

Relative paths (like ".pybasin_cache" or ".pybasin_cache/pendulum") are resolved from the project root, which is located by walking up the directory tree until a pyproject.toml or .git marker is found. Absolute paths are used as-is. The directory is created automatically if it does not exist.

Cache Keys

The cache key is an MD5 hash built from six components: the solver class name, the ODE system's source representation (via get_str()), the ODE system parameters, the serialized y0 and t_eval tensors, and solver-specific configuration (tolerances, method, etc.). Changing any of these produces a different key, so stale results are never returned.

Storage Format

Cached tensors are stored using safetensors, which provides fast, zero-copy loading without the security concerns of pickle. On cache load, tensors are moved to the solver's current device. Corrupted files are detected and deleted automatically rather than raising exceptions.


JaxSolver

The recommended solver for most workloads. It uses Diffrax (Kidger, 2021) for numerical integration, with jax.vmap for batch processing and JIT compilation for performance. On GPU, JaxSolver achieves near-constant integration time regardless of sample count -- roughly 11.5 seconds for N ranging from 5,000 to 100,000 in benchmark tests. It is also the only solver that supports per-trajectory event-based early termination, which is critical for systems with unbounded trajectories.

Default Solver

TorchDiffEqSolver is the default solver and ships with the core pybasin install. When JAX and Diffrax are available (pip install pybasin[jax]) and the ODE system inherits from JaxODESystem, BasinStabilityEstimator automatically selects JaxSolver. If the ODE inherits from ODESystem, TorchDiffEqSolver is used regardless of whether JAX is installed. JaxSolver delivers the best GPU performance and is the only solver supporting event functions for early trajectory termination. For CPU-only workloads at large sample sizes (N >= 100k), TorchDiffEqSolver is faster. See the Solver Comparison benchmark for detailed numbers.

JaxSolver does not inherit from the Solver base class. It implements SolverProtocol independently with its own device handling and caching logic. Two construction modes are available: a generic API for standard ODE integration, and a solver_args mode that passes arguments directly to diffrax.diffeqsolve().

Generic API

from pybasin.solvers import JaxSolver
from diffrax import Dopri5

solver = JaxSolver(
    t_span=(0, 1000),
    t_steps=5000,
    device="cuda",
    method=Dopri5(),       # Diffrax solver instance
    rtol=1e-8,
    atol=1e-6,
    max_steps=16**5,       # Maximum integrator steps
    event_fn=None,         # Optional early termination
)

Constructor Parameters

See Common Parameters for t_span, t_steps, device, rtol, atol, cache_dir, and t_eval. The following are specific to JaxSolver:

Parameter Type Default Description
method Diffrax solver or None None Diffrax solver instance. Defaults to Dopri5() if None
max_steps int 16**5 Maximum number of integrator steps (1,048,576)
event_fn Callable or None None Event function for per-trajectory early termination

Device String

Unlike PyTorch-based solvers, JaxSolver also accepts "gpu" as a device string (mapped to JAX's GPU backend). Both "cuda" and "gpu" resolve to the same GPU device.

Tensor Conversion

JaxSolver converts between PyTorch and JAX tensors at the integration boundary. On GPU, this conversion uses DLPack for zero-copy transfer -- no data is duplicated in device memory. On CPU, the conversion falls back to a NumPy intermediate. Input tensors are PyTorch; output tensors are PyTorch. You never interact with JAX arrays directly.

solver_args Mode

For advanced use cases (SDEs, CDEs, custom step-size controllers, or any configuration not exposed by the generic API), you can pass a dictionary of keyword arguments directly to diffrax.diffeqsolve():

from diffrax import Dopri5, ODETerm, PIDController, SaveAt
import jax.numpy as jnp

solver = JaxSolver(
    solver_args={
        "terms": ODETerm(lambda t, y, args: -y),
        "solver": Dopri5(),
        "t0": 0,
        "t1": 10,
        "dt0": 0.1,
        "saveat": SaveAt(ts=jnp.linspace(0, 10, 100)),
        "stepsize_controller": PIDController(rtol=1e-5, atol=1e-5),
    },
)

When solver_args is provided, all other Diffrax-specific parameters (t_span, t_steps, solver, rtol, atol, max_steps, event_fn) are ignored entirely. The solver wraps each call with jax.vmap and injects y0 per trajectory -- do not include y0 in the dictionary.

Baked-in Time Points

In solver_args mode, the integration time points are determined by the saveat entry you provide. Calling clone(t_steps_factor=10) will not increase the time resolution -- the original saveat.ts is used as-is. A warning is logged if t_steps_factor > 1 in this mode.

Because solver_args bypasses all automatic setup, no ODETerm wrapping, PIDController creation, or SaveAt construction is performed. You are responsible for providing a complete and valid set of Diffrax arguments.

Event Functions

Event functions enable per-trajectory early termination, which is essential for systems where some trajectories diverge to infinity (e.g. the Lorenz system's "broken butterfly" regime). Each trajectory stops independently when the event triggers, while bounded trajectories continue integrating normally.

The event function signature is (t, y, args) -> scalar Array. Return a positive value to continue integration, or zero/negative to stop:

import jax.numpy as jnp

def lorenz_stop_event(t, y, args, **kwargs):
    """Stop integration when any state variable exceeds 200 in absolute value."""
    max_val = 200.0
    return max_val - jnp.max(jnp.abs(y))
solver = JaxSolver(
    t_span=(0, 1000),
    t_steps=4000,
    device="cuda",
    event_fn=lorenz_stop_event,
)

Internally, the event function is wrapped in a diffrax.Event(cond_fn=event_fn) and passed to diffeqsolve. For more details on handling diverging trajectories, see the Handling Unbounded Trajectories guide.

Post-event state values are inf

When an event triggers early termination, Diffrax fills the remaining saved time points (those after the event) with inf. For example, if a trajectory diverges at \(t = 50\) but saveat requests points up to \(t = 1000\), all state values for \(t > 50\) will be inf. Your feature extraction or classification code must handle this -- checking for jnp.isinf in the final state is a reliable way to detect terminated trajectories.

Event functions in solver_args mode

When using solver_args, include the event directly in the dictionary (e.g. as a diffrax.Event instance) rather than using the event_fn parameter.

Reference:

Kidger, P. (2021). On Neural Differential Equations. PhD thesis, University of Oxford. https://docs.kidger.site/diffrax/


TorchDiffEqSolver

A PyTorch-native solver built on torchdiffeq (Chen, 2018). It supports both adaptive-step and fixed-step methods, runs on CPU and CUDA, and integrates directly with ODESystem subclasses (which inherit from torch.nn.Module). At large sample sizes on CPU (N = 100,000), TorchDiffEqSolver is roughly 2x faster than JaxSolver on CPU -- though JaxSolver on GPU remains substantially faster overall.

from pybasin.solvers.torchdiffeq_solver import TorchDiffEqSolver

solver = TorchDiffEqSolver(
    t_span=(0, 1000),
    t_steps=5000,
    device="cuda",
    method="dopri5",
    rtol=1e-8,
    atol=1e-6,
)

Constructor Parameters

See Common Parameters for t_span, t_steps, device, rtol, atol, cache_dir, and t_eval. The only solver-specific parameter is method:

Parameter Type Default Description
method str "dopri5" Integration method (see table below)

Available Methods

Method Type Description
dopri5 Adaptive-step Dormand-Prince 5(4) (default)
dopri8 Adaptive-step Dormand-Prince 8(5,3)
bosh3 Adaptive-step Bogacki-Shampine 3(2)
euler Fixed-step Forward Euler
rk4 Fixed-step Classic Runge-Kutta 4

Integration runs under torch.no_grad(), so no gradient graph is constructed during forward integration. The solver calls ode_system.to(self.device) before integrating, which moves the ODE system's nn.Module parameters to the solver's device.

Reference:

Chen, R. T. Q. (2018). torchdiffeq. https://github.com/rtqichen/torchdiffeq


TorchOdeSolver

A parallel ODE solver built on torchode (Lienen & Gunnemann, 2022). Its distinguishing feature is independent step-size control per batch element: each trajectory can advance with its own time step, avoiding the performance penalty that arises when a single stiff trajectory forces small steps for the entire batch.

from pybasin.solvers.torchode_solver import TorchOdeSolver

solver = TorchOdeSolver(
    t_span=(0, 1000),
    t_steps=5000,
    device="cuda",
    method="dopri5",
    rtol=1e-8,
    atol=1e-6,
)

Constructor Parameters

See Common Parameters for t_span, t_steps, device, rtol, atol, cache_dir, and t_eval. The only solver-specific parameter is method:

Parameter Type Default Description
method str "dopri5" Integration method (see table below)

Available Methods

Method Type Description
dopri5 Adaptive-step Dormand-Prince 5(4) (default)
tsit5 Adaptive-step Tsitouras 5(4)
euler Fixed-step Forward Euler
heun Fixed-step Heun's method

The method string is lowercased internally, so "Dopri5" and "dopri5" are equivalent. Integration runs under torch.inference_mode(). Internally, torchode uses IntegralController for adaptive step-size selection and AutoDiffAdjoint as the solver wrapper.

Performance at Large N

Benchmark results show that TorchOdeSolver scales poorly at large sample sizes. At N = 100,000 it took roughly 310 seconds on CUDA -- compared to about 11 seconds for JaxSolver. Consider TorchOdeSolver primarily when per-trajectory step-size independence is important for correctness, not for raw throughput. See the Solver Comparison benchmark for details.

Reference:

Lienen, M., & Gunnemann, S. (2022). torchode: A Parallel ODE Solver for PyTorch. The Symbiosis of Deep Learning and Differential Equations II, NeurIPS. https://openreview.net/forum?id=uiKVKTiUYB0


ScipyParallelSolver

A CPU-only solver that delegates integration to scipy.integrate.solve_ivp and parallelizes across initial conditions using joblib's loky backend. Each trajectory is solved independently in a separate worker process. Requires a NumpyODESystem subclass -- the ode_system.ode method is passed as the fun argument to solve_ivp, and the parameter array is forwarded through args=(p,). No PyTorch tensor conversions happen on the hot path. This solver is primarily useful for debugging, validating results against a well-established reference, and accessing scipy's implicit methods (Radau, BDF) for stiff systems.

from pybasin.solvers.scipy_solver import ScipyParallelSolver

solver = ScipyParallelSolver(
    t_span=(0, 1000),
    t_steps=5000,
    n_jobs=-1,           # Use all CPU cores
    method="RK45",
    rtol=1e-8,
    atol=1e-6,
    max_step=None,       # Defaults to (t_end - t_start) / 100
)

Constructor Parameters

See Common Parameters for t_span, t_steps, cache_dir, and t_eval. The following differ from or extend the common parameters:

Parameter Type Default Description
device str or None None Only "cpu" is supported -- "cuda" logs a warning and falls back to CPU (see note)
n_jobs int -1 Number of parallel workers (-1 for all CPU cores)
method str "RK45" scipy.integrate.solve_ivp method
rtol float 1e-6 Relative tolerance (default differs from other solvers)
atol float 1e-8 Absolute tolerance (default differs from other solvers)
max_step float or None None Maximum step size. Defaults to (t_end - t_start) / 100

CPU Only

If you pass device="cuda", the solver logs a warning and silently falls back to CPU. No error is raised. This applies to both the constructor and clone().

Available Methods

Scipy provides explicit methods (RK45, RK23, DOP853) and implicit methods for stiff problems (Radau, BDF, LSODA). See the scipy.integrate.solve_ivp documentation for the full list.

Parallelization Behavior

When batch_size > 1 and n_jobs != 1, trajectories are distributed across worker processes using the loky backend. For a single trajectory (batch_size == 1) or when n_jobs == 1, execution is sequential with no multiprocessing overhead. Because NumpyODESystem.ode is pure NumPy, each worker process can call solve_ivp without any PyTorch state, avoiding GIL contention and import overhead inside the worker.


See Also