Solving ODEs with Physics-Informed Neural Networks (PINNs)
It is highly recommended you first read the solving ordinary differential equations with DifferentialEquations.jl tutorial before reading this tutorial.
This tutorial is an introduction to using physics-informed neural networks (PINNs) for solving ordinary differential equations (ODEs). In contrast to the later parts of this documentation which use the symbolic interface, here we will focus on the simplified NNODE
which uses the ODEProblem
specification for the ODE.
Mathematically, the ODEProblem
defines a problem:
\[u' = f(u,p,t)\]
for $t \in (t_0,t_f)$ with an initial condition $u(t_0) = u_0$. With physics-informed neural networks, we choose a neural network architecture NN
to represent the solution u
and seek parameters p
such that NN' = f(NN,p,t)
for all points in the domain. When this is satisfied sufficiently closely, then NN
is thus a solution to the differential equation.
Solving an ODE with NNODE
Let's solve a simple ODE:
\[u' = \cos(2\pi t)\]
for $t \in (0,1)$ and $u_0 = 0$ with NNODE
. First, we define an ODEProblem
as we would for defining an ODE using DifferentialEquations.jl interface. This looks like:
using NeuralPDE
linear(u, p, t) = cos(t * 2 * pi)
tspan = (0.0, 1.0)
u0 = 0.0
prob = ODEProblem(linear, u0, tspan)
ODEProblem with uType Float64 and tType Float64. In-place: false
Non-trivial mass matrix: false
timespan: (0.0, 1.0)
u0: 0.0
Now, to define the NNODE
solver, we must choose a neural network architecture. To do this, we will use the Lux.jl to define a multilayer perceptron (MLP) with one hidden layer of 5 nodes and a sigmoid activation function. This looks like:
using Lux, Random
rng = Random.default_rng()
Random.seed!(rng, 0)
chain = Chain(Dense(1, 5, σ), Dense(5, 1))
ps, st = Lux.setup(rng, chain) |> Lux.f64
((layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.4946011304855347; -1.0391809940338135;;], bias = [-0.458548903465271, -0.8280583620071411, -0.38509929180145264, 0.32322537899017334, -0.32623517513275146]), layer_2 = (weight = [0.5656673908233643 -0.605137288570404 … 0.3129439055919647 0.22128699719905853], bias = [-0.11007555574178696])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Now we must choose an optimizer to define the NNODE
solver. A common choice is Adam
, with a tunable rate, which we will set to 0.1
. In general, this rate parameter should be decreased if the solver's loss tends to be unsteady (sometimes rise “too much”), but should be as large as possible for efficiency. We use Adam
from OptimizationOptimisers. Thus, the definition of the NNODE
solver is as follows:
using OptimizationOptimisers
opt = Adam(0.1)
alg = NNODE(chain, opt, init_params = ps)
NNODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Nothing, Bool, Nothing, Bool, Nothing, Base.Pairs{Symbol, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_2::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}}, Tuple{Symbol}, @NamedTuple{init_params::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_2::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}}}}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 5, σ), layer_2 = Dense(5 => 1)), nothing), Optimisers.Adam(eta=0.1, beta=(0.9, 0.999), epsilon=1.0e-8), nothing, false, true, nothing, false, nothing, Base.Pairs{Symbol, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_2::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}}, Tuple{Symbol}, @NamedTuple{init_params::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_2::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}}}}(:init_params => (layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.4946011304855347; -1.0391809940338135;;], bias = [-0.458548903465271, -0.8280583620071411, -0.38509929180145264, 0.32322537899017334, -0.32623517513275146]), layer_2 = (weight = [0.5656673908233643 -0.605137288570404 … 0.3129439055919647 0.22128699719905853], bias = [-0.11007555574178696]))))
Once these pieces are together, we call solve
just like with any other ODEProblem
. Let's turn on verbose
so we can see the loss over time during the training process:
sol = solve(prob, alg, verbose = true, maxiters = 2000, saveat = 0.01)
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0:0.01:1.0
u: 101-element Vector{Float64}:
0.0
0.01112240122398573
0.021904046987844532
0.032335514822810725
0.04240720278041413
0.05210932899262939
0.061431934124297975
0.07036488731528198
0.0788978962481144
0.08702052200797418
⋮
-0.0766297392144032
-0.0680232830374271
-0.05903887813722844
-0.04969270809723622
-0.04000108184238552
-0.02998037694228652
-0.019646985759390594
-0.009017264484199386
0.0018925149093164814
Now let's compare the predictions from the learned network with the ground truth which we can obtain by numerically solving the ODE.
using OrdinaryDiffEq, Plots
ground_truth = solve(prob, Tsit5(), saveat = 0.01)
plot(ground_truth, label = "ground truth")
plot!(sol.t, sol.u, label = "pred")
And that's it: the neural network solution was computed by training the neural network and returned in the standard DifferentialEquations.jl ODESolution
format. For more information on handling the solution, consult here.