-
Notifications
You must be signed in to change notification settings - Fork 44
Use Adapt.jl #57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Adapt.jl #57
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine, but needs tests.
Hm tests are difficult without depending on CuArrays. |
I'm very interested in using OffsetArrays for an ocean model running on GPUs. Is there anything I could do to help get this PR merged in? Sounds like it's just tests. Depending on CuArrays doesn't sound like a good idea though... Maybe the best approach is to just import OffsetArrays and adapt it ourselves. |
Yes, it's just tests. I haven't looked at Adapt, but that's not really the functionality I'm wondering about. It's really that broadcasting change that needs a test, and that's completely independent of Adapt. It should not be that hard to come up with something, even if you have to create a small custom array type with weird broadcasting behavior. |
I should really finish this PR. Code like this is proliferating https://github.com/CliMA/Oceananigans.jl/blob/15aa5861651e229a949c287c7bafad27ee55d078/src/Utils/adapt_structure.jl#L5 |
On our current release, the following works though it is not ideal: (@v1.6) pkg> st OffsetArrays
Status `~/.julia/environments/v1.6/Project.toml`
[6fe1bfb0] OffsetArrays v1.3.1
julia> using StaticArrays, OffsetArrays
julia> a = @SMatrix [1 2; 3 4]
2×2 SMatrix{2, 2, Int64, 4} with indices SOneTo(2)×SOneTo(2):
1 2
3 4
julia> o = OffsetArray(a, 0:1, 10:11)
2×2 OffsetArray(::SMatrix{2, 2, Int64, 4}, 0:1, 10:11) with eltype Int64 with indices 0:1×10:11:
1 2
3 4
julia> a .+ 1
2×2 SMatrix{2, 2, Int64, 4} with indices SOneTo(2)×SOneTo(2):
2 3
4 5
julia> o .+ 1
2×2 OffsetArray(::Matrix{Int64}, 0:1, 10:11) with eltype Int64 with indices 0:1×10:11:
2 3
4 5 But on this branch, the last operation yields julia> o .+ 1
ERROR: MethodError: no method matching getindex(::Tuple{DataType, DataType}, ::Nothing)
Closest candidates are:
getindex(::Tuple, ::Int64) at tuple.jl:29
getindex(::Tuple, ::Real) at tuple.jl:30
getindex(::Tuple, ::Colon) at tuple.jl:33
...
Stacktrace:
[1] #s155#245
@ ~/.julia/packages/StaticArrays/l7lu2/src/broadcast.jl:100 [inlined]
[2] var"#s155#245"(newsize::Any, ::Any, f::Any, #unused#::Any, s::Any, a::Any)
@ StaticArrays ./none:0
[3] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any, N} where N)
@ Core ./boot.jl:571
[4] copy
@ ~/.julia/packages/StaticArrays/l7lu2/src/broadcast.jl:26 [inlined]
[5] materialize(bc::Base.Broadcast.Broadcasted{StaticArrays.StaticArrayStyle{2}, Nothing, typeof(+), Tuple{OffsetMatrix{Int64, SMatrix{2, 2, Int64, 4}}, Int64}})
@ Base.Broadcast ./broadcast.jl:837
[6] top-level scope
@ REPL[6]:1 Might merit a bit of investigation. One other issue (pointed out by @KristofferC) is a case of piracy: julia> which(eltype, Tuple{Type{SubArray{UInt8, 1, Vector{UInt8}, Tuple{UnitRange{Int64}}, true}}})
eltype(::Type{var"#s73"} where var"#s73"<:(AbstractArray{E, N} where N)) where E in Base at abstractarray.jl:152
julia> using Adapt
julia> which(eltype, Tuple{Type{SubArray{UInt8, 1, Vector{UInt8}, Tuple{UnitRange{Int64}}, true}}})
eltype(::Type{var"#s15"} where var"#s15"<:(Union{Base.LogicalIndex{T, var"#s5"} where var"#s5"<:Src, Base.ReinterpretArray{T, N, var"#s1", var"#s2", IsReshaped} where IsReshaped where var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where L where I where var"#s4" where var"#s3", var"#s13"} where var"#s1" where var"#s13"<:Src, Base.ReshapedArray{T, N, var"#s4", MI} where MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N} where var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where IsReshaped where var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where L where I where var"#s4" where var"#s3", var"#s14"} where var"#s11" where var"#s5" where var"#s1", SubArray{var"#s3", var"#s2", var"#s14", I, L} where L where I where var"#s2" where var"#s3", var"#s14"} where var"#s14"<:Src, SubArray{T, N, var"#s5", I, L} where L where I where var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where IsReshaped where var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where L where I where var"#s4" where var"#s3", var"#s15"} where var"#s11" where var"#s1" where var"#s2", Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N} where var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where IsReshaped where var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where L where I where var"#s4" where var"#s3", var"#s15"} where var"#s11" where var"#s5" where var"#s1", SubArray{var"#s3", var"#s2", var"#s15", I, L} where L where I where var"#s2" where var"#s3", var"#s15"} where var"#s3" where var"#s4", var"#s15"} where var"#s15"<:Src, LinearAlgebra.Adjoint{T, var"#s1"} where var"#s1"<:Dst, LinearAlgebra.Diagonal{T, var"#s11"} where var"#s11"<:Dst, LinearAlgebra.LowerTriangular{T, var"#s7"} where var"#s7"<:Dst, LinearAlgebra.Transpose{T, var"#s6"} where var"#s6"<:Dst, LinearAlgebra.Tridiagonal{T, var"#s12"} where var"#s12"<:Dst, LinearAlgebra.UnitLowerTriangular{T, var"#s8"} where var"#s8"<:Dst, LinearAlgebra.UnitUpperTriangular{T, var"#s10"} where var"#s10"<:Dst, LinearAlgebra.UpperTriangular{T, var"#s9"} where var"#s9"<:Dst, PermutedDimsArray{T, N, var"#s4", var"#s3", var"#s2"} where var"#s2"<:Src where var"#s3" where var"#s4"} where Dst where Src where N)) where T in Adapt at /home/tim/.julia/packages/Adapt/8kQMV/src/wrappers.jl:108 |
test/runtests.jl
Outdated
@test arr == adapt(Array, s_arr) | ||
|
||
# Check that broadcast respects parent | ||
@test Base.Broadcast.BroadcastStyle(typeof(arr)) == StaticArrays.StaticArrayStyle{2}() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to do a little more than just test the style, as illustrated in the comment thread.
Thanks for the review! I originally had the test for actually performing the broadcast, but as you noted that doesn't currently work. StaticArrays broadcasting implementation has a hard-coded list of wrappers it supports https://github.com/JuliaArrays/StaticArrays.jl/blob/b95d07e5f3cf731ff89bfe2c6eebc73ed8480cdd/src/broadcast.jl#L100 |
@timholy What is your preferred way forward? I can disable the |
Codecov Report
@@ Coverage Diff @@
## master #57 +/- ##
=======================================
Coverage 99.21% 99.22%
=======================================
Files 4 4
Lines 256 257 +1
=======================================
+ Hits 254 255 +1
Misses 2 2
Continue to review full report at Codecov.
|
Pretty much the same boat as in #57 (comment). We have to have a functional test of broadcasting that shows that this is a step forward for an array type that obeys the standard interface. If it's a regression specifically for StaticArrays, there are reasons to argue it's a problem that StaticArrays has to learn to deal with. I could merge it under those circumstances. But it seems silly to allow messing with something as fundamental as broadcasting without a reasonable demonstration that old stuff mostly continues to work while allowing improvements. |
You might be able to steal definitions from julia's |
An excellent strategy. I look forward to getting the broadcasting done once we have the infrastructure to support it. |
Thanks Tim! |
@ali-ramadhan do note that broadcasting OffsetArrays with CuArrays won't execute on the GPU yet. |
Adapt is a lightweight dependency that allows wrapper packages like OffsetArray to be
converted between different underlying arrays. The primary use-case is to be able to
convert a CPU OffsetArray to a OffsetArray whose memory is a
CuArray
.