Skip to content

Commit b7743bb

Browse files
authored
Merge a836b89 into 9608ffc
2 parents 9608ffc + a836b89 commit b7743bb

File tree

31 files changed

+384
-230
lines changed

31 files changed

+384
-230
lines changed

NDTensors/ext/NDTensorsCUDAExt/NDTensorsCUDAExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using NDTensors
44
using NDTensors.SetParameters
55
using Adapt
66
using Functors
7-
using LinearAlgebra: BlasFloat
7+
using LinearAlgebra
88

99
if isdefined(Base, :get_extension)
1010
using CUDA
@@ -18,6 +18,7 @@ end
1818

1919
include("imports.jl")
2020
include("set_types.jl")
21+
include("iscu.jl")
2122
include("adapt.jl")
2223
include("linearalgebra.jl")
2324
end

NDTensors/ext/NDTensorsCUDAExt/adapt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ buffertype(::NDTensorCuArrayAdaptor{B}) where {B} = B
1616
function Adapt.adapt_storage(adaptor::NDTensorCuArrayAdaptor, xs::AbstractArray)
1717
ElT = eltype(xs)
1818
BufT = buffertype(adaptor)
19-
return isbits(xs) ? xs : CuArray{ElT,1,BufT}(xs)
19+
return isbits(xs) ? xs : adapt(CuArray{ElT,1,BufT}, xs)
2020
end
2121

2222
function NDTensors.adapt_storagetype(
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import NDTensors: cu, set_ndims, set_eltype, set_eltype_if_unspecified, similartype
22
import NDTensors:
3-
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!
3+
ContractionProperties, _contract!, GemmBackend, auto_select_backend, _gemm!, iscu
44
import NDTensors.SetParameters: nparameters, get_parameter, set_parameter, default_parameter
55

66
import .CUDA: CuArrayAdaptor
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
iscu(::Type{<:CuArray}) = true

NDTensors/src/NDTensors.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,12 @@ include("abstractarray/set_types.jl")
5050
include("abstractarray/to_shape.jl")
5151
include("abstractarray/similar.jl")
5252
include("abstractarray/ndims.jl")
53+
include("abstractarray/permutedims.jl")
5354
include("abstractarray/fill.jl")
55+
include("abstractarray/mul.jl")
5456
include("array/set_types.jl")
57+
include("array/permutedims.jl")
58+
include("array/mul.jl")
5559
include("tupletools.jl")
5660
include("emptynumber.jl")
5761
include("nodata.jl")
@@ -63,9 +67,11 @@ include("tensor/tensor.jl")
6367
include("dims.jl")
6468
include("tensor/set_types.jl")
6569
include("tensor/similar.jl")
70+
include("tensor/permutedims.jl")
6671
include("adapt.jl")
6772
include("tensoralgebra/generic_tensor_operations.jl")
6873
include("tensoralgebra/contraction_logic.jl")
74+
include("abstractarray/tensoralgebra/contract.jl")
6975

7076
#####################################
7177
# DenseTensor and DiagTensor

NDTensors/src/abstractarray/fill.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
function generic_randn(arraytype::Type{<:AbstractArray}, dim::Integer=0)
1+
function generic_randn(
2+
arraytype::Type{<:AbstractArray}, dim::Integer=0; rng=Random.default_rng()
3+
)
24
arraytype_specified = set_unspecified_parameters(
35
leaf_parenttype(arraytype), DefaultParameters()
46
)
57
data = similar(arraytype_specified, dim)
6-
ElT = eltype(data)
7-
for i in 1:length(data)
8-
data[i] = randn(ElT)
9-
end
10-
return data
8+
return randn!(rng, data)
119
end
1210

1311
function generic_zeros(arraytype::Type{<:AbstractArray}, dim::Integer=0)

NDTensors/src/abstractarray/mul.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
function mul!!(CM::AbstractArray, AM::AbstractArray, BM::AbstractArray, α, β)
2+
return mul!!(
3+
leaf_parenttype(CM), CM, leaf_parenttype(AM), AM, leaf_parenttype(BM), BM, α, β
4+
)
5+
return CM
6+
end
7+
8+
function mul!!(
9+
::Type{<:AbstractArray},
10+
CM,
11+
::Type{<:AbstractArray},
12+
AM,
13+
::Type{<:AbstractArray},
14+
BM,
15+
α,
16+
β,
17+
)
18+
mul!(CM, AM, BM, α, β)
19+
return CM
20+
end
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
## NOTICE!!: Here we are not importing Base.permutedims or Base.permutedims! but
2+
## are writing our own implementation. This allows us to
3+
# NDTensors.permutedims
4+
function permutedims(M::AbstractArray, perm)
5+
return permutedims(leaf_parenttype(M), M, perm)
6+
end
7+
8+
# NDTensors.permutedims
9+
function permutedims(::Type{<:AbstractArray}, M, perm)
10+
return Base.permutedims(M, perm)
11+
end
12+
13+
# NDTensors.permutedims!
14+
function permutedims!(Mdest::AbstractArray, M::AbstractArray, perm)
15+
return permutedims!(leaf_parenttype(Mdest), Mdest, leaf_parenttype(M), M, perm)
16+
end
17+
18+
# NDTensors.permutedims!
19+
function permutedims!(::Type{<:AbstractArray}, Mdest, ::Type{<:AbstractArray}, M, perm)
20+
return Base.permutedims!(Mdest, M, perm)
21+
end
22+
23+
function permutedims!!(B::AbstractArray, A::AbstractArray, perm, f)
24+
return permutedims!!(leaf_parenttype(B), B, leaf_parenttype(A), A, perm, f)
25+
end
26+
27+
function permutedims!!(
28+
Bleaftype::Type{<:AbstractArray}, B, Aleaftype::Type{<:AbstractArray}, A, perm, f
29+
)
30+
permutedims!(Bleaftype, B, Aleaftype, A, perm, f)
31+
return B
32+
end
33+
34+
function permutedims!(::Type{<:AbstractArray}, B, ::Type{<:AbstractArray}, A, perm, f)
35+
B .= f.(B, Base.permutedims(A, perm))
36+
return B
37+
end

NDTensors/src/abstractarray/similar.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ function similar(arraytype::Type{<:AbstractArray}, dims::Tuple)
5656
return similartype(arraytype, shape)(undef, NDTensors.to_shape(arraytype, shape))
5757
end
5858

59+
# For when there are CUArray specific issues inline
60+
iscu(A::AbstractArray) = iscu(typeof(A))
61+
function iscu(A::Type{<:AbstractArray})
62+
return (leaf_parenttype(A) == A ? false : iscu(leaf_parenttype(A)))
63+
end
5964
# This function actually allocates the data.
6065
# Catches conversions of dimensions specified by ranges
6166
# dimensions specified by integers with `Base.to_shape`.
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
using LinearAlgebra: BlasFloat
2+
export backend_auto, backend_blas, backend_generic
3+
4+
@eval struct GemmBackend{T}
5+
(f::Type{<:GemmBackend})() = $(Expr(:new, :f))
6+
end
7+
GemmBackend(s) = GemmBackend{Symbol(s)}()
8+
macro GemmBackend_str(s)
9+
return :(GemmBackend{$(Expr(:quote, Symbol(s)))})
10+
end
11+
12+
const gemm_backend = Ref(:Auto)
13+
function backend_auto()
14+
return gemm_backend[] = :Auto
15+
end
16+
function backend_blas()
17+
return gemm_backend[] = :BLAS
18+
end
19+
function backend_generic()
20+
return gemm_backend[] = :Generic
21+
end
22+
23+
@inline function auto_select_backend(
24+
::Type{<:StridedVecOrMat{<:BlasFloat}},
25+
::Type{<:StridedVecOrMat{<:BlasFloat}},
26+
::Type{<:StridedVecOrMat{<:BlasFloat}},
27+
)
28+
return GemmBackend(:BLAS)
29+
end
30+
31+
@inline function auto_select_backend(
32+
::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}, ::Type{<:AbstractVecOrMat}
33+
)
34+
return GemmBackend(:Generic)
35+
end
36+
37+
function _gemm!(
38+
tA, tB, alpha, A::TA, B::TB, beta, C::TC
39+
) where {TA<:AbstractVecOrMat,TB<:AbstractVecOrMat,TC<:AbstractVecOrMat}
40+
if gemm_backend[] == :Auto
41+
_gemm!(auto_select_backend(TA, TB, TC), tA, tB, alpha, A, B, beta, C)
42+
else
43+
_gemm!(GemmBackend(gemm_backend[]), tA, tB, alpha, A, B, beta, C)
44+
end
45+
end
46+
47+
# BLAS matmul
48+
function _gemm!(
49+
::GemmBackend{:BLAS},
50+
tA,
51+
tB,
52+
alpha,
53+
A::AbstractVecOrMat,
54+
B::AbstractVecOrMat,
55+
beta,
56+
C::AbstractVecOrMat,
57+
)
58+
#@timeit_debug timer "BLAS.gemm!" begin
59+
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
60+
#end # @timeit
61+
end
62+
63+
# generic matmul
64+
function _gemm!(
65+
::GemmBackend{:Generic},
66+
tA,
67+
tB,
68+
alpha::AT,
69+
A::AbstractVecOrMat,
70+
B::AbstractVecOrMat,
71+
beta::BT,
72+
C::AbstractVecOrMat,
73+
) where {AT,BT}
74+
mul!(C, tA == 'T' ? transpose(A) : A, tB == 'T' ? transpose(B) : B, alpha, beta)
75+
return C
76+
end
77+
78+
# Non-trivial permutation
79+
function _contract_scalar_perm!(
80+
Rᵃ::AbstractArray{ElR}, Tᵃ::AbstractArray, perm, α, β=zero(ElR)
81+
) where {ElR}
82+
if iszero(β)
83+
if iszero(α)
84+
fill!(Rᵃ, 0)
85+
else
86+
Rᵃ = permutedims!!(Rᵃ, Tᵃ, perm, (r, t) -> α * t)
87+
end
88+
elseif isone(β)
89+
if iszero(α)
90+
# Rᵃ .= Rᵃ
91+
# No-op
92+
else
93+
Rᵃ = permutedims!!(Rᵃ, Tᵃ, perm, (r, t) -> r + α * t)
94+
end
95+
else
96+
if iszero(α)
97+
# Rᵃ .= β .* Rᵃ
98+
LinearAlgebra.scal!(length(Rᵃ), β, Rᵃ, 1)
99+
else
100+
Rᵃ .= α .* permutedims(Tᵃ, perm) .+ β .* Rᵃ
101+
end
102+
end
103+
return Rᵃ
104+
end
105+
106+
function _contract!(
107+
CT::AbstractArray{El,NC},
108+
AT::AbstractArray{El,NA},
109+
BT::AbstractArray{El,NB},
110+
props::ContractionProperties,
111+
α::Number=one(El),
112+
β::Number=zero(El),
113+
) where {El,NC,NA,NB}
114+
tA = 'N'
115+
if props.permuteA
116+
#@timeit_debug timer "_contract!: permutedims A" begin
117+
Ap = permutedims(AT, props.PA)
118+
#end # @timeit
119+
AM = transpose(reshape(Ap, (props.dmid, props.dleft)))
120+
else
121+
#A doesn't have to be permuted
122+
if Atrans(props)
123+
AM = transpose(reshape(AT, (props.dmid, props.dleft)))
124+
else
125+
AM = reshape(AT, (props.dleft, props.dmid))
126+
end
127+
end
128+
129+
tB = 'N'
130+
if props.permuteB
131+
#@timeit_debug timer "_contract!: permutedims B" begin
132+
Bp = permutedims(BT, props.PB)
133+
#end # @timeit
134+
BM = reshape(Bp, (props.dmid, props.dright))
135+
else
136+
if Btrans(props)
137+
BM = transpose(reshape(BT, (props.dright, props.dmid)))
138+
else
139+
BM = reshape(BT, (props.dmid, props.dright))
140+
end
141+
end
142+
143+
# TODO: this logic may be wrong
144+
if props.permuteC
145+
# if we are computing C = α * A B + β * C
146+
# we need to make sure C is permuted to the same
147+
# ordering as A B which is the inverse of props.PC
148+
if β 0
149+
CM = reshape(permutedims(CT, invperm(props.PC)), (props.dleft, props.dright))
150+
else
151+
# Need to copy here since we will be permuting
152+
# into C later
153+
CM = reshape(copy(CT), (props.dleft, props.dright))
154+
end
155+
else
156+
if Ctrans(props)
157+
CM = transpose(reshape(CT, (props.dright, props.dleft)))
158+
else
159+
CM = reshape(CT, (props.dleft, props.dright))
160+
end
161+
end
162+
163+
#tC = similar(CM)
164+
#_gemm!(tA, tB, El(α), AM, BM, El(β), CM)
165+
CM = mul!!(CM, AM, BM, El(α), El(β))
166+
167+
if props.permuteC
168+
Cr = reshape(CM, props.newCrange)
169+
# TODO: use invperm(pC) here?
170+
#@timeit_debug timer "_contract!: permutedims C" begin
171+
CT .= permutedims(Cr, props.PC)
172+
#end # @timeit
173+
end
174+
175+
return CT
176+
end

NDTensors/src/array/mul.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
function mul!!(::Type{<:Array}, CM, ::Type{<:Array}, AM, ::Type{<:Array}, BM, α, β)
2+
@strided CM = mul!(CM, AM, BM, α, β)
3+
return CM
4+
end

NDTensors/src/array/permutedims.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# NDTensors.permutedims
2+
function permutedims(::Type{<:Array}, M, perm)
3+
return @strided Mdest = Base.permutedims(M, perm)
4+
end
5+
6+
# NDTensors.permutedims!
7+
function permutedims!(::Type{<:Array}, Mdest, ::Type{<:Array}, M, perm)
8+
return @strided Mdest .= Base.permutedims(M, perm)
9+
end
10+
11+
function permutedims!(::Type{<:Array}, B, ::Type{<:Array}, A, perm, f)
12+
@strided B .= f.(B, Base.permutedims(A, perm))
13+
return B
14+
end

NDTensors/src/arraytensor/array.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ end
6161
function permutedims!(
6262
output_array::MatrixOrArrayStorage, array::MatrixOrArrayStorage, perm, f::Function
6363
)
64-
@strided output_array .= f.(output_array, permutedims(array, perm))
64+
output_array = permutedims!!(
65+
leaf_parenttype(output_array), output_array, leaf_parenttype(array), array, perm, f
66+
)
6567
return output_array
6668
end

