Using GPUs to train Physics-Informed Neural Networks (PINNs)

the 2-dimensional PDE:

\[∂_t u(x, y, t) = ∂^2_x u(x, y, t) + ∂^2_y u(x, y, t) \, ,\]

with the initial and boundary conditions:

\[\begin{align*} u(x, y, 0) &= e^{x+y} \cos(x + y) \, ,\\ u(0, y, t) &= e^{y} \cos(y + 4t) \, ,\\ u(2, y, t) &= e^{2+y} \cos(2 + y + 4t) \, ,\\ u(x, 0, t) &= e^{x} \cos(x + 4t) \, ,\\ u(x, 2, t) &= e^{x+2} \cos(x + 2 + 4t) \, , \end{align*}\]

on the space and time domain:

\[x \in [0, 2] \, ,\ y \in [0, 2] \, , \ t \in [0, 2] \, ,\]

with physics-informed neural networks. The only major difference from the CPU case is that we must ensure that our initial parameters for the neural network are on the GPU. If that is done, then the internal computations will all take place on the GPU. This is done by using the gpu function on the initial parameters, like:

using Lux, LuxCUDA, ComponentArrays, Random
const gpud = gpu_device()
inner = 25
chain = Chain(Dense(3, inner, σ), Dense(inner, inner, σ), Dense(inner, inner, σ),
    Dense(inner, inner, σ), Dense(inner, 1))
ps = Lux.setup(Random.default_rng(), chain)[1]
ps = ps |> ComponentArray |> gpud .|> Float64
ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [-0.15747177600860596 0.24700665473937988 -0.17961335182189941; 0.18584883213043213 -0.07854330539703369 0.8713706731796265; … ; -0.647541880607605 0.566851019859314 0.17208945751190186; 0.10922908782958984 -0.21965157985687256 0.7917412519454956], bias = [0.09723173081874847, -0.2712171971797943, -0.5282362699508667, 0.5202301740646362, -0.40492361783981323, -0.026740433648228645, 0.20084476470947266, -0.5438757538795471, 0.1294894963502884, 0.14906953275203705  …  -0.2256971001625061, 0.32213032245635986, -0.14753320813179016, -0.01162875909358263, 0.4960283041000366, 0.520872175693512, -0.1192302256822586, -0.17022325098514557, 0.032673951238393784, -0.2585638463497162]), layer_2 = (weight = [-0.09488435089588165 0.2999737560749054 … -0.327985018491745 -0.31878045201301575; 0.03468434140086174 0.006969987414777279 … 0.1868240237236023 -0.2870440185070038; … ; 0.0019017315935343504 0.1260058581829071 … 0.11222280561923981 -0.2008034735918045; -0.2804553508758545 -0.018712803721427917 … 0.2340458333492279 -0.014254603534936905], bias = [0.039938785135746, -0.1459197700023651, 0.11231382191181183, -0.12957827746868134, -0.09463846683502197, -0.13950607180595398, -0.038367677479982376, -0.03348133713006973, 0.19283032417297363, -0.13622453808784485  …  -0.06440231949090958, -0.1918848752975464, -0.016764570027589798, -0.14533618092536926, -0.008297515101730824, -0.192518949508667, -0.1285950392484665, -0.15603086352348328, -0.12992112338542938, -0.14187714457511902]), layer_3 = (weight = [0.041509341448545456 -0.20178192853927612 … -0.29315316677093506 0.0566522479057312; -0.1742417961359024 -0.24156826734542847 … -0.21101762354373932 -0.10332436859607697; … ; 0.011004663072526455 0.19427551329135895 … -0.3320528566837311 0.19668737053871155; 0.34584784507751465 0.037688370794057846 … 0.1016891598701477 0.007698643021285534], bias = [-0.14521443843841553, -0.14431174099445343, -0.03470446914434433, -0.10420458018779755, 0.18734070658683777, -0.02726573869585991, -0.03275416046380997, 0.18850412964820862, -0.054669834673404694, 0.03051884099841118  …  -0.16647610068321228, -0.18619652092456818, 0.14073264598846436, -0.010070609860122204, -0.1395239531993866, -0.0787605494260788, 0.16648876667022705, -0.13670405745506287, -0.0200441125780344, -0.0641426295042038]), layer_4 = (weight = [-0.16468805074691772 -0.24445761740207672 … -0.2217857390642166 -0.30997058749198914; -0.20905432105064392 0.26521390676498413 … -0.08805125951766968 0.19322925806045532; … ; -0.06150420010089874 -0.15921339392662048 … 0.3396996259689331 0.13675762712955475; -0.23976972699165344 0.04290776699781418 … -0.02535441145300865 0.21783222258090973], bias = [0.1305132359266281, 0.0022767067421227694, 0.07408668845891953, 0.08836676925420761, -0.004991746041923761, -0.040787745267152786, 0.04673285409808159, 0.16528531908988953, -0.06037449836730957, 0.10886359214782715  …  -0.1781819760799408, -0.194841668009758, -0.14001862704753876, -0.14102952182292938, 0.1943173110485077, -0.036057304590940475, 0.18695345520973206, -0.13262836635112762, -0.11959166824817657, 0.015515303239226341]), layer_5 = (weight = [0.27310070395469666 -0.18258146941661835 … 0.33848753571510315 -0.3023024797439575], bias = [0.10128605365753174]))

