Getting Started with SciMLSensitivity: Differentiating ODE Solutions

Warn

This tutorial assumes familiarity with DifferentialEquations.jl. If you are not familiar with DifferentialEquations.jl, please consult the DifferentialEquations.jl documentation.

SciMLSensitivity.jl is a tool for obtaining derivatives of equation solvers, such as differential equation solvers. These can be used in many ways, such as for analyzing the local sensitivities of a system or to compute the gradients of cost functions for model calibration and parameter estimation. In this tutorial, we will show how to make use of the tooling in SciMLSensitivity.jl to differentiate the ODE solvers.

Note

SciMLSensitivity.jl applies to all equation solvers of the SciML ecosystem, such as linear solvers, nonlinear solvers, nonlinear optimization, and more. This tutorial focuses on differential equations, so please see the other tutorials focused on these other SciMLProblem types as necessary. While the interface works similarly for all problem types, these tutorials will showcase the aspects that are special to a given problem.

Setup

Let's first define a differential equation we wish to solve. We will choose the Lotka-Volterra equation. This is done via DifferentialEquations.jl using:

using OrdinaryDiffEq

function lotka_volterra!(du, u, p, t)
    du[1] = dx = p[1] * u[1] - p[2] * u[1] * u[2]
    du[2] = dy = -p[3] * u[2] + p[4] * u[1] * u[2]
end
p = [1.5, 1.0, 3.0, 1.0];
u0 = [1.0; 1.0];
prob = ODEProblem(lotka_volterra!, u0, (0.0, 10.0), p)
sol = solve(prob, Tsit5(), reltol = 1e-6, abstol = 1e-6)
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 104-element Vector{Float64}:
  0.0
  0.022388671774158386
  0.06688455772214347
  0.12204057861057453
  0.1901739088897294
  0.2700958843663612
  0.3624899566568635
  0.4663498913963812
  0.5804932242040545
  0.7035670559596722
  ⋮
  9.363458919328917
  9.438253960677558
  9.514924295802581
  9.5948773310752
  9.679331554459784
  9.769895481406428
  9.868269555469228
  9.975570635758869
 10.0
u: 104-element Vector{Vector{Float64}}:
 [1.0, 1.0]
 [1.0117558257818347, 0.9563342092954507]
 [1.0384182072226116, 0.8758683249561677]
 [1.0774848848533785, 0.786875167161091]
 [1.134905782974095, 0.6915813161179499]
 [1.215349431258797, 0.5976695404830108]
 [1.3266197064716623, 0.509348518202066]
 [1.4766110896773839, 0.43133598821934244]
 [1.674672297975389, 0.36648541983542293]
 [1.9317152588988318, 0.31613609919958985]
 ⋮
 [1.280491754436697, 3.2111901642776264]
 [1.1439255035472293, 2.8083596202276446]
 [1.0502507522603817, 2.426494591751833]
 [0.9895322686715115, 2.070758390747169]
 [0.9563824267282467, 1.7446017637035662]
 [0.9484176841185465, 1.4490309505420775]
 [0.9660834157927825, 1.1849911555290313]
 [1.0122116806002588, 0.9547855316946803]
 [1.0263542618083945, 0.9096831916592041]

Now let's differentiate the solution to this ODE using a few different automatic differentiation methods.

Forward-Mode Automatic Differentiation with ForwardDiff.jl

Let's say we need the derivative of the solution with respect to the initial condition u0 and its parameters p. One of the simplest ways to do this is via ForwardDiff.jl. All one needs to do is to use the ForwardDiff.jl library to differentiate some function f which uses a differential equation solve inside of it. For example, let's say we want the derivative of the first component of the ODE solution with respect to these quantities at evenly spaced time points of dt = 1. We can compute this via:

using ForwardDiff

function f(x)
    _prob = remake(prob, u0 = x[1:2], p = x[3:end])
    solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 1)[1, :]