NDTensors/src/dense/dense.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#
22
# Dense storage
33
#
4-
using LinearAlgebra: BlasFloat
54

65
struct Dense{ElT,DataT<:AbstractArray} <: TensorStorage{ElT}
76
data::DataT

NDTensors/src/dense/densetensor.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ end
8080
# Single index
8181
#
8282

83+
@propagate_inbounds function getindex(T::DenseTensor{<:Number})
84+
return (iscu(T) ? NDTensors.cpu(data(T))[] : data(T)[])
85+
end
86+
8387
@propagate_inbounds function getindex(T::DenseTensor{<:Number}, I::Integer...)
8488
Base.@_inline_meta
8589
return getindex(data(T), Base._sub2ind(T, I...))
@@ -195,7 +199,7 @@ function permutedims!(
195199
) where {N,StoreT<:StridedArray}
196200
RA = array(R)
197201
TA = array(T)
198-
@strided RA .= permutedims(TA, perm)
202+
RA = permutedims!(RA, TA, perm)
199203
return R
200204
end
201205

@@ -243,8 +247,7 @@ function permutedims!(
243247
end
244248
RA = array(R)
245249
TA = array(T)
246-
@strided RA .= f.(RA, permutedims(TA, perm))
247-
return R
250+
return permutedims!!(RA, TA, perm, f)
248251
end
249252

250253
"""

0 commit comments

Comments
 (0)