-
-
Notifications
You must be signed in to change notification settings - Fork 160
Description
Hi all, as posted here,
My executions of DiffEqFlux.sciml_train
are not optimizing but yet are iterating for the full amount without error. By not optimizing, I mean that for whatever random initial parameters I feed for a fitting example, the loss function does not change at each iteration and the final minimizers are the initial parameters. This remains true despite changing the learning rate for Adam, the number of iterations, the solver entirely (I also tried BFGS and SOSRI) or sensealg=ForwardDiffSensitivity()
.
I am using the standard loss(abs2, predicted - observed)
and based on the initializations I randomize with I see significant variations in the loss which makes me think for some initial conditions some improvement should be possible. Previously, I have optimized identical systems with higher dimensions and terrible data with matlab’s fmincon().
All this leads me to believe I am making some novice error which is likely given I am new to julia and DiffEqFlux. Any help would be much appreciated. Note, I did just update everything, restarted atom and found the same problem.
First, I generate the model problem,
##
using Plots, Random
using DifferentialEquations, Flux, Optim, DiffEqFlux, DiffEqSensitivity
## Define GLV and generate synthetic data from defined parameters
tsteps = 24
n=3
function gLV!(dx,x,p,t)
# Weird ReverseDiff (AD) bug
if (typeof(p) != Array{Float64,1}) #&& (typeof(p) != SciMLBase.NullParameters)
p=p.value
end
# unpack params
n = length(x)
mu = p[1:n]
A = Array{Float64,2}(undef,3,3)
for i=1:n; A[:,i] = p[n*i + 1: n*(i+1)]; end
dx .= x.*(mu + A*x)
if any(x .> 100.0) # no cc catch
dx .= 0
end
end
condition(u,t,integrator) = any(u .< 1e-8) #|| any(u .> 10)
function affect!(integrator)
integrator.u[integrator.u .< 1e-8] .= 0
end
cb = ContinuousCallback(condition,affect!)
mu_true = [0.4; 0.8; 0.2].*0.1
A_true = [-1.5 0.5 -0.2; -0.7 -1.3 0.1; 0.8 -0.2 -0.9]
p_true = [mu_true; A_true[:]]
x0 = [0.1; 0.1; 0.1].*0.1
tspan = (0.0, 144.0)
prob = ODEProblem(gLV!, x0, tspan)
sol = DifferentialEquations.solve(prob, Tsit5(), p=p_true, saveat=tsteps, callback=cb)
syndata = Array(sol) + 0.002*randn(size(Array(sol)))
plot(sol, alpha=0.3)
Plots.scatter!(sol.t, syndata')
title!("Definite")
After making the model problem and generating some noisy data with the true parameters, I then initialize random parameters, define a loss function and call sciml_train.
## Random Initialize
mu_init = 0.1*rand(n)
A_init = -1*rand(n,n)
p_init = [mu_init; A_init[:]]
## Optimization and Fitting
i=0
function loss(p_curr)
global i
i+=1
sol = DifferentialEquations.solve(prob, Tsit5(), p=p_curr, saveat=tsteps, callback=cb)
loss = sum(abs2, Array(sol) - syndata)
print("\n Loss of ",loss, " on iteration ",i)
return loss, sol
end
result_ode = DiffEqFlux.sciml_train(loss, p_init,
ADAM(0.1),
maxiters=50)
print("\n")
print("\n p_init:", round.(p_init,digits=4))
print("\n p_optm:", round.(result_ode.minimizer,digits=4))
which outputs for a given initialization one score repetitively and then prints the same two sets of parameters, like so:
0.013182825077950067 on iteration 1
Loss of 0.013182825077950069 on iteration 2
Loss of 0.013182825077950067 on iteration 3
Loss of 0.013182825077950069 on iteration 4
Loss of 0.013182825077950067 on iteration 5
Loss of 0.013182825077950069 on iteration 6
Loss of 0.013182825077950067 on iteration 7
Loss of 0.013182825077950069 on iteration 8
Loss of 0.013182825077950067 on iteration 9
Loss of 0.013182825077950069 on iteration 10
Loss of 0.013182825077950067 on iteration 11
Loss of 0.013182825077950069 on iteration 12
Loss of 0.013182825077950067 on iteration 13
Loss of 0.013182825077950069 on iteration 14
Loss of 0.013182825077950067 on iteration 15
Loss of 0.013182825077950069 on iteration 16
Loss of 0.013182825077950067 on iteration 17
Loss of 0.013182825077950069 on iteration 18
Loss of 0.013182825077950067 on iteration 19
Loss of 0.013182825077950069 on iteration 20
Loss of 0.013182825077950067 on iteration 21
Loss of 0.013182825077950069 on iteration 22
Loss of 0.013182825077950067 on iteration 23
Loss of 0.013182825077950069 on iteration 24
Loss of 0.013182825077950067 on iteration 25
Loss of 0.013182825077950069 on iteration 26
Loss of 0.013182825077950067 on iteration 27
Loss of 0.013182825077950069 on iteration 28
Loss of 0.013182825077950067 on iteration 29
Loss of 0.013182825077950069 on iteration 30
Loss of 0.013182825077950067 on iteration 31
Loss of 0.013182825077950069 on iteration 32
Loss of 0.013182825077950067 on iteration 33
Loss of 0.013182825077950069 on iteration 34
Loss of 0.013182825077950067 on iteration 35
Loss of 0.013182825077950069 on iteration 36
Loss of 0.013182825077950067 on iteration 37
Loss of 0.013182825077950069 on iteration 38
Loss of 0.013182825077950067 on iteration 39
Loss of 0.013182825077950069 on iteration 40
Loss of 0.013182825077950067 on iteration 41
Loss of 0.013182825077950069 on iteration 42
Loss of 0.013182825077950067 on iteration 43
Loss of 0.013182825077950069 on iteration 44
Loss of 0.013182825077950067 on iteration 45
Loss of 0.013182825077950069 on iteration 46
Loss of 0.013182825077950067 on iteration 47
Loss of 0.013182825077950069 on iteration 48
Loss of 0.013182825077950067 on iteration 49
Loss of 0.013182825077950069 on iteration 50
Loss of 0.013182825077950067 on iteration 51
Loss of 0.013182825077950069 on iteration 52
Loss of 0.013182825077950067 on iteration 53
Loss of 0.013182825077950069 on iteration 54
Loss of 0.013182825077950067 on iteration 55
Loss of 0.013182825077950069 on iteration 56
Loss of 0.013182825077950067 on iteration 57
Loss of 0.013182825077950069 on iteration 58
Loss of 0.013182825077950067 on iteration 59
Loss of 0.013182825077950069 on iteration 60
Loss of 0.013182825077950067 on iteration 61
Loss of 0.013182825077950069 on iteration 62
Loss of 0.013182825077950067 on iteration 63
Loss of 0.013182825077950069 on iteration 64
Loss of 0.013182825077950067 on iteration 65
Loss of 0.013182825077950069 on iteration 66
Loss of 0.013182825077950067 on iteration 67
Loss of 0.013182825077950069 on iteration 68
Loss of 0.013182825077950067 on iteration 69
Loss of 0.013182825077950069 on iteration 70
Loss of 0.013182825077950067 on iteration 71
Loss of 0.013182825077950069 on iteration 72
Loss of 0.013182825077950067 on iteration 73
Loss of 0.013182825077950069 on iteration 74
Loss of 0.013182825077950067 on iteration 75
Loss of 0.013182825077950069 on iteration 76
Loss of 0.013182825077950067 on iteration 77
Loss of 0.013182825077950069 on iteration 78
Loss of 0.013182825077950067 on iteration 79
Loss of 0.013182825077950069 on iteration 80
Loss of 0.013182825077950067 on iteration 81
Loss of 0.013182825077950069 on iteration 82
Loss of 0.013182825077950067 on iteration 83
Loss of 0.013182825077950069 on iteration 84
Loss of 0.013182825077950067 on iteration 85
Loss of 0.013182825077950069 on iteration 86
Loss of 0.013182825077950067 on iteration 87
Loss of 0.013182825077950069 on iteration 88
Loss of 0.013182825077950067 on iteration 89
Loss of 0.013182825077950069 on iteration 90
Loss of 0.013182825077950067 on iteration 91
Loss of 0.013182825077950069 on iteration 92
Loss of 0.013182825077950067 on iteration 93
Loss of 0.013182825077950069 on iteration 94
Loss of 0.013182825077950067 on iteration 95
Loss of 0.013182825077950069 on iteration 96
Loss of 0.013182825077950067 on iteration 97
Loss of 0.013182825077950069 on iteration 98
Loss of 0.013182825077950067 on iteration 99
Loss of 0.013182825077950069 on iteration 100
p_init:[0.0621, 0.0467, 0.0351, -0.7881, -0.9748, -0.8404, -0.6985, -0.387, -0.8855, -0.6415, -0.5578, -0.2624]
p_optm:[0.0621, 0.0467, 0.0351, -0.7881, -0.9748, -0.8404, -0.6985, -0.387, -0.8855, -0.6415, -0.5578, -0.2624]