end
x = [u0; p]
dx = ForwardDiff.jacobian(f, x)
11×6 Matrix{Float64}:
   1.0        0.0           0.0        0.0          0.0          0.0
   2.14463   -1.1848        2.54832   -1.1848       0.477483    -0.628218
  -5.88478    0.266338     -3.38158    0.266338     3.50594    -12.662
   0.691824   0.3718       -0.762033   0.3718      -0.0477691   -0.278507
   2.7989    -0.408784      3.80837   -0.408784     0.883252     0.914524
   4.0171    -1.65424      12.3007    -1.65424      3.95659     -2.0814
  -2.07453    0.851802     -7.0992     0.851802    -1.06005     -3.46806
   2.63655   -0.00114306    3.54679   -0.00114306   0.872776     1.30651
   7.88534   -0.610538     16.9144    -0.610538     4.3355       3.54215
 -16.5707     0.866198    -36.104      0.866198    -5.67502    -19.8444
   1.96602    0.188561      2.16063    0.188561     0.563199     0.939672

Let's dig into what this is saying a bit. x is a vector which concatenates the initial condition and parameters, meaning that the first 2 values are the initial conditions and the last 4 are the parameters. We use the remake function to build a function f(x) which uses these new initial conditions and parameters to solve the differential equation and return the time series of the first component.

Then ForwardDiff.jacobian(f,x) computes the Jacobian of f with respect to x. The output dx[i,j] corresponds to the derivative of the solution of the first component at time t=j-1 with respect to x[i]. For example, dx[3,2] is the derivative of the first component of the solution at time t=1 with respect to p[1].

Note

Since the global error is 1-2 orders of magnitude higher than the local error, we use accuracies of 1e-6 (instead of the default 1e-3) to get reasonable sensitivities

Reverse-Mode Automatic Differentiation

The solve function is automatically compatible with AD systems like Zygote.jl and thus there is no machinery that is necessary to use other than to put solve inside a function that is differentiated by Zygote. For example, the following computes the solution to an ODE and computes the gradient of a loss function (the sum of the ODE's output at each timepoint with dt=0.1) via the adjoint method:

using Zygote, SciMLSensitivity

function sum_of_solution(u0, p)
    _prob = remake(prob, u0 = u0, p = p)
    sum(solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1))
end
du01, dp1 = Zygote.gradient(sum_of_solution, u0, p)
([-39.127737527250886, -8.787495434474875], [8.304244028181543, -159.48401961551914, 75.2031622989798, -339.19516313730287])

Zygote.jl's automatic differentiation system is overloaded to allow SciMLSensitivity.jl to redefine the way the derivatives are computed, allowing trade-offs between numerical stability, memory, and compute performance, similar to how ODE solver algorithms are chosen.

Choosing Sensitivity Algorithms

The algorithms for differentiation calculation are called AbstractSensitivityAlgorithms, or sensealgs for short. These are chosen by passing the sensealg keyword argument into solve. Let's demonstrate this by choosing the QuadratureAdjointsensealg for the differentiation of this system:

function sum_of_solution(u0, p)
    _prob = remake(prob, u0 = u0, p = p)
    sum(solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
        sensealg = GaussAdjoint()))
end
du01, dp1 = Zygote.gradient(sum_of_solution, u0, p)
([-39.1261032497264, -8.787925705972565], [8.307610393622285, -159.48459637924043, 75.20354297813154, -339.19349676309093])

Here this computes the derivative of the output with respect to the initial condition and the derivative with respect to the parameters respectively using the GaussAdjoint(). For more information on the choices of sensitivity algorithms, see the reference documentation in choosing sensitivity algorithms.

Note

ForwardDiff.jl's automatic differentiation system ignores the sensitivity algorithms.

When Should You Use Forward or Reverse Mode?

Good question! The simple answer is, if you are differentiating a system of fewer than 100 equations, use forward-mode, otherwise reverse-mode. But it can be a lot more complicated than that! For more information, see the reference documentation in choosing sensitivity algorithms.

And that is it! Where should you go from here?

That's all there is to the basics of differentiating the ODE solvers with SciMLSensitivity.jl. That said, check out the following tutorials to dig into more detail: