Skip to content

Commit a0ddeb5

Browse files
Merge pull request #209 from SciML/arraypartition_cuarray
Fix ArrayPartition CuArray transformation
2 parents 2e17617 + 3866e68 commit a0ddeb5

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "2.29.0"
4+
version = "2.29.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/array_partition.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ Base.zero(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zero(A)
6565

6666
## Array
6767

68+
Base.Array(A::ArrayPartition) = ArrayPartition(Array.(A.x))
6869
Base.Array(VA::AbstractVectorOfArray{T,N,A}) where {T,N,A <: AbstractVector{<:ArrayPartition}} = reduce(hcat,Array.(VA.u))
6970

7071
## ones

test/gpu/ode_gpu.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using CUDA, LinearAlgebra, OrdinaryDiffEq
2+
3+
u0 = cu(rand(100))
4+
5+
A = cu(randn(100,100))
6+
7+
f(du,u,p,t) = mul!(du,A,u)
8+
9+
prob = ODEProblem(f,u0,(0.0f0,1.0f0))
10+
11+
sol = solve(prob,Tsit5())
12+
13+
Array(sol)
14+
15+
# https://discourse.julialang.org/t/results-of-secondorderodeproblem-give-error-this-object-is-not-a-gpu-array/82100
16+
17+
u0 = cu(rand(100))
18+
19+
du0 = cu(rand(100))
20+
21+
A = cu(randn(100,100))
22+
23+
f(ddu,du,u,p,t) = mul!(ddu,A,u)
24+
25+
prob = SecondOrderODEProblem(f,du0,u0,(0.0f0,1.0f0))
26+
27+
sol = solve(prob,Tsit5())
28+
29+
CuArray(sol)

0 commit comments

Comments
 (0)