Skip to content

Re-creating MTK problems causes side effects #3407

Closed
@ysfoo

Description

@ysfoo

I am following the somewhat recent tutorial on re-creating MTK problems. In summary, the function loss is defined, which remakes an ODEProblem by replacing the tunable portion of the parameters with the input argument, and then evaluates the loss function. Upon calling the loss function, the parameters in the original ODEProblem object (in the global scope) are mutated.

MWE:

using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D

@parameters α β γ δ
@variables x(t) y(t)
eqs = [D(x) ~ (α - β * y) * x
       D(y) ~ (δ * x - γ) * y]
@mtkbuild odesys = ODESystem(eqs, t)

using OrdinaryDiffEq

odeprob = ODEProblem(
    odesys, [x => 1.0, y => 1.0], (0.0, 10.0), [α => 1.5, β => 1.0, γ => 3.0, δ => 1.0])
timesteps = 0.0:0.1:10.0
sol = solve(odeprob, Tsit5(); saveat = timesteps)
data = Array(sol)
# add some random noise
data = data + 0.01 * randn(size(data))

using SymbolicIndexingInterface: parameter_values, state_values
using SciMLStructures: Tunable, canonicalize, replace, replace!
using PreallocationTools

function loss(x, p)
    odeprob = p[1] # ODEProblem stored as parameters to avoid using global variables
    ps = parameter_values(odeprob) # obtain the parameter object from the problem
    diffcache = p[5]
    # get an appropriately typed preallocated buffer to store the `x` values in
    buffer = get_tmp(diffcache, x)
    # copy the current values to this buffer
    copyto!(buffer, canonicalize(Tunable(), ps)[1])
    # create a copy of the parameter object with the buffer
    ps = replace(Tunable(), ps, buffer)
    # set the updated values in the parameter object
    setter = p[4]
    setter(ps, x)
    # remake the problem, passing in our new parameter object
    newprob = remake(odeprob; p = ps)
    timesteps = p[2]
    sol = solve(newprob, AutoTsit5(Rosenbrock23()); saveat = timesteps)
    truth = p[3]
    data = Array(sol)
    return sum((truth .- data) .^ 2) / length(truth)
end

using SymbolicIndexingInterface

setter = setp(odeprob, [α, β, γ, δ]);
# `DiffCache` to avoid allocations
diffcache = DiffCache(canonicalize(Tunable(), parameter_values(odeprob))[1]);

getter = getp(odeprob, [α, β, γ, δ])
getter(odeprob) # returns original parameter values
loss(ones(4), (odeprob, timesteps, data, setter, diffcache))
getter(odeprob) # returns ones, calling the loss function has mutated `odeprob`

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions