Getting Started with GPU-Accelerated Differential Equations in Julia
The two ways to accelerate ODE solvers with GPUs
There are two very different ways that one can accelerate an ODE solution with GPUs. There is one case where u
is very big and f
is very expensive but very structured, and you use GPUs to accelerate the computation of said f
. The other use case is where u
is very small, but you want to solve the ODE f
over many different initial conditions (u0
) or parameters p
. In that case, you can use GPUs to parallelize over different parameters and initial conditions. In other words:
Type of Problem | SciML Solution |
---|---|
Accelerate a big ODE | Use CUDA.jl's CuArray as u0 |
Solve the same ODE with many u0 and p | Use DiffEqGPU.jl'sEnsembleGPUArray and EnsembleGPUKernel |
Supported GPUs
SciML's GPU support extends to a wide array of hardware, including:
GPU Manufacturer | GPU Kernel Language | Julia Support Package | Backend Type |
---|---|---|---|
NVIDIA | CUDA | CUDA.jl | CUDA.CUDABackend() |
AMD | ROCm | AMDGPU.jl | AMDGPU.ROCBackend() |
Intel | OneAPI | OneAPI.jl | oneAPI.oneAPIBackend() |
Apple (M-Series) | Metal | Metal.jl | Metal.MetalBackend() |
For this tutorial we will demonstrate the CUDA backend for NVIDIA GPUs, though any of the other GPUs can be used by simply swapping out the backend
choice.
Simple Example of Within-Method GPU Parallelism
The following is a quick and dirty example of doing within-method GPU parallelism. Let's say we had a simple but large ODE with many linear algebra or map/broadcast operations:
using OrdinaryDiffEq, LinearAlgebra
u0 = rand(1000)
A = randn(1000, 1000)
f(du, u, p, t) = mul!(du, A, u)
prob = ODEProblem(f, u0, (0.0, 1.0))
sol = solve(prob, Tsit5())
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 50-element Vector{Float64}:
0.0
0.0015478761855679305
0.00648580863398825
0.013842416823521709
0.022391643841718323
0.03325242939710622
0.04561502822320631
0.05901456341425282
0.07299580669091357
0.087920728422522
⋮
0.8051958359073149
0.8303464269733638
0.855220989124135
0.8807960944818966
0.906252317657025
0.9322591011728141
0.9588072465756187
0.9814855255785365
1.0
u: 50-element Vector{Vector{Float64}}:
[0.6077993951026277, 0.7492427325037554, 0.24030948236125937, 0.7876203548496596, 0.9790896488924092, 0.39466961581632154, 0.8624005981887551, 0.44259694848939535, 0.8924620923795857, 0.02291634363312045 … 0.9575521471392613, 0.8803300650478935, 0.5201892514317958, 0.9654442939425254, 0.36768574375921526, 0.6396059294669751, 0.901286174640774, 0.685994963053012, 0.27133784540119166, 0.3416276233225435]
[0.6388217639528583, 0.7550178758102967, 0.23851886858256294, 0.798157641408362, 0.9991230598196218, 0.3979320485800568, 0.8199383091064789, 0.41758245795126836, 0.8625227443916994, 0.025110186525423902 … 0.9928820975413687, 0.9188273924595998, 0.5424605821203632, 0.9989534800028482, 0.34863444594913573, 0.6129502249468205, 0.9449488603002396, 0.6859185020749737, 0.2372386250598567, 0.35672114399690125]
[0.7604156908506836, 0.7674550530644948, 0.23847720394552546, 0.8294452834399094, 1.0804114475300428, 0.4060703085606619, 0.6888640957065337, 0.3343223619420136, 0.7636075351231503, 0.03724209063173313 … 1.1052247509058286, 1.036939414566552, 0.590016004446043, 1.0983914945208484, 0.2796647248494942, 0.5103533632849603, 1.0769710826307772, 0.6990725769164804, 0.11904315092961476, 0.41837714984868773]
[1.0053064029594083, 0.7754094885668386, 0.25381901276200136, 0.8727080609515534, 1.2506728414728154, 0.41008832071754375, 0.5034713184150569, 0.19682926309607676, 0.6106487534432536, 0.0715403130063657 … 1.2729780841888372, 1.1972501755009486, 0.592849788111663, 1.2227273732202086, 0.15679756465134848, 0.31151327457108635, 1.2523068815326168, 0.7558685670894918, -0.08699326886042917, 0.5490906533963484]
[1.3834519386374002, 0.7837838905457066, 0.2921861003267462, 0.9236774476427589, 1.523181116666089, 0.3997205500175946, 0.2955086213223752, 0.009936882166138456, 0.4342715986146158, 0.1385042704272604 … 1.472789062410583, 1.35732431035077, 0.4905464950458207, 1.3244911127993384, -0.007538677655676027, 0.02053802821052491, 1.4211369646984968, 0.8788149364361183, -0.37901171481451906, 0.7632432858301461]
[1.9996969576951213, 0.8237528265358053, 0.36398372741091023, 0.9962401787667565, 1.989912694960366, 0.36144105137963456, 0.022407587889207218, -0.27997753933805464, 0.2327477796056976, 0.26713486066052694 … 1.7455467459548604, 1.5225296611969998, 0.19108356494503692, 1.3716391061004376, -0.22617082700556285, -0.4215580154598785, 1.5721740596708869, 1.1317783194275162, -0.8466567366911893, 1.1446223361289465]
[2.863933540547281, 0.9572029196534585, 0.4555527992887772, 1.0899709291588249, 2.7048696099110843, 0.29049292086911155, -0.34906580016617744, -0.693400928174069, 0.06481141247194418, 0.46686840196762325 … 2.1045157539194497, 1.6791929519403357, -0.3901551134623551, 1.2835669148227058, -0.4352059108610356, -0.989436360202414, 1.6318025361215416, 1.576996213459419, -1.5362375063508247, 1.7571106395927933]
[3.9567238329930747, 1.2660033518194682, 0.5242731272720043, 1.1857324242261225, 3.7570311466950703, 0.20741445738859104, -0.9259994793290318, -1.2517796356900284, -0.015227556801584132, 0.7251706951816024 … 2.585605057722956, 1.8683327115008561, -1.3278904296850627, 0.9695710114192313, -0.5103261403065463, -1.64077600938374, 1.4987013015224453, 2.300299770684521, -2.5125883888250757, 2.6902049052817136]
[5.212509354712602, 1.8219352088265837, 0.4927626775971167, 1.2217430428676943, 5.287165498600609, 0.1684122496688786, -1.9099209691049166, -1.9601273540869935, 0.0068975008072079645, 0.9971546368338595 … 3.2216661701781475, 2.1984737250394413, -2.688231906006248, 0.3335252094962419, -0.2468236831160612, -2.330373319496057, 1.009343445746315, 3.401298931576341, -3.8451962977980796, 4.05284494696332]
[6.597347622680693, 2.707730246048976, 0.23350582843453324, 1.0593978321294257, 7.657144179123195, 0.2683507530696783, -3.743003371209908, -2.850509989437023, 0.049158725518221195, 1.2230149083895065 … 4.053428323483684, 2.895916786228373, -4.661277250011878, -0.7774081803390297, 0.7082635349031607, -3.0958792890184026, -0.17535764379219534, 5.073413283353828, -5.723384836635318, 6.104970574138457]
⋮
[8.764780066224067e8, 1.896679198856089e10, -2.8791238716648205e10, -1.7782386084904808e10, -1.8629943618088207e10, 1.6352128582481981e10, -2.013634606630096e10, 9.378994001048164e9, 3.250327838955565e10, -2.4188887993906605e10 … -2.0476771292206867e10, -8.801964896424888e8, -6.591287844944542e9, -1.0345209058982435e10, 1.5391788719636024e10, 3.4634498633198557e9, -1.5679486188023586e10, -1.213285938302461e9, -5.438045934829378e9, -1.1518680809893606e10]
[-1.6008137867406118e9, 4.436306900109785e10, -7.133880987032213e10, -3.267709320324597e10, -4.100510774819068e10, 3.507435874752564e10, -4.109720486029508e10, 3.1143379966602604e10, 7.211345057356902e10, -4.872831003357185e10 … -4.42749636134423e10, -6.410373034463417e9, -1.1352339308446444e10, -2.023518090277868e10, 2.6040479824811573e10, 5.334322576628781e9, -2.8138574175721382e10, -7.436936219045354e9, -1.536871597666749e10, -3.2057632585082157e10]
[-9.835055910048376e9, 1.0096923135993665e11, -1.6803850623723938e11, -5.5607741471781364e10, -8.78295963571415e10, 7.405439571856248e10, -8.166362597352007e10, 9.049337091938423e10, 1.570788747232032e11, -9.518209685019202e10 … -9.493617992720575e10, -2.421290453960685e10, -1.8516986299478905e10, -3.8494069151425415e10, 3.99757041361387e10, 6.87526800357321e9, -4.7974744954544e10, -2.3632620006238934e10, -3.703165440635093e10, -8.192947141839993e10]
[-3.3092087238827965e10, 2.3126596324069577e11, -3.9204980268073303e11, -8.584692704455705e10, -1.8883942613339667e11, 1.5867585142554782e11, -1.622942803308893e11, 2.509502540181117e11, 3.453400326746502e11, -1.842867903222202e11 … -2.0802521888540536e11, -7.667313323238606e10, -2.9094192895382225e10, -7.28300164926125e10, 5.1896106956949196e10, 4.975465607817764e9, -7.821226403444156e10, -6.299240672838114e10, -8.268893822661942e10, -2.0338509542303232e11]
[-9.044451266344685e10, 5.194873853485014e11, -8.856619721051328e11, -1.0532240735501501e11, -3.9829126738653033e11, 3.3694374364964685e11, -3.1529565367391327e11, 6.557717021798854e11, 7.451969398253005e11, -3.4395225415398e11 … -4.5426926671591754e11, -2.1497354388960907e11, -4.345809137183459e10, -1.3344021671751703e11, 3.82089145460287e10, -1.0247641642202696e10, -1.1689857185952869e11, -1.485384955832559e11, -1.692423883829415e11, -4.807822887195532e11]
[-2.267247062902883e11, 1.1697967852823088e12, -1.9858339572089553e12, -4.6619944134888565e10, -8.424662306146082e11, 7.233687913797831e11, -6.084511182028693e11, 1.6752083134825513e12, 1.606733596390494e12, -6.224138541052672e11 … -1.0090562928663315e12, -5.698888225018145e11, -6.314178299136891e10, -2.3762883304224597e11, -7.506665077293513e10, -6.784138030055833e10, -1.527465096926599e11, -3.263220885460888e11, -3.2348788757574664e11, -1.1128434354892239e12]
[-5.396515870699504e11, 2.639962114912662e12, -4.4247767716484375e12, 3.1226918592045764e11, -1.7919549974816555e12, 1.5683319551611003e12, -1.1619329367779023e12, 4.2045624098930684e12, 3.4505821619469956e12, -1.068930513344212e12 … -2.278776586292099e12, -1.4535976419461814e12, -9.191225782225375e10, -3.997871976301755e11, -5.0405311840302344e11, -2.4812008650601465e11, -1.4285096377187524e11, -6.705233099948209e11, -5.661276078077296e11, -2.5238877370180205e12]
[-1.0888619612530238e12, 5.2329957377794e12, -8.627492497950489e12, 1.298170123182634e12, -3.399601510927614e12, 3.019585725089798e12, -1.9742412504288433e12, 8.99387505490768e12, 6.516214041688621e12, -1.5666202424681304e12 … -4.568349218894518e12, -3.1201715963489346e12, -1.294639298214924e11, -5.640713802356516e11, -1.517735827705361e12, -6.418491697608444e11, -1.79862808914855e10, -1.1562044987971702e12, -8.153635928282788e11, -4.933624652748248e12]
[-1.895812881330054e12, 9.083221209112744e12, -1.4725653427386322e13, 3.190050714478752e12, -5.732181029828195e12, 5.128172678169889e12, -2.9866649838152583e12, 1.646762355726429e13, 1.08148198792512e13, -1.950916864883442e12 … -8.056223336936267e12, -5.708294163123032e12, -1.7145024869188217e11, -6.4776831302855e11, -3.3198320033962886e12, -1.314032554112901e12, 2.6037936347618207e11, -1.7023153937138574e12, -9.653660834081237e11, -8.357515822381224e12]
Translating this to a GPU-based solve of the ODE simply requires moving the arrays for the initial condition, parameters, and caches to the GPU. This looks like:
using OrdinaryDiffEq, CUDA, LinearAlgebra
u0 = cu(rand(1000))
A = cu(randn(1000, 1000))
f(du, u, p, t) = mul!(du, A, u)
prob = ODEProblem(f, u0, (0.0f0, 1.0f0)) # Float32 is better on GPUs!
sol = solve(prob, Tsit5())
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 50-element Vector{Float32}:
0.0
0.0012238872
0.0035123606
0.007574204
0.0141317025
0.022441624
0.032786697
0.044370756
0.057975534
0.072406396
⋮
0.8123368
0.83820784
0.8635944
0.88990456
0.914737
0.9414715
0.9677188
0.9943564
1.0
u: 50-element Vector{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}:
Float32[0.07975651, 0.79987717, 0.15339984, 0.44715297, 0.057252217, 0.017083757, 0.397226, 0.37008578, 0.61970097, 0.31732622 … 0.7707357, 0.51477575, 0.9488768, 0.8984697, 0.6454726, 0.23771797, 0.015754933, 0.8835519, 0.14337541, 0.8914004]
Float32[0.091294326, 0.79488426, 0.16077894, 0.46531612, 0.042521283, 0.022484621, 0.41546035, 0.35169297, 0.6409096, 0.29587573 … 0.7650729, 0.5296531, 0.95839334, 0.89492947, 0.6625467, 0.20936961, 0.026015624, 0.89660525, 0.1544463, 0.92302203]
Float32[0.114186615, 0.78224206, 0.17290637, 0.5025619, 0.011553634, 0.03309998, 0.44805008, 0.3156221, 0.68108934, 0.25510186 … 0.7512275, 0.56094134, 0.97495365, 0.8904496, 0.69106406, 0.15726943, 0.04654314, 0.9167866, 0.17067073, 0.983254]
Float32[0.15831761, 0.7495799, 0.18813594, 0.5789597, -0.054026656, 0.053368658, 0.5013973, 0.24553536, 0.7551918, 0.18121126 … 0.716143, 0.6275383, 1.0009681, 0.8888348, 0.731154, 0.06826759, 0.087696426, 0.9380107, 0.18462086, 1.0938436]
Float32[0.23556055, 0.670697, 0.19117078, 0.72898, -0.1863768, 0.08889933, 0.57666373, 0.11282618, 0.88755006, 0.06080608 … 0.63031274, 0.76434344, 1.036005, 0.9011437, 0.7695284, -0.06339563, 0.16844058, 0.92819405, 0.16427559, 1.2824914]
Float32[0.3349104, 0.5268481, 0.14399192, 0.9637455, -0.39342663, 0.13628027, 0.6551657, -0.09844001, 1.0917639, -0.08507974 … 0.4653609, 0.9873055, 1.0716101, 0.9383143, 0.77739954, -0.20181294, 0.29899853, 0.82583314, 0.05298618, 1.5380032]
Float32[0.44189632, 0.28092495, -0.023511056, 1.3204851, -0.6906421, 0.19388293, 0.7288427, -0.44552433, 1.4329518, -0.23460549 … 0.15584946, 1.3353711, 1.1035643, 1.0097601, 0.7374969, -0.31466138, 0.50858957, 0.5353016, -0.24510631, 1.8727479]
Float32[0.5123232, -0.087102264, -0.40441057, 1.7962462, -1.0271978, 0.24979705, 0.7777053, -0.96750784, 1.9790694, -0.31548086 … -0.370383, 1.8042051, 1.1133202, 1.1107806, 0.6521072, -0.3396446, 0.8058779, -0.036394976, -0.8423278, 2.2423425]
Float32[0.4969156, -0.66848576, -1.1984875, 2.440554, -1.3213347, 0.29623276, 0.7730812, -1.7838187, 2.9263222, -0.1956353 … -1.3362389, 2.4380078, 1.0525147, 1.2359463, 0.54094386, -0.20112997, 1.2393956, -1.08271, -2.0068703, 2.5991454]
Float32[0.35627759, -1.528888, -2.5688074, 3.1885526, -1.3208227, 0.31716573, 0.64322233, -2.894179, 4.417735, 0.37391642 … -2.9813879, 3.1795447, 0.8228836, 1.3337202, 0.4759468, 0.16606079, 1.8130995, -2.6827354, -3.964981, 2.7478895]
⋮
Float32[-2.9611823f10, -7.206831f9, 4.877548f9, -7.6340726f9, -1.9562246f10, -3.07769f10, 4.3311944f10, 1.0876393f10, -1.0741676f10, -8.717373f9 … 6.465729f9, 3.471865f10, 1.8261494f9, 1.8211381f10, -1.0422555f10, -9.655502f9, 8.361296f9, -1.8834127f10, 9.820999f9, -2.0596848f8]
Float32[-6.8715753f10, -1.0922034f10, 1.5217509f10, -2.1514357f10, -4.1637315f10, -6.724589f10, 8.992277f10, 2.061411f10, -3.2678451f10, -1.4891718f10 … 1.1013754f10, 8.7965204f10, 6.329102f9, 3.817686f10, -2.9388624f10, -1.6565554f10, 1.774597f10, -4.7151383f10, 2.9557037f10, -5.2396816f8]
Float32[-1.5629435f11, -1.0514547f10, 4.1878036f10, -5.335225f10, -8.608189f10, -1.4399508f11, 1.7814594f11, 3.7791662f10, -8.996295f10, -2.375732f10 … 1.6613748f10, 2.1484424f11, 1.583286f10, 7.604724f10, -7.630109f10, -2.6657223f10, 3.8433223f10, -1.1061219f11, 8.115341f10, 2.1229366f9]
Float32[-3.6276217f11, 1.1681056f10, 1.1311899f11, -1.2594249f11, -1.7937478f11, -3.1575586f11, 3.4853306f11, 6.8063035f10, -2.4367129f11, -3.683709f10 … 2.068339f10, 5.3148172f11, 3.2616524f10, 1.4841466f11, -1.9763138f11, -4.1359827f10, 8.93784f10, -2.5506721f11, 2.1882511f11, 2.126189f10]
Float32[-7.9303154f11, 1.0146961f11, 2.8045456f11, -2.6570852f11, -3.5108264f11, -6.6031295f11, 6.308077f11, 1.1076454f11, -6.02342f11, -5.629563f10 … 1.4491229f10, 1.2276578f12, 4.8370266f10, 2.6485626f11, -4.7574303f11, -6.0631642f10, 2.0548513f11, -5.365526f11, 5.359557f11, 9.654057f10]
Float32[-1.8102049f12, 4.2070386f11, 7.303765f11, -5.5667196f11, -7.021809f11, -1.4542008f12, 1.1302492f12, 1.6000944f11, -1.5495299f12, -1.0247669f11 … -2.6377503f10, 2.9688503f12, 2.076084f10, 4.583979f11, -1.210582f12, -9.411356f10, 5.1658555f11, -1.1354007f12, 1.3561716f12, 3.820381f11]
Float32[-3.992539f12, 1.3466187f12, 1.8426713f12, -1.0775456f12, -1.3322446f12, -3.1341203f12, 1.8455957f12, 1.3984404f11, -3.8253725f12, -2.4099963f11 … -1.5258001f11, 6.943643f12, -2.2043084f11, 6.995851f11, -3.009561f12, -1.680115f11, 1.2900035f12, -2.2383829f12, 3.2726083f12, 1.2777775f12]
Float32[-8.7161653f12, 3.910039f12, 4.660013f12, -1.9545968f12, -2.4085036f12, -6.7578825f12, 2.604322f12, -2.2778308f11, -9.379219f12, -7.0784877f11 … -4.7313817f11, 1.617926f13, -1.2377624f12, 8.485303f11, -7.55691f12, -3.8321023f11, 3.2537594f12, -4.1568368f12, 7.794969f12, 3.95339f12]
Float32[-1.0252236f13, 4.8512473f12, 5.6647347f12, -2.1924683f12, -2.7029576f12, -7.937989f12, 2.701903f12, -4.1457546f11, -1.1315447f13, -8.9959485f11 … -5.818322f11, 1.9316185f13, -1.6720943f12, 8.297132f11, -9.181175f12, -4.683762f11, 3.9525669f12, -4.6874996f12, 9.340705f12, 4.9736974f12]
Notice that the solution values sol[i]
are CUDA-based arrays, which can be moved back to the CPU using Array(sol[i])
.
More details on effective use of within-method GPU parallelism can be found in the within-method GPU parallelism tutorial.
Example of Parameter-Parallelism with GPU Ensemble Methods
On the other side of the spectrum, what if we want to solve tons of small ODEs? For this use case, we would use the ensemble methods to solve the same ODE many times with different parameters. This looks like:
using DiffEqGPU, OrdinaryDiffEq, StaticArrays, CUDA
function lorenz(u, p, t)
σ = p[1]
ρ = p[2]
β = p[3]
du1 = σ * (u[2] - u[1])
du2 = u[1] * (ρ - u[3]) - u[2]
du3 = u[1] * u[2] - β * u[3]
return SVector{3}(du1, du2, du3)
end
u0 = @SVector [1.0f0; 0.0f0; 0.0f0]
tspan = (0.0f0, 10.0f0)
p = @SVector [10.0f0, 28.0f0, 8 / 3.0f0]
prob = ODEProblem{false}(lorenz, u0, tspan, p)
prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)
sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()),
trajectories = 10_000)
EnsembleSolution Solution of length 10000 with uType:
SciMLBase.ODESolution{Float32, 2, SubArray{StaticArraysCore.SVector{3, Float32}, 1, Matrix{StaticArraysCore.SVector{3, Float32}}, Tuple{UnitRange{Int64}, Int64}, true}, Nothing, Nothing, SubArray{Float32, 1, Matrix{Float32}, Tuple{UnitRange{Int64}, Int64}, true}, Nothing, DiffEqGPU.ImmutableODEProblem{StaticArraysCore.SVector{3, Float32}, Tuple{Float32, Float32}, false, StaticArraysCore.SVector{3, Float32}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.lorenz), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, GPUTsit5, SciMLBase.LinearInterpolation{SubArray{Float32, 1, Matrix{Float32}, Tuple{UnitRange{Int64}, Int64}, true}, SubArray{StaticArraysCore.SVector{3, Float32}, 1, Matrix{StaticArraysCore.SVector{3, Float32}}, Tuple{UnitRange{Int64}, Int64}, true}}, Nothing, Nothing}
To dig more into this example, see the ensemble GPU solving tutorial.