In total, this looks like:

using NeuralPDE, Lux, LuxCUDA, Random, ComponentArrays
using Optimization
using OptimizationOptimisers
import ModelingToolkit: Interval
using Plots
using Printf

@parameters t x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2
Dt = Differential(t)
t_min = 0.0
t_max = 2.0
x_min = 0.0
x_max = 2.0
y_min = 0.0
y_max = 2.0

# 2D PDE
eq = Dt(u(t, x, y)) ~ Dxx(u(t, x, y)) + Dyy(u(t, x, y))

analytic_sol_func(t, x, y) = exp(x + y) * cos(x + y + 4t)
# Initial and boundary conditions
bcs = [u(t_min, x, y) ~ analytic_sol_func(t_min, x, y),
    u(t, x_min, y) ~ analytic_sol_func(t, x_min, y),
    u(t, x_max, y) ~ analytic_sol_func(t, x_max, y),
    u(t, x, y_min) ~ analytic_sol_func(t, x, y_min),
    u(t, x, y_max) ~ analytic_sol_func(t, x, y_max)]

# Space and time domains
domains = [t ∈ Interval(t_min, t_max),
    x ∈ Interval(x_min, x_max),
    y ∈ Interval(y_min, y_max)]

# Neural network
inner = 25
chain = Chain(Dense(3, inner, σ), Dense(inner, inner, σ), Dense(inner, inner, σ),
    Dense(inner, inner, σ), Dense(inner, 1))

strategy = QuasiRandomTraining(100)
ps = Lux.setup(Random.default_rng(), chain)[1]
ps = ps |> ComponentArray |> gpud .|> Float64
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

@named pde_system = PDESystem(eq, bcs, domains, [t, x, y], [u(t, x, y)])
prob = discretize(pde_system, discretization)
symprob = symbolic_discretize(pde_system, discretization)

callback = function (p, l)
    println("Current loss is: $l")
    return false
end

res = Optimization.solve(prob, OptimizationOptimisers.Adam(1e-2); maxiters = 2500)
retcode: Default
u: ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [1.628618271691005 -1.0236056561136626 -0.825503347365049; -1.325001745981892 -0.5170324820321547 -0.6423900880048403; … ; -1.8769429711802414 -0.7113847233086531 -0.4228157168964858; -1.6395374577664668 -0.5421340980002323 -0.5658403752708485], bias = [1.3477526928250232, 0.7242553426119427, -1.1454681712629242, 1.128742331400384, -3.040428648792343, 0.4737792977936717, 0.8361364196190243, -0.37351269059050984, 0.16754515993411143, -0.9240114055490724  …  -1.6906797493981125, -0.19021908404035476, 0.9239904497598128, 1.1949359026263178, 0.3420964164179462, 1.7622854293104504, -0.8919738403024335, -0.033073549338966744, 1.5562989752126974, 0.9824776024104329]), layer_2 = (weight = [0.0235603948796308 1.053146086875285 … 1.6505601848998546 0.8712046563986013; -2.3431567304047496 -1.0338238867323195 … -0.9647044137770956 -1.126211997151654; … ; 1.4460147473920268 0.8681271832565168 … 0.7786154222178651 0.46307027138002876; 0.18700315839853915 -0.5693662337836158 … -0.10982389040020275 -0.25710737270321105], bias = [-0.10730521318779848, -0.13371787625039463, -0.007659838595290272, 0.08944064132719645, -0.10797614466546822, 0.05306496695918892, -0.01800117053082672, 0.09568826007672587, -0.05417053437067145, 0.1888393628004252  …  0.355080036016173, -0.1375392362955204, 0.05033692170800919, -0.4127297994421522, 0.003271363872455259, -0.2088541557326349, -0.12216835381700515, 0.022546245715159185, 0.11057678477677035, -0.10217736550910965]), layer_3 = (weight = [-0.32850569938570007 2.514569559201959 … -1.653936536538192 -0.9507553914192106; -1.2725186362522007 0.7493116053730294 … -1.8909497534893052 -0.7432487783366728; … ; -0.6363895763287531 1.8229413170099626 … -1.7878181710206185 -1.051816572272188; -1.5158344118357985 1.04172182757427 … -2.062047500616772 -0.44255435228020545], bias = [-0.4832770675780699, -0.3012539965567003, 0.02129377309436023, -0.32416276016667495, -0.1070255007228937, 0.3417483162831266, -0.5561280620409067, -0.5049461868751864, -0.5922658828850311, -0.0668355161727004  …  -0.3810601570421281, 0.45099630490798476, -0.27214771186614056, -0.0424703376007265, -0.12568677357780897, -0.6148399486127368, -0.38706085812612845, -0.2390157253996015, -0.18029872056387578, -0.25352755142596595]), layer_4 = (weight = [1.5915082247275434 0.6608600667747856 … 1.554273410280527 0.780589655123078; 0.6799198142286178 1.1417531915940686 … 1.1446125787648325 1.3291284967231756; … ; 1.6673674660776825 0.7594604868811006 … 1.0659755285587664 0.997413723346971; -1.9175517701791833 -0.29778996591131773 … -0.6570658284577573 -0.08241356552250806], bias = [-1.1465952919689628, -0.7943949007289886, -1.1646669676829284, 0.679232118734539, -0.619830313490605, 0.32937921653436447, -0.30476713333401234, -0.6376560019266607, 0.6732015290973096, -0.7579369224262553  …  -0.6513442481786744, -0.02849668610992922, -0.3856603548191902, -0.03571518678920574, 0.11472807606087952, -0.9816085504265238, -0.41605197034811847, 0.967582654417059, -1.058277308107995, 0.9204353136349308]), layer_5 = (weight = [-3.1783534931630744 -2.353949402937573 … -3.123489229258173 5.241638019643616], bias = [2.1032767258013543]))

We then use the remake function to rebuild the PDE problem to start a new optimization at the optimized parameters, and continue with a lower learning rate:

prob = remake(prob, u0 = res.u)
res = Optimization.solve(
    prob, OptimizationOptimisers.Adam(1e-3); callback = callback, maxiters = 2500)
