Getting Started with GPU-Accelerated Differential Equations in Julia

The two ways to accelerate ODE solvers with GPUs

There are two very different ways that one can accelerate an ODE solution with GPUs. There is one case where u is very big and f is very expensive but very structured, and you use GPUs to accelerate the computation of said f. The other use case is where u is very small, but you want to solve the ODE f over many different initial conditions (u0) or parameters p. In that case, you can use GPUs to parallelize over different parameters and initial conditions. In other words:

Type of ProblemSciML Solution
Accelerate a big ODEUse CUDA.jl's CuArray as u0
Solve the same ODE with many u0 and pUse DiffEqGPU.jl'sEnsembleGPUArray and EnsembleGPUKernel

Supported GPUs

SciML's GPU support extends to a wide array of hardware, including:

GPU ManufacturerGPU Kernel LanguageJulia Support PackageBackend Type
NVIDIACUDACUDA.jlCUDA.CUDABackend()
AMDROCmAMDGPU.jlAMDGPU.ROCBackend()
IntelOneAPIOneAPI.jloneAPI.oneAPIBackend()
Apple (M-Series)MetalMetal.jlMetal.MetalBackend()

For this tutorial we will demonstrate the CUDA backend for NVIDIA GPUs, though any of the other GPUs can be used by simply swapping out the backend choice.

Simple Example of Within-Method GPU Parallelism

The following is a quick and dirty example of doing within-method GPU parallelism. Let's say we had a simple but large ODE with many linear algebra or map/broadcast operations:

using OrdinaryDiffEq, LinearAlgebra
u0 = rand(1000)
A = randn(1000, 1000)
f(du, u, p, t) = mul!(du, A, u)
prob = ODEProblem(f, u0, (0.0, 1.0))
sol = solve(prob, Tsit5())
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 50-element Vector{Float64}:
 0.0
 0.0015478761855679305
 0.00648580863398825
 0.013842416823521709
 0.022391643841718323
 0.03325242939710622
 0.04561502822320631
 0.05901456341425282
 0.07299580669091357
 0.087920728422522
 ⋮
 0.8051958359073149
 0.8303464269733638
 0.855220989124135
 0.8807960944818966
 0.906252317657025
 0.9322591011728141
 0.9588072465756187
 0.9814855255785365
 1.0
u: 50-element Vector{Vector{Float64}}:
 [0.6077993951026277, 0.7492427325037554, 0.24030948236125937, 0.7876203548496596, 0.9790896488924092, 0.39466961581632154, 0.8624005981887551, 0.44259694848939535, 0.8924620923795857, 0.02291634363312045  …  0.9575521471392613, 0.8803300650478935, 0.5201892514317958, 0.9654442939425254, 0.36768574375921526, 0.6396059294669751, 0.901286174640774, 0.685994963053012, 0.27133784540119166, 0.3416276233225435]
 [0.6388217639528583, 0.7550178758102967, 0.23851886858256294, 0.798157641408362, 0.9991230598196218, 0.3979320485800568, 0.8199383091064789, 0.41758245795126836, 0.8625227443916994, 0.025110186525423902  …  0.9928820975413687, 0.9188273924595998, 0.5424605821203632, 0.9989534800028482, 0.34863444594913573, 0.6129502249468205, 0.9449488603002396, 0.6859185020749737, 0.2372386250598567, 0.35672114399690125]
 [0.7604156908506836, 0.7674550530644948, 0.23847720394552546, 0.8294452834399094, 1.0804114475300428, 0.4060703085606619, 0.6888640957065337, 0.3343223619420136, 0.7636075351231503, 0.03724209063173313  …  1.1052247509058286, 1.036939414566552, 0.590016004446043, 1.0983914945208484, 0.2796647248494942, 0.5103533632849603, 1.0769710826307772, 0.6990725769164804, 0.11904315092961476, 0.41837714984868773]
 [1.0053064029594083, 0.7754094885668386, 0.25381901276200136, 0.8727080609515534, 1.2506728414728154, 0.41008832071754375, 0.5034713184150569, 0.19682926309607676, 0.6106487534432536, 0.0715403130063657  …  1.2729780841888372, 1.1972501755009486, 0.592849788111663, 1.2227273732202086, 0.15679756465134848, 0.31151327457108635, 1.2523068815326168, 0.7558685670894918, -0.08699326886042917, 0.5490906533963484]
 [1.3834519386374002, 0.7837838905457066, 0.2921861003267462, 0.9236774476427589, 1.523181116666089, 0.3997205500175946, 0.2955086213223752, 0.009936882166138456, 0.4342715986146158, 0.1385042704272604  …  1.472789062410583, 1.35732431035077, 0.4905464950458207, 1.3244911127993384, -0.007538677655676027, 0.02053802821052491, 1.4211369646984968, 0.8788149364361183, -0.37901171481451906, 0.7632432858301461]
 [1.9996969576951213, 0.8237528265358053, 0.36398372741091023, 0.9962401787667565, 1.989912694960366, 0.36144105137963456, 0.022407587889207218, -0.27997753933805464, 0.2327477796056976, 0.26713486066052694  …  1.7455467459548604, 1.5225296611969998, 0.19108356494503692, 1.3716391061004376, -0.22617082700556285, -0.4215580154598785, 1.5721740596708869, 1.1317783194275162, -0.8466567366911893, 1.1446223361289465]
 [2.863933540547281, 0.9572029196534585, 0.4555527992887772, 1.0899709291588249, 2.7048696099110843, 0.29049292086911155, -0.34906580016617744, -0.693400928174069, 0.06481141247194418, 0.46686840196762325  …  2.1045157539194497, 1.6791929519403357, -0.3901551134623551, 1.2835669148227058, -0.4352059108610356, -0.989436360202414, 1.6318025361215416, 1.576996213459419, -1.5362375063508247, 1.7571106395927933]
 [3.9567238329930747, 1.2660033518194682, 0.5242731272720043, 1.1857324242261225, 3.7570311466950703, 0.20741445738859104, -0.9259994793290318, -1.2517796356900284, -0.015227556801584132, 0.7251706951816024  …  2.585605057722956, 1.8683327115008561, -1.3278904296850627, 0.9695710114192313, -0.5103261403065463, -1.64077600938374, 1.4987013015224453, 2.300299770684521, -2.5125883888250757, 2.6902049052817136]
 [5.212509354712602, 1.8219352088265837, 0.4927626775971167, 1.2217430428676943, 5.287165498600609, 0.1684122496688786, -1.9099209691049166, -1.9601273540869935, 0.0068975008072079645, 0.9971546368338595  …  3.2216661701781475, 2.1984737250394413, -2.688231906006248, 0.3335252094962419, -0.2468236831160612, -2.330373319496057, 1.009343445746315, 3.401298931576341, -3.8451962977980796, 4.05284494696332]
 [6.597347622680693, 2.707730246048976, 0.23350582843453324, 1.0593978321294257, 7.657144179123195, 0.2683507530696783, -3.743003371209908, -2.850509989437023, 0.049158725518221195, 1.2230149083895065  …  4.053428323483684, 2.895916786228373, -4.661277250011878, -0.7774081803390297, 0.7082635349031607, -3.0958792890184026, -0.17535764379219534, 5.073413283353828, -5.723384836635318, 6.104970574138457]
 ⋮
 [8.764780066224067e8, 1.896679198856089e10, -2.8791238716648205e10, -1.7782386084904808e10, -1.8629943618088207e10, 1.6352128582481981e10, -2.013634606630096e10, 9.378994001048164e9, 3.250327838955565e10, -2.4188887993906605e10  …  -2.0476771292206867e10, -8.801964896424888e8, -6.591287844944542e9, -1.0345209058982435e10, 1.5391788719636024e10, 3.4634498633198557e9, -1.5679486188023586e10, -1.213285938302461e9, -5.438045934829378e9, -1.1518680809893606e10]
 [-1.6008137867406118e9, 4.436306900109785e10, -7.133880987032213e10, -3.267709320324597e10, -4.100510774819068e10, 3.507435874752564e10, -4.109720486029508e10, 3.1143379966602604e10, 7.211345057356902e10, -4.872831003357185e10  …  -4.42749636134423e10, -6.410373034463417e9, -1.1352339308446444e10, -2.023518090277868e10, 2.6040479824811573e10, 5.334322576628781e9, -2.8138574175721382e10, -7.436936219045354e9, -1.536871597666749e10, -3.2057632585082157e10]
 [-9.835055910048376e9, 1.0096923135993665e11, -1.6803850623723938e11, -5.5607741471781364e10, -8.78295963571415e10, 7.405439571856248e10, -8.166362597352007e10, 9.049337091938423e10, 1.570788747232032e11, -9.518209685019202e10  …  -9.493617992720575e10, -2.421290453960685e10, -1.8516986299478905e10, -3.8494069151425415e10, 3.99757041361387e10, 6.87526800357321e9, -4.7974744954544e10, -2.3632620006238934e10, -3.703165440635093e10, -8.192947141839993e10]
 [-3.3092087238827965e10, 2.3126596324069577e11, -3.9204980268073303e11, -8.584692704455705e10, -1.8883942613339667e11, 1.5867585142554782e11, -1.622942803308893e11, 2.509502540181117e11, 3.453400326746502e11, -1.842867903222202e11  …  -2.0802521888540536e11, -7.667313323238606e10, -2.9094192895382225e10, -7.28300164926125e10, 5.1896106956949196e10, 4.975465607817764e9, -7.821226403444156e10, -6.299240672838114e10, -8.268893822661942e10, -2.0338509542303232e11]
 [-9.044451266344685e10, 5.194873853485014e11, -8.856619721051328e11, -1.0532240735501501e11, -3.9829126738653033e11, 3.3694374364964685e11, -3.1529565367391327e11, 6.557717021798854e11, 7.451969398253005e11, -3.4395225415398e11  …  -4.5426926671591754e11, -2.1497354388960907e11, -4.345809137183459e10, -1.3344021671751703e11, 3.82089145460287e10, -1.0247641642202696e10, -1.1689857185952869e11, -1.485384955832559e11, -1.692423883829415e11, -4.807822887195532e11]
 [-2.267247062902883e11, 1.1697967852823088e12, -1.9858339572089553e12, -4.6619944134888565e10, -8.424662306146082e11, 7.233687913797831e11, -6.084511182028693e11, 1.6752083134825513e12, 1.606733596390494e12, -6.224138541052672e11  …  -1.0090562928663315e12, -5.698888225018145e11, -6.314178299136891e10, -2.3762883304224597e11, -7.506665077293513e10, -6.784138030055833e10, -1.527465096926599e11, -3.263220885460888e11, -3.2348788757574664e11, -1.1128434354892239e12]
 [-5.396515870699504e11, 2.639962114912662e12, -4.4247767716484375e12, 3.1226918592045764e11, -1.7919549974816555e12, 1.5683319551611003e12, -1.1619329367779023e12, 4.2045624098930684e12, 3.4505821619469956e12, -1.068930513344212e12  …  -2.278776586292099e12, -1.4535976419461814e12, -9.191225782225375e10, -3.997871976301755e11, -5.0405311840302344e11, -2.4812008650601465e11, -1.4285096377187524e11, -6.705233099948209e11, -5.661276078077296e11, -2.5238877370180205e12]
 [-1.0888619612530238e12, 5.2329957377794e12, -8.627492497950489e12, 1.298170123182634e12, -3.399601510927614e12, 3.019585725089798e12, -1.9742412504288433e12, 8.99387505490768e12, 6.516214041688621e12, -1.5666202424681304e12  …  -4.568349218894518e12, -3.1201715963489346e12, -1.294639298214924e11, -5.640713802356516e11, -1.517735827705361e12, -6.418491697608444e11, -1.79862808914855e10, -1.1562044987971702e12, -8.153635928282788e11, -4.933624652748248e12]
 [-1.895812881330054e12, 9.083221209112744e12, -1.4725653427386322e13, 3.190050714478752e12, -5.732181029828195e12, 5.128172678169889e12, -2.9866649838152583e12, 1.646762355726429e13, 1.08148198792512e13, -1.950916864883442e12  …  -8.056223336936267e12, -5.708294163123032e12, -1.7145024869188217e11, -6.4776831302855e11, -3.3198320033962886e12, -1.314032554112901e12, 2.6037936347618207e11, -1.7023153937138574e12, -9.653660834081237e11, -8.357515822381224e12]

Translating this to a GPU-based solve of the ODE simply requires moving the arrays for the initial condition, parameters, and caches to the GPU. This looks like:

using OrdinaryDiffEq, CUDA, LinearAlgebra
u0 = cu(rand(1000))
A = cu(randn(1000, 1000))
f(du, u, p, t) = mul!(du, A, u)
prob = ODEProblem(f, u0, (0.0f0, 1.0f0)) # Float32 is better on GPUs!
sol = solve(prob, Tsit5())
retcode: Success
Interpolation: specialized 4th order "free" interpolation
t: 50-element Vector{Float32}:
 0.0
 0.0012238872
 0.0035123606
 0.007574204
 0.0141317025
 0.022441624
 0.032786697
 0.044370756
 0.057975534
 0.072406396
 ⋮
 0.8123368
 0.83820784
 0.8635944
 0.88990456
 0.914737
 0.9414715
 0.9677188
 0.9943564
 1.0
u: 50-element Vector{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}:
 Float32[0.07975651, 0.79987717, 0.15339984, 0.44715297, 0.057252217, 0.017083757, 0.397226, 0.37008578, 0.61970097, 0.31732622  …  0.7707357, 0.51477575, 0.9488768, 0.8984697, 0.6454726, 0.23771797, 0.015754933, 0.8835519, 0.14337541, 0.8914004]
 Float32[0.091294326, 0.79488426, 0.16077894, 0.46531612, 0.042521283, 0.022484621, 0.41546035, 0.35169297, 0.6409096, 0.29587573  …  0.7650729, 0.5296531, 0.95839334, 0.89492947, 0.6625467, 0.20936961, 0.026015624, 0.89660525, 0.1544463, 0.92302203]
 Float32[0.114186615, 0.78224206, 0.17290637, 0.5025619, 0.011553634, 0.03309998, 0.44805008, 0.3156221, 0.68108934, 0.25510186  …  0.7512275, 0.56094134, 0.97495365, 0.8904496, 0.69106406, 0.15726943, 0.04654314, 0.9167866, 0.17067073, 0.983254]
 Float32[0.15831761, 0.7495799, 0.18813594, 0.5789597, -0.054026656, 0.053368658, 0.5013973, 0.24553536, 0.7551918, 0.18121126  …  0.716143, 0.6275383, 1.0009681, 0.8888348, 0.731154, 0.06826759, 0.087696426, 0.9380107, 0.18462086, 1.0938436]
 Float32[0.23556055, 0.670697, 0.19117078, 0.72898, -0.1863768, 0.08889933, 0.57666373, 0.11282618, 0.88755006, 0.06080608  …  0.63031274, 0.76434344, 1.036005, 0.9011437, 0.7695284, -0.06339563, 0.16844058, 0.92819405, 0.16427559, 1.2824914]
 Float32[0.3349104, 0.5268481, 0.14399192, 0.9637455, -0.39342663, 0.13628027, 0.6551657, -0.09844001, 1.0917639, -0.08507974  …  0.4653609, 0.9873055, 1.0716101, 0.9383143, 0.77739954, -0.20181294, 0.29899853, 0.82583314, 0.05298618, 1.5380032]
 Float32[0.44189632, 0.28092495, -0.023511056, 1.3204851, -0.6906421, 0.19388293, 0.7288427, -0.44552433, 1.4329518, -0.23460549  …  0.15584946, 1.3353711, 1.1035643, 1.0097601, 0.7374969, -0.31466138, 0.50858957, 0.5353016, -0.24510631, 1.8727479]
 Float32[0.5123232, -0.087102264, -0.40441057, 1.7962462, -1.0271978, 0.24979705, 0.7777053, -0.96750784, 1.9790694, -0.31548086  …  -0.370383, 1.8042051, 1.1133202, 1.1107806, 0.6521072, -0.3396446, 0.8058779, -0.036394976, -0.8423278, 2.2423425]
 Float32[0.4969156, -0.66848576, -1.1984875, 2.440554, -1.3213347, 0.29623276, 0.7730812, -1.7838187, 2.9263222, -0.1956353  …  -1.3362389, 2.4380078, 1.0525147, 1.2359463, 0.54094386, -0.20112997, 1.2393956, -1.08271, -2.0068703, 2.5991454]
 Float32[0.35627759, -1.528888, -2.5688074, 3.1885526, -1.3208227, 0.31716573, 0.64322233, -2.894179, 4.417735, 0.37391642  …  -2.9813879, 3.1795447, 0.8228836, 1.3337202, 0.4759468, 0.16606079, 1.8130995, -2.6827354, -3.964981, 2.7478895]
 ⋮
 Float32[-2.9611823f10, -7.206831f9, 4.877548f9, -7.6340726f9, -1.9562246f10, -3.07769f10, 4.3311944f10, 1.0876393f10, -1.0741676f10, -8.717373f9  …  6.465729f9, 3.471865f10, 1.8261494f9, 1.8211381f10, -1.0422555f10, -9.655502f9, 8.361296f9, -1.8834127f10, 9.820999f9, -2.0596848f8]
 Float32[-6.8715753f10, -1.0922034f10, 1.5217509f10, -2.1514357f10, -4.1637315f10, -6.724589f10, 8.992277f10, 2.061411f10, -3.2678451f10, -1.4891718f10  …  1.1013754f10, 8.7965204f10, 6.329102f9, 3.817686f10, -2.9388624f10, -1.6565554f10, 1.774597f10, -4.7151383f10, 2.9557037f10, -5.2396816f8]
 Float32[-1.5629435f11, -1.0514547f10, 4.1878036f10, -5.335225f10, -8.608189f10, -1.4399508f11, 1.7814594f11, 3.7791662f10, -8.996295f10, -2.375732f10  …  1.6613748f10, 2.1484424f11, 1.583286f10, 7.604724f10, -7.630109f10, -2.6657223f10, 3.8433223f10, -1.1061219f11, 8.115341f10, 2.1229366f9]
 Float32[-3.6276217f11, 1.1681056f10, 1.1311899f11, -1.2594249f11, -1.7937478f11, -3.1575586f11, 3.4853306f11, 6.8063035f10, -2.4367129f11, -3.683709f10  …  2.068339f10, 5.3148172f11, 3.2616524f10, 1.4841466f11, -1.9763138f11, -4.1359827f10, 8.93784f10, -2.5506721f11, 2.1882511f11, 2.126189f10]
 Float32[-7.9303154f11, 1.0146961f11, 2.8045456f11, -2.6570852f11, -3.5108264f11, -6.6031295f11, 6.308077f11, 1.1076454f11, -6.02342f11, -5.629563f10  …  1.4491229f10, 1.2276578f12, 4.8370266f10, 2.6485626f11, -4.7574303f11, -6.0631642f10, 2.0548513f11, -5.365526f11, 5.359557f11, 9.654057f10]
 Float32[-1.8102049f12, 4.2070386f11, 7.303765f11, -5.5667196f11, -7.021809f11, -1.4542008f12, 1.1302492f12, 1.6000944f11, -1.5495299f12, -1.0247669f11  …  -2.6377503f10, 2.9688503f12, 2.076084f10, 4.583979f11, -1.210582f12, -9.411356f10, 5.1658555f11, -1.1354007f12, 1.3561716f12, 3.820381f11]
 Float32[-3.992539f12, 1.3466187f12, 1.8426713f12, -1.0775456f12, -1.3322446f12, -3.1341203f12, 1.8455957f12, 1.3984404f11, -3.8253725f12, -2.4099963f11  …  -1.5258001f11, 6.943643f12, -2.2043084f11, 6.995851f11, -3.009561f12, -1.680115f11, 1.2900035f12, -2.2383829f12, 3.2726083f12, 1.2777775f12]
 Float32[-8.7161653f12, 3.910039f12, 4.660013f12, -1.9545968f12, -2.4085036f12, -6.7578825f12, 2.604322f12, -2.2778308f11, -9.379219f12, -7.0784877f11  …  -4.7313817f11, 1.617926f13, -1.2377624f12, 8.485303f11, -7.55691f12, -3.8321023f11, 3.2537594f12, -4.1568368f12, 7.794969f12, 3.95339f12]
 Float32[-1.0252236f13, 4.8512473f12, 5.6647347f12, -2.1924683f12, -2.7029576f12, -7.937989f12, 2.701903f12, -4.1457546f11, -1.1315447f13, -8.9959485f11  …  -5.818322f11, 1.9316185f13, -1.6720943f12, 8.297132f11, -9.181175f12, -4.683762f11, 3.9525669f12, -4.6874996f12, 9.340705f12, 4.9736974f12]

Notice that the solution values sol[i] are CUDA-based arrays, which can be moved back to the CPU using Array(sol[i]).

More details on effective use of within-method GPU parallelism can be found in the within-method GPU parallelism tutorial.

Example of Parameter-Parallelism with GPU Ensemble Methods

On the other side of the spectrum, what if we want to solve tons of small ODEs? For this use case, we would use the ensemble methods to solve the same ODE many times with different parameters. This looks like:

using DiffEqGPU, OrdinaryDiffEq, StaticArrays, CUDA

function lorenz(u, p, t)
    σ = p[1]
    ρ = p[2]
    β = p[3]
    du1 = σ * (u[2] - u[1])
    du2 = u[1] * (ρ - u[3]) - u[2]
    du3 = u[1] * u[2] - β * u[3]
    return SVector{3}(du1, du2, du3)
end

u0 = @SVector [1.0f0; 0.0f0; 0.0f0]
tspan = (0.0f0, 10.0f0)
p = @SVector [10.0f0, 28.0f0, 8 / 3.0f0]
prob = ODEProblem{false}(lorenz, u0, tspan, p)
prob_func = (prob, i, repeat) -> remake(prob, p = (@SVector rand(Float32, 3)) .* p)
monteprob = EnsembleProblem(prob, prob_func = prob_func, safetycopy = false)

