Skip to content

Solvers

ODE System Classes

pybasin provides three ODE system base classes, each paired with specific solver backends:

  • ODESystem -- PyTorch-based. Define ode(t, y, p) using torch operations. Works with TorchDiffEqSolver and TorchOdeSolver.
  • JaxODESystem -- JAX-based. Define ode(t, y, p) using jax.numpy operations. Works with JaxSolver for JIT-compiled, GPU-optimized integration.
  • NumpyODESystem -- NumPy-based. Define ode(t, y, p) using numpy operations. Required by ScipyParallelSolver; the ode method is passed to scipy.integrate.solve_ivp as fun with parameters forwarded via args.

When solver=None, BasinStabilityEstimator auto-selects the solver based on which class the ODE inherits from. See the Solvers user guide for details.

pybasin.solvers.torch_ode_system.ODESystem

Bases: AutoGetStrMixin, ABC, Module

Abstract base class for defining an ODE system.

P is a type parameter representing the parameter dictionary type. Pass a TypedDict subclass for typed self.params access.

from typing import TypedDict


class MyParams(TypedDict):
    alpha: float
    beta: float


class MyODE(ODESystem[MyParams]):
    def __init__(self, params: MyParams):
        super().__init__(params)

    def ode(self, t: torch.Tensor, y: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
        alpha, beta = p[..., 0], p[..., 1]
        return torch.zeros_like(y)

Functions

ode abstractmethod

ode(t: Tensor, y: Tensor, p: Tensor) -> torch.Tensor

Right-hand side (RHS) for the ODE using PyTorch tensors.

Parameters:

Name Type Description Default
t Tensor

The current time (can be scalar or batch).

required
y Tensor

The current state (can be shape (..., n)).

required
p Tensor

Flat parameter array with shape (n_params,) built by params_to_array(). Access individual parameters via p[..., i] to support batching.

required

Returns:

Type Description
Tensor

The time derivatives with the same leading shape as y.

params_to_array

params_to_array() -> torch.Tensor

Convert self.params to a flat tensor.

Values are ordered by the TypedDict field declaration order.

Returns:

Type Description
Tensor

Flat tensor of shape (n_params,).

forward

forward(t: Tensor, y: Tensor) -> torch.Tensor

Calls the ODE function in a manner consistent with nn.Module.

When _batched_params is set (by the solver for parameter batching), it is used instead of params_to_array().


pybasin.solvers.numpy_ode_system.NumpyODESystem

Bases: AutoGetStrMixin, ABC

Abstract base class for numpy-based ODE systems, compatible with :func:scipy.integrate.solve_ivp.

P is a type parameter representing the parameter dictionary type. Instances are callable and can be passed directly as the fun argument to solve_ivp.

Subclasses declare parameters via a TypedDict type parameter.

Functions

ode abstractmethod

ode(t: float, y: ndarray, p: ndarray) -> np.ndarray

Right-hand side (RHS) for the ODE.

Parameters:

Name Type Description Default
t float

Current time.

required
y ndarray

Current state vector, shape (n,).

required
p ndarray

Flat parameter array of shape (n_params,) built by params_to_array(). Access individual parameters via p[i].

required

Returns:

Type Description
ndarray

Time derivatives, shape (n,).

params_to_array

params_to_array() -> np.ndarray

Convert self.params to a flat numpy array.

Values are ordered by the TypedDict field declaration order.

Returns:

Type Description
ndarray

Flat array of shape (n_params,).

to

to(device: Any) -> NumpyODESystem[P]

No-op device transfer for compatibility with the Solver base class.


pybasin.solvers.jax_ode_system.JaxODESystem

Bases: AutoGetStrMixin

Base class for defining an ODE system using pure JAX.

This class is designed for ODE systems that need maximum performance with JAX/Diffrax. Unlike the PyTorch-based ODESystem, this uses pure JAX operations that can be JIT-compiled for optimal GPU performance.

P is a type parameter representing the parameter dictionary type. Pass a TypedDict subclass for typed self.params access.

For standard ODEs, subclass and override ode():

from typing import TypedDict
import jax.numpy as jnp
from jax import Array


class MyParams(TypedDict):
    alpha: float
    beta: float


class MyJaxODE(JaxODESystem[MyParams]):
    def __init__(self, params: MyParams):
        super().__init__(params)

    def ode(self, t: Array, y: Array, p: Array) -> Array:
        alpha, beta = p[0], p[1]
        return jnp.zeros_like(y)

For SDEs or CDEs where you provide custom Diffrax terms via solver_args, overriding ode() is not required. The subclass only needs params and get_str() for caching and display:

class MySDESystem(JaxODESystem[MyParams]):
    def __init__(self, params: MyParams):
        super().__init__(params)

    def get_str(self) -> str:
        return f"MySDE(alpha={self.params['alpha']})"

Functions

__init__

__init__(params: P) -> None

Initialize the JAX ODE system.

Parameters:

Name Type Description Default
params P

Dictionary of ODE parameters.

required

ode

ode(t: Array, y: Array, p: Array) -> Array

Right-hand side (RHS) for the ODE using pure JAX operations.

Override this method for standard ODE systems. For SDEs or CDEs where custom Diffrax terms are provided via JaxSolver(solver_args=...), overriding this method is not required.

This method must use only JAX operations (jnp, not np or torch) to enable JIT compilation and efficient execution.

Notes:

  • Use jnp operations instead of np or torch
  • Avoid Python control flow that depends on array values
  • This method will be JIT-compiled, so ensure it's traceable
  • p is the flat parameter array built from params_to_array(). Access individual parameters via p[i]. Batching is handled by vmap, so the ODE always sees unbatched 1-D arrays.

Parameters:

Name Type Description Default
t Array

The current time (scalar JAX array).

required
y Array

The current state with shape (n_dims,) for single trajectory.

required
p Array

Flat parameter array with shape (n_params,), built by params_to_array().

required

Returns:

Type Description
Array

The time derivatives with the same shape as y.

Raises:

Type Description
NotImplementedError

If not overridden and called directly.

params_to_array

params_to_array() -> Array

Convert self.params to a flat JAX array.

Values are ordered by the TypedDict field declaration order.

Returns:

Type Description
Array

Flat JAX array of shape (n_params,).

to

to(device: Any) -> JaxODESystem[P]

No-op for JAX systems - device handling is done on tensors.

This method exists for API compatibility with PyTorch-based ODESystem.

Parameters:

Name Type Description Default
device Any

Ignored for JAX systems.

required

Returns:

Type Description
JaxODESystem[P]

Returns self.

__call__

__call__(t: Array, y: Array, args: Any = None) -> Array

Make the ODE system callable for use with Diffrax.

Diffrax expects f(t, y, args) signature.

Parameters:

Name Type Description Default
t Array

Current time.

required
y Array

Current state.

required
args Any

Passed through to ode().

None

Returns:

Type Description
Array

Time derivatives.


Solver Protocol

pybasin.protocols.SolverProtocol

Bases: Protocol

Protocol defining the common interface for ODE solvers.

Two implementations exist: Solver (PyTorch-based) and JaxSolver (JAX-based). Structural typing allows both to satisfy this protocol without explicit inheritance, though classes may inherit from it to declare conformance explicitly.

Attributes:

Name Type Description
device device

Device for output tensors.

Functions

__init__

__init__(
    t_span: tuple[float, float] = (0, 1000),
    t_steps: int = 1000,
    device: str | None = None,
    method: Any = None,
    rtol: float = 1e-08,
    atol: float = 1e-06,
    cache_dir: str | None = DEFAULT_CACHE_DIR,
    t_eval: tuple[float, float] | None = None,
) -> None

Initialize the solver with integration parameters.

Parameters:

Name Type Description Default
t_span tuple[float, float]

Tuple (t_start, t_end) defining the integration interval.

(0, 1000)
t_steps int

Number of evaluation points in the save region.

1000
device str | None

Device to use ('cuda', 'cpu', 'gpu', or None for auto-detect).

None
method Any

Integration method (solver-specific).

None
rtol float

Relative tolerance (used by adaptive-step methods only).

1e-08
atol float

Absolute tolerance (used by adaptive-step methods only).

1e-06
cache_dir str | None

Directory for caching integration results. Relative paths are resolved from the project root. None disables caching.

DEFAULT_CACHE_DIR
t_eval tuple[float, float] | None

Optional save region (save_start, save_end). Only time points in this range are stored. Must be contained within t_span. If None, defaults to t_span (save all points).

None

integrate

integrate(
    ode_system: ODESystemProtocol,
    y0: Tensor,
    params: Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]

Solve the ODE system and return the evaluation time points and solution.

Parameters:

Name Type Description Default
ode_system ODESystemProtocol

An instance of an ODE system (ODESystem or JaxODESystem).

required
y0 Tensor

Initial conditions with shape (batch, n_dims).

required
params Tensor | None

Optional parameter array. When None, the solver calls ode_system.params_to_array() to obtain the default parameters.

None

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple (t_eval, y_values) where y_values has shape (t_steps, batch, n_dims).

clone

clone(
    *,
    device: str | None = None,
    t_steps_factor: int = 1,
    cache_dir: str | None | object = UNSET,
) -> SolverProtocol

Create a copy of this solver, optionally overriding device, resolution, or caching.

Parameters:

Name Type Description Default
device str | None

Target device ('cpu', 'cuda', 'gpu'). If None, keeps the current device.

None
t_steps_factor int

Multiply the number of evaluation points by this factor (e.g. 10 for smoother plotting). Defaults to 1 (no change).

1
cache_dir str | None | object

Override cache directory. Pass None to disable caching. If not provided, keeps the current setting.

UNSET

Returns:

Type Description
SolverProtocol

New solver instance.


Solver Implementations

pybasin.solvers.jax_solver.JaxSolver

Bases: SolverProtocol, DisplayNameMixin

High-performance ODE solver using JAX and Diffrax for native JAX ODE systems.

This solver is optimized for JaxODESystem instances and provides the fastest integration performance by avoiding any PyTorch callbacks. It uses JIT compilation and vmap for efficient batch processing.

The interface is compatible with other solvers - it accepts PyTorch tensors and returns PyTorch tensors, but internally uses JAX for maximum performance.

See also: Diffrax documentation

Citation:

@phdthesis{kidger2021on,
    title={{O}n {N}eural {D}ifferential {E}quations},
    author={Patrick Kidger},
    year={2021},
    school={University of Oxford},
}

Example usage:

Overload 1 — generic API for standard ODEs:

from pybasin.solvers.jax_ode_system import JaxODESystem
from pybasin.solvers import JaxSolver
import torch

class MyODE(JaxODESystem):
    def ode(self, t, y, args=None):
        return -y  # Simple decay
    def get_str(self):
        return "decay"

solver = JaxSolver(t_span=(0, 10), t_steps=100)
y0 = torch.tensor([[1.0, 2.0]])  # batch=1, dims=2
t, y = solver.integrate(MyODE({}), y0)

Overload 2 — direct Diffrax control via solver_args:

Pass native Diffrax arguments directly to diffeqsolve. This is useful for SDEs, CDEs, or any advanced Diffrax configuration.

.. note::

When using solver_args, the integration time points are baked into saveat.ts at construction time. The t_steps_factor parameter of :meth:clone has no effect in this mode — the actual integration still uses the original saveat.

from diffrax import Dopri5, ODETerm, PIDController, SaveAt
import jax.numpy as jnp

solver = JaxSolver(
    solver_args={
        "terms": ODETerm(lambda t, y, args: -y),
        "solver": Dopri5(),
        "t0": 0,
        "t1": 10,
        "dt0": 0.1,
        "saveat": SaveAt(ts=jnp.linspace(0, 10, 100)),
        "stepsize_controller": PIDController(rtol=1e-5, atol=1e-5),
    },
)

Functions

__init__

__init__(
    t_span: tuple[float, float] = (0, 1000),
    t_steps: int = 1000,
    device: str | None = None,
    method: Any | None = None,
    rtol: float = 1e-08,
    atol: float = 1e-06,
    cache_dir: str | None = DEFAULT_CACHE_DIR,
    max_steps: int = DEFAULT_MAX_STEPS,
    event_fn: Callable[[Any, Array, Any], Array]
    | None = None,
    *,
    t_eval: tuple[float, float] | None = None,
) -> None
__init__(
    *,
    solver_args: dict[str, Any],
    cache_dir: str | None = DEFAULT_CACHE_DIR,
) -> None
__init__(
    t_span: tuple[float, float] = (0, 1000),
    t_steps: int = 1000,
    device: str | None = None,
    method: Any | None = None,
    rtol: float = 1e-08,
    atol: float = 1e-06,
    cache_dir: str | None = DEFAULT_CACHE_DIR,
    max_steps: int = DEFAULT_MAX_STEPS,
    event_fn: Callable[[Any, Array, Any], Array]
    | None = None,
    *,
    solver_args: dict[str, Any] | None = None,
    t_eval: tuple[float, float] | None = None,
)

Initialize JaxSolver.

Can be called in two ways:

  1. Generic API with named parameters for standard ODE integration:

JaxSolver(t_span=(0, 10), t_steps=100, rtol=1e-8, ...)

  1. Direct Diffrax control via solver_args for full access to diffeqsolve kwargs (SDEs, CDEs, custom step-size controllers, etc.):

JaxSolver(solver_args={"terms": ..., "solver": ..., "t0": ..., ...})

The two interfaces are mutually exclusive at the type level.

Parameters:

Name Type Description Default
t_span tuple[float, float]

Tuple (t_start, t_end) defining the integration interval.

(0, 1000)
t_steps int

Number of evaluation points in the save region.

1000
device str | None

Device to use ('cuda', 'gpu', 'cpu', or None for auto-detect).

None
method Any | None

Diffrax solver instance (e.g., Dopri5(), Tsit5()). Defaults to Dopri5() if None.

None
rtol float

Relative tolerance (used by adaptive-step methods only).

1e-08
atol float

Absolute tolerance (used by adaptive-step methods only).

1e-06
max_steps int

Maximum number of integrator steps.

DEFAULT_MAX_STEPS
cache_dir str | None

Directory for caching integration results. Relative paths are resolved from the project root. None disables caching.

DEFAULT_CACHE_DIR
event_fn Callable[[Any, Array, Any], Array] | None

Optional event function for early termination. Should return positive when integration should continue, negative/zero to stop. Signature: (t, y, args) -> scalar Array.

None
solver_args dict[str, Any] | None

Dict of kwargs passed directly to diffrax.diffeqsolve(). When provided, all other Diffrax-specific parameters are ignored. Must NOT include y0 (provided per-trajectory via integrate()).

None
t_eval tuple[float, float] | None

Optional save region (save_start, save_end). Only time points in this range are stored. Must be contained within t_span. If None, defaults to t_span (save all points). Ignored in solver_args mode.

None

clone

clone(
    *,
    device: str | None = None,
    t_steps_factor: int = 1,
    cache_dir: str | None | object = UNSET,
) -> JaxSolver

Create a copy of this solver, optionally overriding device, resolution, or caching.

Parameters:

Name Type Description Default
device str | None

Target device ('cpu', 'cuda', 'gpu'). If None, keeps the current device.

None
t_steps_factor int

Multiply the number of evaluation points by this factor. Ignored for solver_args mode (saveat is baked in at construction time).

1
cache_dir str | None | object

Override cache directory. Pass None to disable caching. If not provided, keeps the current setting.

UNSET

Returns:

Type Description
JaxSolver

New JaxSolver instance.

integrate

integrate(
    ode_system: ODESystemProtocol,
    y0: Tensor,
    params: Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]

Solve the ODE system and return the evaluation time points and solution.

Parameters:

Name Type Description Default
ode_system ODESystemProtocol

An instance of JaxODESystem.

required
y0 Tensor

Initial conditions as PyTorch tensor with shape (batch, n_dims).

required
params Tensor | None

Optional 2-D tensor of shape (P, n_params) with P parameter combinations. The solver runs every IC against every combination, producing B*P output trajectories in IC-major order: trajectory ic*P + p carries (y0[ic], params[p]). When None, the ODE system's default parameters are used for all ICs (equivalent to a single-row params).

None

Returns:

Type Description
tuple[Tensor, Tensor]

Tuple (t_eval, y_values) as PyTorch tensors where y_values has shape (t_steps, B*P, n_dims).


pybasin.solvers.TorchDiffEqSolver

Bases: Solver

Differentiable ODE solver with full GPU support and O(1)-memory backpropagation.

Uses the adjoint method for memory-efficient gradient computation through ODE solutions. Supports adaptive-step (dopri5, dopri8, bosh3) and fixed-step (euler, rk4) methods.

See also: torchdiffeq GitHub

Citation:

@misc{torchdiffeq,
    author={Chen, Ricky T. Q.},
    title={torchdiffeq},
    year={2018},
    url={https://github.com/rtqichen/torchdiffeq},
}

Functions

__init__

__init__(
    t_span: tuple[float, float] = (0, 1000),
    t_steps: int = 1000,
    device: str | None = None,
    method: str = "dopri5",
    rtol: float = 1e-08,
    atol: float = 1e-06,
    cache_dir: str | None = DEFAULT_CACHE_DIR,
    t_eval: tuple[float, float] | None = None,
)

Initialize TorchDiffEqSolver.

Parameters:

Name Type Description Default
t_span tuple[float, float]

Tuple (t_start, t_end) defining the integration interval.

(0, 1000)
t_steps int

Number of evaluation points in the save region.

1000
device str | None

Device to use ('cuda', 'cpu', or None for auto-detect).

None
method str

Integration method from tordiffeq.odeint.

'dopri5'
rtol float

Relative tolerance (used by adaptive-step methods only).

1e-08
atol float

Absolute tolerance (used by adaptive-step methods only).

1e-06
cache_dir str | None

Directory for caching integration results. None disables caching.

DEFAULT_CACHE_DIR
t_eval tuple[float, float] | None

Optional save region (save_start, save_end). Only time points in this range are stored. Must be contained within t_span. If None, defaults to t_span.

None

clone

clone(
    *,
    device: str | None = None,
    t_steps_factor: int = 1,
    cache_dir: str | None | object = UNSET,
) -> TorchDiffEqSolver

Create a copy of this solver, optionally overriding device, resolution, or caching.


pybasin.solvers.TorchOdeSolver

Bases: Solver

Parallel ODE solver with independent step sizes per batch element.

Compatible with PyTorch's JIT compiler for performance optimization. Unlike other solvers, torchode can take different step sizes for each sample in a batch, avoiding performance traps for problems of varying stiffness.

See also: torchode documentation

Citation:

@inproceedings{lienen2022torchode,
    title = {torchode: A Parallel {ODE} Solver for PyTorch},
    author = {Marten Lienen and Stephan G{"u}nnemann},
    booktitle = {The Symbiosis of Deep Learning and Differential Equations II, NeurIPS},
    year = {2022},
    url = {https://openreview.net/forum?id=uiKVKTiUYB0}
}

Functions

__init__

__init__(
    t_span: tuple[float, float] = (0, 1000),
    t_steps: int = 1000,
    device: str | None = None,
    method: str = "dopri5",
    rtol: float = 1e-08,
    atol: float = 1e-06,
    cache_dir: str | None = DEFAULT_CACHE_DIR,
    t_eval: tuple[float, float] | None = None,
)

Initialize TorchOdeSolver.

Parameters:

Name Type Description Default
t_span tuple[float, float]

Tuple (t_start, t_end) defining the integration interval.

(0, 1000)
t_steps int

Number of evaluation points in the save region.

1000
device str | None

Device to use ('cuda', 'cpu', or None for auto-detect).

None
method str

Integration method ('dopri5', 'tsit5', 'euler', 'heun').

'dopri5'
rtol float

Relative tolerance (used by adaptive-step methods only).

1e-08
atol float

Absolute tolerance (used by adaptive-step methods only).

1e-06
cache_dir str | None

Directory for caching integration results. None disables caching.

DEFAULT_CACHE_DIR
t_eval tuple[float, float] | None

Optional save region (save_start, save_end). Only time points in this range are stored. Must be contained within t_span. If None, defaults to t_span.

None

clone

clone(
    *,
    device: str | None = None,
    t_steps_factor: int = 1,
    cache_dir: str | None | object = UNSET,
) -> TorchOdeSolver

Create a copy of this solver, optionally overriding device, resolution, or caching.


pybasin.solvers.ScipyParallelSolver

Bases: Solver

ODE solver using sklearn's parallel processing with scipy's solve_ivp.

Uses multiprocessing (loky backend) to solve multiple initial conditions in parallel. Each worker solves one trajectory at a time using scipy's solve_ivp.

Requires a :class:~pybasin.solvers.numpy_ode_system.NumpyODESystem subclass. The ODE is passed directly to solve_ivp with no PyTorch-to-NumPy conversion overhead.

See also: scipy.integrate.solve_ivp

Functions

__init__

__init__(
    t_span: tuple[float, float] = (0, 1000),
    t_steps: int = 1000,
    device: str | None = None,
    n_jobs: int = -1,
    method: str = "RK45",
    rtol: float = 1e-06,
    atol: float = 1e-08,
    max_step: float | None = None,
    cache_dir: str | None = DEFAULT_CACHE_DIR,
    t_eval: tuple[float, float] | None = None,
)

Initialize ScipyParallelSolver.

Parameters:

Name Type Description Default
t_span tuple[float, float]

Tuple (t_start, t_end) defining the integration interval.

(0, 1000)
t_steps int

Number of evaluation points in the save region.

1000
device str | None

Device to use (only 'cpu' supported).

None
n_jobs int

Number of parallel jobs (-1 for all CPUs).

-1
method str

Integration method ('RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA', etc).

'RK45'
rtol float

Relative tolerance (used by adaptive-step methods only).

1e-06
atol float

Absolute tolerance (used by adaptive-step methods only).

1e-08
max_step float | None

Maximum step size for the solver.

None
cache_dir str | None

Directory for caching integration results. None disables caching.

DEFAULT_CACHE_DIR
t_eval tuple[float, float] | None

Optional save region (save_start, save_end). Only time points in this range are stored. Must be contained within t_span. If None, defaults to t_span.

None

clone

clone(
    *,
    device: str | None = None,
    t_steps_factor: int = 1,
    cache_dir: str | None | object = UNSET,
) -> ScipyParallelSolver

Create a copy of this solver, optionally overriding device, resolution, or caching.

Note: ScipyParallelSolver only supports CPU, so the device is always CPU.