Solvers
ODE System Classes
pybasin provides three ODE system base classes, each paired with specific solver backends:
ODESystem-- PyTorch-based. Defineode(t, y, p)usingtorchoperations. Works withTorchDiffEqSolverandTorchOdeSolver.JaxODESystem-- JAX-based. Defineode(t, y, p)usingjax.numpyoperations. Works withJaxSolverfor JIT-compiled, GPU-optimized integration.NumpyODESystem-- NumPy-based. Defineode(t, y, p)usingnumpyoperations. Required byScipyParallelSolver; theodemethod is passed toscipy.integrate.solve_ivpasfunwith parameters forwarded viaargs.
When solver=None, BasinStabilityEstimator auto-selects the solver based on which class the ODE inherits from. See the Solvers user guide for details.
pybasin.solvers.torch_ode_system.ODESystem
Bases: AutoGetStrMixin, ABC, Module
Abstract base class for defining an ODE system.
P is a type parameter representing the parameter dictionary type.
Pass a TypedDict subclass for typed self.params access.
from typing import TypedDict
class MyParams(TypedDict):
alpha: float
beta: float
class MyODE(ODESystem[MyParams]):
def __init__(self, params: MyParams):
super().__init__(params)
def ode(self, t: torch.Tensor, y: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
alpha, beta = p[..., 0], p[..., 1]
return torch.zeros_like(y)
Functions
ode
abstractmethod
Right-hand side (RHS) for the ODE using PyTorch tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
Tensor
|
The current time (can be scalar or batch). |
required |
y
|
Tensor
|
The current state (can be shape (..., n)). |
required |
p
|
Tensor
|
Flat parameter array with shape |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The time derivatives with the same leading shape as y. |
params_to_array
Convert self.params to a flat tensor.
Values are ordered by the TypedDict field declaration order.
Returns:
| Type | Description |
|---|---|
Tensor
|
Flat tensor of shape |
pybasin.solvers.numpy_ode_system.NumpyODESystem
Bases: AutoGetStrMixin, ABC
Abstract base class for numpy-based ODE systems, compatible with
:func:scipy.integrate.solve_ivp.
P is a type parameter representing the parameter dictionary type.
Instances are callable and can be passed directly as the fun argument
to solve_ivp.
Subclasses declare parameters via a TypedDict type parameter.
Functions
ode
abstractmethod
Right-hand side (RHS) for the ODE.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
float
|
Current time. |
required |
y
|
ndarray
|
Current state vector, shape |
required |
p
|
ndarray
|
Flat parameter array of shape |
required |
Returns:
| Type | Description |
|---|---|
ndarray
|
Time derivatives, shape |
params_to_array
Convert self.params to a flat numpy array.
Values are ordered by the TypedDict field declaration order.
Returns:
| Type | Description |
|---|---|
ndarray
|
Flat array of shape |
pybasin.solvers.jax_ode_system.JaxODESystem
Bases: AutoGetStrMixin
Base class for defining an ODE system using pure JAX.
This class is designed for ODE systems that need maximum performance with JAX/Diffrax. Unlike the PyTorch-based ODESystem, this uses pure JAX operations that can be JIT-compiled for optimal GPU performance.
P is a type parameter representing the parameter dictionary type.
Pass a TypedDict subclass for typed self.params access.
For standard ODEs, subclass and override ode():
from typing import TypedDict
import jax.numpy as jnp
from jax import Array
class MyParams(TypedDict):
alpha: float
beta: float
class MyJaxODE(JaxODESystem[MyParams]):
def __init__(self, params: MyParams):
super().__init__(params)
def ode(self, t: Array, y: Array, p: Array) -> Array:
alpha, beta = p[0], p[1]
return jnp.zeros_like(y)
For SDEs or CDEs where you provide custom Diffrax terms via solver_args,
overriding ode() is not required. The subclass only needs params and
get_str() for caching and display:
class MySDESystem(JaxODESystem[MyParams]):
def __init__(self, params: MyParams):
super().__init__(params)
def get_str(self) -> str:
return f"MySDE(alpha={self.params['alpha']})"
Functions
__init__
Initialize the JAX ODE system.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
P
|
Dictionary of ODE parameters. |
required |
ode
Right-hand side (RHS) for the ODE using pure JAX operations.
Override this method for standard ODE systems. For SDEs or CDEs where
custom Diffrax terms are provided via JaxSolver(solver_args=...),
overriding this method is not required.
This method must use only JAX operations (jnp, not np or torch) to enable JIT compilation and efficient execution.
Notes:
- Use jnp operations instead of np or torch
- Avoid Python control flow that depends on array values
- This method will be JIT-compiled, so ensure it's traceable
pis the flat parameter array built fromparams_to_array(). Access individual parameters viap[i]. Batching is handled byvmap, so the ODE always sees unbatched 1-D arrays.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
Array
|
The current time (scalar JAX array). |
required |
y
|
Array
|
The current state with shape (n_dims,) for single trajectory. |
required |
p
|
Array
|
Flat parameter array with shape (n_params,), built by |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The time derivatives with the same shape as y. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If not overridden and called directly. |
params_to_array
Convert self.params to a flat JAX array.
Values are ordered by the TypedDict field declaration order.
Returns:
| Type | Description |
|---|---|
Array
|
Flat JAX array of shape |
to
No-op for JAX systems - device handling is done on tensors.
This method exists for API compatibility with PyTorch-based ODESystem.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
Any
|
Ignored for JAX systems. |
required |
Returns:
| Type | Description |
|---|---|
JaxODESystem[P]
|
Returns self. |
__call__
Make the ODE system callable for use with Diffrax.
Diffrax expects f(t, y, args) signature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t
|
Array
|
Current time. |
required |
y
|
Array
|
Current state. |
required |
args
|
Any
|
Passed through to ode(). |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Time derivatives. |
Solver Protocol
pybasin.protocols.SolverProtocol
Bases: Protocol
Protocol defining the common interface for ODE solvers.
Two implementations exist: Solver (PyTorch-based) and JaxSolver (JAX-based). Structural typing allows both to satisfy this protocol without explicit inheritance, though classes may inherit from it to declare conformance explicitly.
Attributes:
| Name | Type | Description |
|---|---|---|
device |
device
|
Device for output tensors. |
Functions
__init__
__init__(
t_span: tuple[float, float] = (0, 1000),
t_steps: int = 1000,
device: str | None = None,
method: Any = None,
rtol: float = 1e-08,
atol: float = 1e-06,
cache_dir: str | None = DEFAULT_CACHE_DIR,
t_eval: tuple[float, float] | None = None,
) -> None
Initialize the solver with integration parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t_span
|
tuple[float, float]
|
Tuple (t_start, t_end) defining the integration interval. |
(0, 1000)
|
t_steps
|
int
|
Number of evaluation points in the save region. |
1000
|
device
|
str | None
|
Device to use ('cuda', 'cpu', 'gpu', or None for auto-detect). |
None
|
method
|
Any
|
Integration method (solver-specific). |
None
|
rtol
|
float
|
Relative tolerance (used by adaptive-step methods only). |
1e-08
|
atol
|
float
|
Absolute tolerance (used by adaptive-step methods only). |
1e-06
|
cache_dir
|
str | None
|
Directory for caching integration results. Relative paths are resolved from the project root. |
DEFAULT_CACHE_DIR
|
t_eval
|
tuple[float, float] | None
|
Optional save region |
None
|
integrate
integrate(
ode_system: ODESystemProtocol,
y0: Tensor,
params: Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]
Solve the ODE system and return the evaluation time points and solution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ode_system
|
ODESystemProtocol
|
An instance of an ODE system (ODESystem or JaxODESystem). |
required |
y0
|
Tensor
|
Initial conditions with shape (batch, n_dims). |
required |
params
|
Tensor | None
|
Optional parameter array. When |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
Tuple (t_eval, y_values) where y_values has shape (t_steps, batch, n_dims). |
clone
clone(
*,
device: str | None = None,
t_steps_factor: int = 1,
cache_dir: str | None | object = UNSET,
) -> SolverProtocol
Create a copy of this solver, optionally overriding device, resolution, or caching.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
str | None
|
Target device ('cpu', 'cuda', 'gpu'). If None, keeps the current device. |
None
|
t_steps_factor
|
int
|
Multiply the number of evaluation points by this factor (e.g. 10 for smoother plotting). Defaults to 1 (no change). |
1
|
cache_dir
|
str | None | object
|
Override cache directory. Pass |
UNSET
|
Returns:
| Type | Description |
|---|---|
SolverProtocol
|
New solver instance. |
Solver Implementations
pybasin.solvers.jax_solver.JaxSolver
Bases: SolverProtocol, DisplayNameMixin
High-performance ODE solver using JAX and Diffrax for native JAX ODE systems.
This solver is optimized for JaxODESystem instances and provides the fastest integration performance by avoiding any PyTorch callbacks. It uses JIT compilation and vmap for efficient batch processing.
The interface is compatible with other solvers - it accepts PyTorch tensors and returns PyTorch tensors, but internally uses JAX for maximum performance.
See also: Diffrax documentation
Citation:
@phdthesis{kidger2021on,
title={{O}n {N}eural {D}ifferential {E}quations},
author={Patrick Kidger},
year={2021},
school={University of Oxford},
}
Example usage:
Overload 1 — generic API for standard ODEs:
from pybasin.solvers.jax_ode_system import JaxODESystem
from pybasin.solvers import JaxSolver
import torch
class MyODE(JaxODESystem):
def ode(self, t, y, args=None):
return -y # Simple decay
def get_str(self):
return "decay"
solver = JaxSolver(t_span=(0, 10), t_steps=100)
y0 = torch.tensor([[1.0, 2.0]]) # batch=1, dims=2
t, y = solver.integrate(MyODE({}), y0)
Overload 2 — direct Diffrax control via solver_args:
Pass native Diffrax arguments directly to diffeqsolve. This is useful
for SDEs, CDEs, or any advanced Diffrax configuration.
.. note::
When using solver_args, the integration time points are baked into
saveat.ts at construction time. The t_steps_factor parameter of
:meth:clone has no effect in this mode — the actual integration still
uses the original saveat.
from diffrax import Dopri5, ODETerm, PIDController, SaveAt
import jax.numpy as jnp
solver = JaxSolver(
solver_args={
"terms": ODETerm(lambda t, y, args: -y),
"solver": Dopri5(),
"t0": 0,
"t1": 10,
"dt0": 0.1,
"saveat": SaveAt(ts=jnp.linspace(0, 10, 100)),
"stepsize_controller": PIDController(rtol=1e-5, atol=1e-5),
},
)
Functions
__init__
__init__(
t_span: tuple[float, float] = (0, 1000),
t_steps: int = 1000,
device: str | None = None,
method: Any | None = None,
rtol: float = 1e-08,
atol: float = 1e-06,
cache_dir: str | None = DEFAULT_CACHE_DIR,
max_steps: int = DEFAULT_MAX_STEPS,
event_fn: Callable[[Any, Array, Any], Array]
| None = None,
*,
t_eval: tuple[float, float] | None = None,
) -> None
__init__(
t_span: tuple[float, float] = (0, 1000),
t_steps: int = 1000,
device: str | None = None,
method: Any | None = None,
rtol: float = 1e-08,
atol: float = 1e-06,
cache_dir: str | None = DEFAULT_CACHE_DIR,
max_steps: int = DEFAULT_MAX_STEPS,
event_fn: Callable[[Any, Array, Any], Array]
| None = None,
*,
solver_args: dict[str, Any] | None = None,
t_eval: tuple[float, float] | None = None,
)
Initialize JaxSolver.
Can be called in two ways:
- Generic API with named parameters for standard ODE integration:
JaxSolver(t_span=(0, 10), t_steps=100, rtol=1e-8, ...)
- Direct Diffrax control via
solver_argsfor full access todiffeqsolvekwargs (SDEs, CDEs, custom step-size controllers, etc.):
JaxSolver(solver_args={"terms": ..., "solver": ..., "t0": ..., ...})
The two interfaces are mutually exclusive at the type level.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t_span
|
tuple[float, float]
|
Tuple (t_start, t_end) defining the integration interval. |
(0, 1000)
|
t_steps
|
int
|
Number of evaluation points in the save region. |
1000
|
device
|
str | None
|
Device to use ('cuda', 'gpu', 'cpu', or None for auto-detect). |
None
|
method
|
Any | None
|
Diffrax solver instance (e.g., Dopri5(), Tsit5()). Defaults to Dopri5() if None. |
None
|
rtol
|
float
|
Relative tolerance (used by adaptive-step methods only). |
1e-08
|
atol
|
float
|
Absolute tolerance (used by adaptive-step methods only). |
1e-06
|
max_steps
|
int
|
Maximum number of integrator steps. |
DEFAULT_MAX_STEPS
|
cache_dir
|
str | None
|
Directory for caching integration results. Relative paths are resolved from the project root. |
DEFAULT_CACHE_DIR
|
event_fn
|
Callable[[Any, Array, Any], Array] | None
|
Optional event function for early termination. Should return positive when integration should continue, negative/zero to stop. Signature: (t, y, args) -> scalar Array. |
None
|
solver_args
|
dict[str, Any] | None
|
Dict of kwargs passed directly to |
None
|
t_eval
|
tuple[float, float] | None
|
Optional save region |
None
|
clone
clone(
*,
device: str | None = None,
t_steps_factor: int = 1,
cache_dir: str | None | object = UNSET,
) -> JaxSolver
Create a copy of this solver, optionally overriding device, resolution, or caching.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
str | None
|
Target device ('cpu', 'cuda', 'gpu'). If None, keeps the current device. |
None
|
t_steps_factor
|
int
|
Multiply the number of evaluation points by this factor. Ignored for |
1
|
cache_dir
|
str | None | object
|
Override cache directory. Pass |
UNSET
|
Returns:
| Type | Description |
|---|---|
JaxSolver
|
New JaxSolver instance. |
integrate
integrate(
ode_system: ODESystemProtocol,
y0: Tensor,
params: Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]
Solve the ODE system and return the evaluation time points and solution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ode_system
|
ODESystemProtocol
|
An instance of JaxODESystem. |
required |
y0
|
Tensor
|
Initial conditions as PyTorch tensor with shape (batch, n_dims). |
required |
params
|
Tensor | None
|
Optional 2-D tensor of shape |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, Tensor]
|
Tuple (t_eval, y_values) as PyTorch tensors where y_values has shape |
pybasin.solvers.TorchDiffEqSolver
Bases: Solver
Differentiable ODE solver with full GPU support and O(1)-memory backpropagation.
Uses the adjoint method for memory-efficient gradient computation through ODE solutions. Supports adaptive-step (dopri5, dopri8, bosh3) and fixed-step (euler, rk4) methods.
See also: torchdiffeq GitHub
Citation:
@misc{torchdiffeq,
author={Chen, Ricky T. Q.},
title={torchdiffeq},
year={2018},
url={https://github.com/rtqichen/torchdiffeq},
}
Functions
__init__
__init__(
t_span: tuple[float, float] = (0, 1000),
t_steps: int = 1000,
device: str | None = None,
method: str = "dopri5",
rtol: float = 1e-08,
atol: float = 1e-06,
cache_dir: str | None = DEFAULT_CACHE_DIR,
t_eval: tuple[float, float] | None = None,
)
Initialize TorchDiffEqSolver.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t_span
|
tuple[float, float]
|
Tuple (t_start, t_end) defining the integration interval. |
(0, 1000)
|
t_steps
|
int
|
Number of evaluation points in the save region. |
1000
|
device
|
str | None
|
Device to use ('cuda', 'cpu', or None for auto-detect). |
None
|
method
|
str
|
Integration method from tordiffeq.odeint. |
'dopri5'
|
rtol
|
float
|
Relative tolerance (used by adaptive-step methods only). |
1e-08
|
atol
|
float
|
Absolute tolerance (used by adaptive-step methods only). |
1e-06
|
cache_dir
|
str | None
|
Directory for caching integration results. |
DEFAULT_CACHE_DIR
|
t_eval
|
tuple[float, float] | None
|
Optional save region |
None
|
pybasin.solvers.TorchOdeSolver
Bases: Solver
Parallel ODE solver with independent step sizes per batch element.
Compatible with PyTorch's JIT compiler for performance optimization. Unlike other solvers, torchode can take different step sizes for each sample in a batch, avoiding performance traps for problems of varying stiffness.
See also: torchode documentation
Citation:
@inproceedings{lienen2022torchode,
title = {torchode: A Parallel {ODE} Solver for PyTorch},
author = {Marten Lienen and Stephan G{"u}nnemann},
booktitle = {The Symbiosis of Deep Learning and Differential Equations II, NeurIPS},
year = {2022},
url = {https://openreview.net/forum?id=uiKVKTiUYB0}
}
Functions
__init__
__init__(
t_span: tuple[float, float] = (0, 1000),
t_steps: int = 1000,
device: str | None = None,
method: str = "dopri5",
rtol: float = 1e-08,
atol: float = 1e-06,
cache_dir: str | None = DEFAULT_CACHE_DIR,
t_eval: tuple[float, float] | None = None,
)
Initialize TorchOdeSolver.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t_span
|
tuple[float, float]
|
Tuple (t_start, t_end) defining the integration interval. |
(0, 1000)
|
t_steps
|
int
|
Number of evaluation points in the save region. |
1000
|
device
|
str | None
|
Device to use ('cuda', 'cpu', or None for auto-detect). |
None
|
method
|
str
|
Integration method ('dopri5', 'tsit5', 'euler', 'heun'). |
'dopri5'
|
rtol
|
float
|
Relative tolerance (used by adaptive-step methods only). |
1e-08
|
atol
|
float
|
Absolute tolerance (used by adaptive-step methods only). |
1e-06
|
cache_dir
|
str | None
|
Directory for caching integration results. |
DEFAULT_CACHE_DIR
|
t_eval
|
tuple[float, float] | None
|
Optional save region |
None
|
pybasin.solvers.ScipyParallelSolver
Bases: Solver
ODE solver using sklearn's parallel processing with scipy's solve_ivp.
Uses multiprocessing (loky backend) to solve multiple initial conditions in parallel. Each worker solves one trajectory at a time using scipy's solve_ivp.
Requires a :class:~pybasin.solvers.numpy_ode_system.NumpyODESystem subclass. The ODE is passed
directly to solve_ivp with no PyTorch-to-NumPy conversion overhead.
See also: scipy.integrate.solve_ivp
Functions
__init__
__init__(
t_span: tuple[float, float] = (0, 1000),
t_steps: int = 1000,
device: str | None = None,
n_jobs: int = -1,
method: str = "RK45",
rtol: float = 1e-06,
atol: float = 1e-08,
max_step: float | None = None,
cache_dir: str | None = DEFAULT_CACHE_DIR,
t_eval: tuple[float, float] | None = None,
)
Initialize ScipyParallelSolver.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
t_span
|
tuple[float, float]
|
Tuple (t_start, t_end) defining the integration interval. |
(0, 1000)
|
t_steps
|
int
|
Number of evaluation points in the save region. |
1000
|
device
|
str | None
|
Device to use (only 'cpu' supported). |
None
|
n_jobs
|
int
|
Number of parallel jobs (-1 for all CPUs). |
-1
|
method
|
str
|
Integration method ('RK45', 'RK23', 'DOP853', 'Radau', 'BDF', 'LSODA', etc). |
'RK45'
|
rtol
|
float
|
Relative tolerance (used by adaptive-step methods only). |
1e-06
|
atol
|
float
|
Absolute tolerance (used by adaptive-step methods only). |
1e-08
|
max_step
|
float | None
|
Maximum step size for the solver. |
None
|
cache_dir
|
str | None
|
Directory for caching integration results. |
DEFAULT_CACHE_DIR
|
t_eval
|
tuple[float, float] | None
|
Optional save region |
None
|