sol = solve(monteprob, GPUTsit5(), EnsembleGPUKernel(CUDA.CUDABackend()),
    trajectories = 10_000)
EnsembleSolution Solution of length 10000 with uType:
SciMLBase.ODESolution{Float32, 2, SubArray{StaticArraysCore.SVector{3, Float32}, 1, Matrix{StaticArraysCore.SVector{3, Float32}}, Tuple{UnitRange{Int64}, Int64}, true}, Nothing, Nothing, SubArray{Float32, 1, Matrix{Float32}, Tuple{UnitRange{Int64}, Int64}, true}, Nothing, DiffEqGPU.ImmutableODEProblem{StaticArraysCore.SVector{3, Float32}, Tuple{Float32, Float32}, false, StaticArraysCore.SVector{3, Float32}, SciMLBase.ODEFunction{false, SciMLBase.AutoSpecialize, typeof(Main.lorenz), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing, Nothing}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, SciMLBase.StandardODEProblem}, GPUTsit5, SciMLBase.LinearInterpolation{SubArray{Float32, 1, Matrix{Float32}, Tuple{UnitRange{Int64}, Int64}, true}, SubArray{StaticArraysCore.SVector{3, Float32}, 1, Matrix{StaticArraysCore.SVector{3, Float32}}, Tuple{UnitRange{Int64}, Int64}, true}}, Nothing, Nothing}

To dig more into this example, see the ensemble GPU solving tutorial.