retcode: Default
u: ComponentArrays.ComponentVector{Float64, CUDA.CuArray{Float64, 1, CUDA.DeviceMemory}, Tuple{ComponentArrays.Axis{(layer_1 = ViewAxis(1:100, Axis(weight = ViewAxis(1:75, ShapedAxis((25, 3))), bias = ViewAxis(76:100, Shaped1DAxis((25,))))), layer_2 = ViewAxis(101:750, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_3 = ViewAxis(751:1400, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_4 = ViewAxis(1401:2050, Axis(weight = ViewAxis(1:625, ShapedAxis((25, 25))), bias = ViewAxis(626:650, Shaped1DAxis((25,))))), layer_5 = ViewAxis(2051:2076, Axis(weight = ViewAxis(1:25, ShapedAxis((1, 25))), bias = ViewAxis(26:26, Shaped1DAxis((1,))))))}}}(layer_1 = (weight = [1.6778292509875963 -0.9947279459406194 -0.799483488182213; -1.3525004842468626 -0.5138046812916414 -0.626103420666353; … ; -1.8806433681837833 -0.6857846749466058 -0.43187039902967794; -1.638316863712306 -0.529376755333831 -0.5524362635343743], bias = [1.3128871475772248, 0.7587868918036449, -1.0935396728702302, 1.1436656679943504, -3.0754235566073054, 0.4840077117953924, 0.8495786340754967, -0.4530003497929537, 0.1677947809608984, -0.9099656421606362  …  -1.6659494084725748, -0.19216962863032874, 0.9486841870832098, 1.200219346706729, 0.3696877734089161, 1.8046372389616265, -0.8725587901557489, -0.050302483395952474, 1.6030771503737513, 1.0287856138032854]), layer_2 = (weight = [-0.09972106358896525 1.0532011419127516 … 1.6485090169416714 0.8777714562179583; -2.368561736817559 -1.131072748446975 … -1.0694081246418454 -1.22520767456896; … ; 1.460178539971638 0.8204123646085546 … 0.7151483465952367 0.40311750808950786; 0.18504277786061268 -0.6149504066392703 … -0.16512662098616762 -0.30944409023292824], bias = [-0.1471984695875208, -0.11977007245195025, 0.0059144364275187985, 0.09646204690014169, -0.11922665796103625, 0.07059221461690712, 0.01205536852411003, 0.10338423006879849, -0.05679336923562684, 0.2076546213691613  …  0.3817293209583069, -0.13089263216189614, 0.07110360837172487, -0.3761053272775026, -9.06039314236305e-5, -0.24399651954494025, -0.10829846282764409, 0.0242208347156881, 0.11135651899548707, -0.10891662816508667]), layer_3 = (weight = [-0.32299613729434734 2.582749168511958 … -1.7038058469389101 -0.9252952260841824; -1.3565638983900943 0.6134211306653473 … -1.8591321986572302 -0.7433152184662394; … ; -0.607476340211386 1.879783668938865 … -1.7312392640835812 -1.0755387766201854; -1.619509850082069 0.8807702758004571 … -2.059729836430782 -0.44649329297151963], bias = [-0.4777612244658963, -0.30347872366379536, 0.03292677401030682, -0.31648070128151656, -0.09377431367862517, 0.359240872663168, -0.5534204714459046, -0.502650092647931, -0.597929905649346, -0.0595130351537359  …  -0.38589448029040946, 0.3770238438451814, -0.27361421423725324, 0.008241898699951764, -0.13411998859246416, -0.6259624347612607, -0.36028614604393244, -0.27782500718522335, -0.17368394982914112, -0.2653754366608459]), layer_4 = (weight = [1.6461490770442793 0.6525836210110545 … 1.588145842231959 0.7696326902754942; 0.6396079138768718 1.206200161122021 … 1.164862288977325 1.388367000345782; … ; 1.7103530314672213 0.7505545407830521 … 1.0970324384444692 0.986609955479368; -1.889023788992867 -0.20173191079457764 … -0.5875890203839074 0.013330211383438208], bias = [-1.1495069647295508, -0.8700848473534365, -1.1687023422341654, 0.7158559874454046, -0.6959751611410016, 0.29369858460470416, -0.3574923658253903, -0.6424261889696732, 0.71767766447915, -0.7936481700932517  …  -0.705290445685833, 0.04580503386380758, -0.3793791458428365, 0.030835208305864736, 0.08662669352643085, -1.0140922038136915, -0.4552659672858387, 0.9714474255781383, -1.0724561174727676, 0.9352055790828426]), layer_5 = (weight = [-3.3698802775854224 -2.5535124808087906 … -3.298961768290909 5.60478774118867], bias = [2.322761958187865]))

Finally, we inspect the solution:

phi = discretization.phi
ts, xs, ys = [infimum(d.domain):0.1:supremum(d.domain) for d in domains]
u_real = [analytic_sol_func(t, x, y) for t in ts for x in xs for y in ys]
u_predict = [first(Array(phi([t, x, y], res.u))) for t in ts for x in xs for y in ys]

function plot_(res)
    # Animate
    anim = @animate for (i, t) in enumerate(0:0.05:t_max)
        @info "Animating frame $i..."
        u_real = reshape([analytic_sol_func(t, x, y) for x in xs for y in ys],
            (length(xs), length(ys)))
        u_predict = reshape([Array(phi([t, x, y], res.u))[1] for x in xs for y in ys],
            length(xs), length(ys))
        u_error = abs.(u_predict .- u_real)
        title = @sprintf("predict, t = %.3f", t)
        p1 = plot(xs, ys, u_predict, st = :surface, label = "", title = title)
        title = @sprintf("real")
        p2 = plot(xs, ys, u_real, st = :surface, label = "", title = title)
        title = @sprintf("error")
        p3 = plot(xs, ys, u_error, st = :contourf, label = "", title = title)
        plot(p1, p2, p3)
    end
    gif(anim, "3pde.gif", fps = 10)
end

plot_(res)
Example block output

Performance benchmarks

Here are some performance benchmarks for 2d-pde with various number of input points and the number of neurons in the hidden layer, measuring the time for 100 iterations. Comparing runtime with GPU and CPU.

julia> CUDA.device()

image