Data Iterators and Minibatching

It is possible to solve an optimization problem with batches using a MLUtils.DataLoader, which is passed to Optimization.solve with ncycles. All data for the batches need to be passed as a tuple of vectors.

Note

This example uses the OptimizationOptimisers.jl package. See the Optimisers.jl page for details on the installation and usage.

using Lux, Optimization, OptimizationOptimisers, OrdinaryDiffEq, SciMLSensitivity, MLUtils,
      Random, ComponentArrays

function newtons_cooling(du, u, p, t)
    temp = u[1]
    k, temp_m = p
    du[1] = dT = -k * (temp - temp_m)
end

function true_sol(du, u, p, t)
    true_p = [log(2) / 8.0, 100.0]
    newtons_cooling(du, u, true_p, t)
end

model = Chain(Dense(1, 32, tanh), Dense(32, 1))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps)
smodel = StatefulLuxLayer{true}(model, nothing, st)

function dudt_(u, p, t)
    smodel(u, p) .* u
end

function callback(state, l) #callback function to observe training
    display(l)
    return false
end

u0 = Float32[200.0]
datasize = 30
tspan = (0.0f0, 1.5f0)

t = range(tspan[1], tspan[2], length = datasize)
true_prob = ODEProblem(true_sol, u0, tspan)
ode_data = Array(solve(true_prob, Tsit5(), saveat = t))

prob = ODEProblem{false}(dudt_, u0, tspan, ps_ca)

function predict_adjoint(fullp, time_batch)
    Array(solve(prob, Tsit5(), p = fullp, saveat = time_batch))
end

function loss_adjoint(fullp, data)
    batch, time_batch = data
    pred = predict_adjoint(fullp, time_batch)
    sum(abs2, batch .- pred)
end

k = 10
# Pass the data for the batches as separate vectors wrapped in a tuple
train_loader = MLUtils.DataLoader((ode_data, t), batchsize = k)

numEpochs = 300
l1 = loss_adjoint(ps_ca, train_loader.data)[1]

optfun = OptimizationFunction(
    loss_adjoint,
    Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, ps_ca, train_loader)
using IterTools: ncycle
res1 = Optimization.solve(
    optprob, Optimisers.ADAM(0.05); callback = callback, epochs = 1000)
retcode: Default
u: ComponentVector{Float32}(layer_1 = (weight = Float32[1.2178907; 1.3415166; … ; 0.711676; -1.0222925;;], bias = Float32[6.3584704, 4.9140496, 2.442036, 6.4238577, 4.398076, 5.540741, 3.5461493, -1.895345, -3.4434607, 1.445039  …  -4.0448446, 1.6177058, 9.765238, -2.1369996, 2.2632813, -0.67093873, -6.3219123, 0.59514165, 0.9987828, -2.6149142]), layer_2 = (weight = Float32[0.55648977 0.24586277 … -0.054084256 -0.04101506], bias = Float32[-0.8686419]))