Optimizing through an ODE solve and re-creating MTK Problems

Solving an ODE as part of an OptimizationProblem's loss function is a common scenario. In this example, we will go through an efficient way to model such scenarios using ModelingToolkit.jl.

First, we build the ODE to be solved. For this example, we will use a Lotka-Volterra model:

using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters α β γ δ
@variables x(t) y(t)
eqs = [D(x) ~ (α - β * y) * x
       D(y) ~ (δ * x - γ) * y]
@mtkcompile odesys = System(eqs, t)
<< @example-block not executed in draft mode >>

To create the "data" for optimization, we will solve the system with a known set of parameters.

using OrdinaryDiffEq

odeprob = ODEProblem(
    odesys, [x => 1.0, y => 1.0, α => 1.5, β => 1.0, γ => 3.0, δ => 1.0], (0.0, 10.0))
timesteps = 0.0:0.1:10.0
sol = solve(odeprob, Tsit5(); saveat = timesteps)
data = Array(sol)
# add some random noise
data = data + 0.01 * randn(size(data))
<< @example-block not executed in draft mode >>

Now we will create the loss function for the Optimization solve. This will require creating an ODEProblem with the parameter values passed to the loss function. Creating a new ODEProblem is expensive and requires differentiating through the code generation process. This can be bug-prone and is unnecessary. Instead, we will leverage the remake function. This allows creating a copy of an existing problem with updating state/parameter values. It should be noted that the types of the values passed to the loss function may not agree with the types stored in the existing ODEProblem. Thus, we cannot use setp to modify the problem in-place. Here, we will use the replace function from SciMLStructures.jl since it allows updating the entire Tunable portion of the parameter object which contains the parameters to optimize.

using SymbolicIndexingInterface: parameter_values, state_values
using SciMLStructures: Tunable, canonicalize, replace, replace!
using PreallocationTools

function loss(x, p)
    odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
    ps = parameter_values(odeprob) # obtain the parameter object from the problem
    diffcache = p[5]
    # get an appropriately typed preallocated buffer to store the `x` values in
    buffer = get_tmp(diffcache, x)
    # copy the current values to this buffer
    copyto!(buffer, canonicalize(Tunable(), ps)[1])
    # create a copy of the parameter object with the buffer
    ps = replace(Tunable(), ps, buffer)
    # set the updated values in the parameter object
    setter = p[4]
    setter(ps, x)
    # remake the problem, passing in our new parameter object
    newprob = remake(odeprob; p = ps)
    timesteps = p[2]
    sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
    truth = p[3]
    data = Array(sol)
    return sum((truth .- data) .^ 2) / length(truth)
end
<< @example-block not executed in draft mode >>

Note how the problem, timesteps and true data are stored as model parameters. This helps avoid referencing global variables in the function, which would slow it down significantly.

We could have done the same thing by passing remake a map of parameter values. For example, let us enforce that the order of ODE parameters in x is [α β γ δ]. Then, we could have done:

remake(odeprob; p = [α => x[1], β => x[2], γ => x[3], δ => x[4]])

However, passing a symbolic map to remake is significantly slower than passing it a parameter object directly. Thus, we use replace to speed up the process. In general, remake is the most flexible method, but the flexibility comes at a cost of performance.

We can perform the optimization as below:

using Optimization
using OptimizationOptimJL
using SymbolicIndexingInterface

# manually create an OptimizationFunction to ensure usage of `ForwardDiff`, which will
# require changing the types of parameters from `Float64` to `ForwardDiff.Dual`
optfn = OptimizationFunction(loss, Optimization.AutoForwardDiff())
# function to set the parameters we are optimizing
setter = setp(odeprob, [α, β, γ, δ])
# `DiffCache` to avoid allocations.
# `copy` prevents the buffer stored by `DiffCache` from aliasing the one in
# `parameter_values(odeprob)`.
diffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(odeprob))[1]))
# parameter object is a tuple, to store differently typed objects together
optprob = OptimizationProblem(
    optfn, rand(4), (odeprob, timesteps, data, setter, diffcache),
    lb = 0.1zeros(4), ub = 3ones(4))
sol = solve(optprob, BFGS())
<< @example-block not executed in draft mode >>

Re-creating the problem

There are multiple ways to re-create a problem with new state/parameter values. We will go over the various methods, listing their use cases.

Pure remake

This method is the most generic. It can handle symbolic maps, initializations of parameters/states dependent on each other and partial updates. However, this comes at the cost of performance. remake is also not always inferable.

remake and setp/setu

Calling remake(prob) creates a copy of the existing problem. This new problem has the exact same types as the original one, and the remake call is fully inferred. State/parameter values can be modified after the copy by using setp and/or setu. This is most appropriate when the types of state/parameter values does not need to be changed, only their values.

replace and remake

replace returns a copy of a parameter object, with the appropriate portion replaced by new values. This is useful for changing the type of an entire portion, such as during the optimization process described above. remake is used in this case to create a copy of the problem with updated state/unknown values.

remake and replace!

replace! is similar to replace, except that it operates in-place. This means that the parameter values must be of the same types. This is useful for cases where bulk parameter replacement is required without needing to change types. For example, optimization methods where the gradient is not computed using dual numbers (as demonstrated above).