Optimizing through an ODE solve and re-creating MTK Problems
Solving an ODE as part of an OptimizationProblem
's loss function is a common scenario. In this example, we will go through an efficient way to model such scenarios using ModelingToolkit.jl.
First, we build the ODE to be solved. For this example, we will use a Lotka-Volterra model:
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
@parameters α β γ δ
@variables x(t) y(t)
eqs = [D(x) ~ (α - β * y) * x
D(y) ~ (δ * x - γ) * y]
@mtkcompile odesys = System(eqs, t)
<< @example-block not executed in draft mode >>
To create the "data" for optimization, we will solve the system with a known set of parameters.
using OrdinaryDiffEq
odeprob = ODEProblem(
odesys, [x => 1.0, y => 1.0, α => 1.5, β => 1.0, γ => 3.0, δ => 1.0], (0.0, 10.0))
timesteps = 0.0:0.1:10.0
sol = solve(odeprob, Tsit5(); saveat = timesteps)
data = Array(sol)
# add some random noise
data = data + 0.01 * randn(size(data))
<< @example-block not executed in draft mode >>
Now we will create the loss function for the Optimization solve. This will require creating an ODEProblem
with the parameter values passed to the loss function. Creating a new ODEProblem
is expensive and requires differentiating through the code generation process. This can be bug-prone and is unnecessary. Instead, we will leverage the remake
function. This allows creating a copy of an existing problem with updating state/parameter values. It should be noted that the types of the values passed to the loss function may not agree with the types stored in the existing ODEProblem
. Thus, we cannot use setp
to modify the problem in-place. Here, we will use the replace
function from SciMLStructures.jl since it allows updating the entire Tunable
portion of the parameter object which contains the parameters to optimize.
using SymbolicIndexingInterface: parameter_values, state_values
using SciMLStructures: Tunable, canonicalize, replace, replace!
using PreallocationTools
function loss(x, p)
odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
ps = parameter_values(odeprob) # obtain the parameter object from the problem
diffcache = p[5]
# get an appropriately typed preallocated buffer to store the `x` values in
buffer = get_tmp(diffcache, x)
# copy the current values to this buffer
copyto!(buffer, canonicalize(Tunable(), ps)[1])
# create a copy of the parameter object with the buffer
ps = replace(Tunable(), ps, buffer)
# set the updated values in the parameter object
setter = p[4]
setter(ps, x)
# remake the problem, passing in our new parameter object
newprob = remake(odeprob; p = ps)
timesteps = p[2]
sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
truth = p[3]
data = Array(sol)
return sum((truth .- data) .^ 2) / length(truth)
end
<< @example-block not executed in draft mode >>
Note how the problem, timesteps and true data are stored as model parameters. This helps avoid referencing global variables in the function, which would slow it down significantly.
We could have done the same thing by passing remake
a map of parameter values. For example, let us enforce that the order of ODE parameters in x
is [α β γ δ]
. Then, we could have done:
remake(odeprob; p = [α => x[1], β => x[2], γ => x[3], δ => x[4]])
However, passing a symbolic map to remake
is significantly slower than passing it a parameter object directly. Thus, we use replace
to speed up the process. In general, remake
is the most flexible method, but the flexibility comes at a cost of performance.
We can perform the optimization as below:
using Optimization
using OptimizationOptimJL
using SymbolicIndexingInterface
# manually create an OptimizationFunction to ensure usage of `ForwardDiff`, which will
# require changing the types of parameters from `Float64` to `ForwardDiff.Dual`
optfn = OptimizationFunction(loss, Optimization.AutoForwardDiff())
# function to set the parameters we are optimizing
setter = setp(odeprob, [α, β, γ, δ])
# `DiffCache` to avoid allocations.
# `copy` prevents the buffer stored by `DiffCache` from aliasing the one in
# `parameter_values(odeprob)`.
diffcache = DiffCache(copy(canonicalize(Tunable(), parameter_values(odeprob))[1]))
# parameter object is a tuple, to store differently typed objects together
optprob = OptimizationProblem(
optfn, rand(4), (odeprob, timesteps, data, setter, diffcache),
lb = 0.1zeros(4), ub = 3ones(4))
sol = solve(optprob, BFGS())
<< @example-block not executed in draft mode >>
Re-creating the problem
There are multiple ways to re-create a problem with new state/parameter values. We will go over the various methods, listing their use cases.
Pure remake
This method is the most generic. It can handle symbolic maps, initializations of parameters/states dependent on each other and partial updates. However, this comes at the cost of performance. remake
is also not always inferable.
remake
and setp
/setu
Calling remake(prob)
creates a copy of the existing problem. This new problem has the exact same types as the original one, and the remake
call is fully inferred. State/parameter values can be modified after the copy by using setp
and/or setu
. This is most appropriate when the types of state/parameter values does not need to be changed, only their values.
replace
and remake
replace
returns a copy of a parameter object, with the appropriate portion replaced by new values. This is useful for changing the type of an entire portion, such as during the optimization process described above. remake
is used in this case to create a copy of the problem with updated state/unknown values.
remake
and replace!
replace!
is similar to replace
, except that it operates in-place. This means that the parameter values must be of the same types. This is useful for cases where bulk parameter replacement is required without needing to change types. For example, optimization methods where the gradient is not computed using dual numbers (as demonstrated above).