Getting Started with SciMLSensitivity: Differentiating ODE Solutions
This tutorial assumes familiarity with DifferentialEquations.jl. If you are not familiar with DifferentialEquations.jl, please consult the DifferentialEquations.jl documentation.
SciMLSensitivity.jl is a tool for obtaining derivatives of equation solvers, such as differential equation solvers. These can be used in many ways, such as for analyzing the local sensitivities of a system or to compute the gradients of cost functions for model calibration and parameter estimation. In this tutorial, we will show how to make use of the tooling in SciMLSensitivity.jl to differentiate the ODE solvers.
SciMLSensitivity.jl applies to all equation solvers of the SciML ecosystem, such as linear solvers, nonlinear solvers, nonlinear optimization, and more. This tutorial focuses on differential equations, so please see the other tutorials focused on these other SciMLProblem types as necessary. While the interface works similarly for all problem types, these tutorials will showcase the aspects that are special to a given problem.
Setup
Let's first define a differential equation we wish to solve. We will choose the Lotka-Volterra equation. This is done via DifferentialEquations.jl using:
using OrdinaryDiffEq
function lotka_volterra!(du, u, p, t)
du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];
prob = ODEProblem(lotka_volterra!, u0, (0.0, 10.0), p)
sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6)
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 104-element Vector{Float64}:
0.0
0.022388671774158386
0.06688455772214347
0.12204057861057453
0.1901739088897294
0.2700958843663612
0.3624899566568635
0.4663498913963812
0.5804932242040545
0.7035670559596722
⋮
9.363458919328917
9.438253960677558
9.514924295802581
9.5948773310752
9.679331554459784
9.769895481406428
9.868269555469228
9.975570635758869
10.0
u: 104-element Vector{Vector{Float64}}:
[1.0, 1.0]
[1.0117558257818347, 0.9563342092954507]
[1.0384182072226116, 0.8758683249561677]
[1.0774848848533785, 0.786875167161091]
[1.134905782974095, 0.6915813161179499]
[1.215349431258797, 0.5976695404830108]
[1.3266197064716623, 0.509348518202066]
[1.4766110896773839, 0.43133598821934244]
[1.674672297975389, 0.36648541983542293]
[1.9317152588988318, 0.31613609919958985]
⋮
[1.280491754436697, 3.2111901642776264]
[1.1439255035472293, 2.8083596202276446]
[1.0502507522603817, 2.426494591751833]
[0.9895322686715115, 2.070758390747169]
[0.9563824267282467, 1.7446017637035662]
[0.9484176841185465, 1.4490309505420775]
[0.9660834157927825, 1.1849911555290313]
[1.0122116806002588, 0.9547855316946803]
[1.0263542618083945, 0.9096831916592041]
Now let's differentiate the solution to this ODE using a few different automatic differentiation methods.
Forward-Mode Automatic Differentiation with ForwardDiff.jl
Let's say we need the derivative of the solution with respect to the initial condition u0
and its parameters p
. One of the simplest ways to do this is via ForwardDiff.jl. All one needs to do is to use the ForwardDiff.jl library to differentiate some function f
which uses a differential equation solve
inside of it. For example, let's say we want the derivative of the first component of the ODE solution with respect to these quantities at evenly spaced time points of dt = 1
. We can compute this via:
using ForwardDiff
function f(x)
_prob = remake(prob, u0 = x[1:2], p = x[3:end])
solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 1)[1, :]
end
x = [u0; p]
dx = ForwardDiff.jacobian(f, x)
11×6 Matrix{Float64}:
1.0 0.0 0.0 0.0 0.0 0.0
2.14463 -1.1848 2.54832 -1.1848 0.477483 -0.628218
-5.88478 0.266338 -3.38158 0.266338 3.50594 -12.662
0.691824 0.3718 -0.762033 0.3718 -0.0477691 -0.278507
2.7989 -0.408784 3.80837 -0.408784 0.883252 0.914524
4.0171 -1.65424 12.3007 -1.65424 3.95659 -2.0814
-2.07453 0.851802 -7.0992 0.851802 -1.06005 -3.46806
2.63655 -0.00114306 3.54679 -0.00114306 0.872776 1.30651
7.88534 -0.610538 16.9144 -0.610538 4.3355 3.54215
-16.5707 0.866198 -36.104 0.866198 -5.67502 -19.8444
1.96602 0.188561 2.16063 0.188561 0.563199 0.939672
Let's dig into what this is saying a bit. x
is a vector which concatenates the initial condition and parameters, meaning that the first 2 values are the initial conditions and the last 4 are the parameters. We use the remake
function to build a function f(x)
which uses these new initial conditions and parameters to solve the differential equation and return the time series of the first component.
Then ForwardDiff.jacobian(f,x)
computes the Jacobian of f
with respect to x
. The output dx[i,j]
corresponds to the derivative of the solution of the first component at time t=j-1
with respect to x[i]
. For example, dx[3,2]
is the derivative of the first component of the solution at time t=1
with respect to p[1]
.
Since the global error is 1-2 orders of magnitude higher than the local error, we use accuracies of 1e-6 (instead of the default 1e-3) to get reasonable sensitivities
Reverse-Mode Automatic Differentiation
The solve
function is automatically compatible with AD systems like Zygote.jl and thus there is no machinery that is necessary to use other than to put solve
inside a function that is differentiated by Zygote. For example, the following computes the solution to an ODE and computes the gradient of a loss function (the sum of the ODE's output at each timepoint with dt=0.1) via the adjoint method:
using Zygote, SciMLSensitivity
function sum_of_solution(u0, p)
_prob = remake(prob, u0 = u0, p = p)
sum(solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1))
end
du01, dp1 = Zygote.gradient(sum_of_solution, u0, p)
([-39.127737527250886, -8.787495434474875], [8.304244028181543, -159.48401961551914, 75.2031622989798, -339.19516313730287])
Zygote.jl's automatic differentiation system is overloaded to allow SciMLSensitivity.jl to redefine the way the derivatives are computed, allowing trade-offs between numerical stability, memory, and compute performance, similar to how ODE solver algorithms are chosen.
Choosing Sensitivity Algorithms
The algorithms for differentiation calculation are called AbstractSensitivityAlgorithms
, or sensealg
s for short. These are chosen by passing the sensealg
keyword argument into solve. Let's demonstrate this by choosing the QuadratureAdjoint
sensealg
for the differentiation of this system:
function sum_of_solution(u0, p)
_prob = remake(prob, u0 = u0, p = p)
sum(solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
sensealg = GaussAdjoint()))
end
du01, dp1 = Zygote.gradient(sum_of_solution, u0, p)
([-39.1261032497264, -8.787925705972565], [8.307610393622285, -159.48459637924043, 75.20354297813154, -339.19349676309093])
Here this computes the derivative of the output with respect to the initial condition and the derivative with respect to the parameters respectively using the GaussAdjoint()
. For more information on the choices of sensitivity algorithms, see the reference documentation in choosing sensitivity algorithms.
When Should You Use Forward or Reverse Mode?
Good question! The simple answer is, if you are differentiating a system of fewer than 100 equations, use forward-mode, otherwise reverse-mode. But it can be a lot more complicated than that! For more information, see the reference documentation in choosing sensitivity algorithms.
And that is it! Where should you go from here?
That's all there is to the basics of differentiating the ODE solvers with SciMLSensitivity.jl. That said, check out the following tutorials to dig into more detail:
- See the ODE parameter estimation tutorial to learn how to fit the parameters of ODE systems
- See the direct sensitivity tutorial to dig into the lower level API for more performance