Skip to content

Commit 938a168

Browse files
authored
Merge a9c355c into a00ef70
2 parents a00ef70 + a9c355c commit 938a168

File tree

5 files changed

+40
-10
lines changed

5 files changed

+40
-10
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ SparseArrays = "1.6"
6464
SplitApplyCombine = "1.2.2"
6565
StaticArrays = "0.12, 1.0"
6666
Strided = "2"
67-
StridedViews = "0.2"
67+
StridedViews = "0.2.2"
6868
TimerOutputs = "0.5.5"
6969
TupleTools = "1.2.0"
7070
VectorInterface = "0.4.2"

NDTensors/src/array/permutedims.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@ function permutedims(E::Exposed{<:Array}, perm)
1010
end
1111

1212
function permutedims!(Edest::Exposed{<:Array}, Esrc::Exposed{<:Array}, perm)
13-
@strided unexpose(Edest) .= permutedims(Esrc, perm)
14-
return unexpose(Edest)
13+
a_dest = unexpose(Edest)
14+
a_src = unexpose(Esrc)
15+
@strided a_dest .= permutedims(a_src, perm)
16+
return a_dest
1517
end
1618

1719
function permutedims!(Edest::Exposed{<:Array}, Esrc::Exposed{<:Array}, perm, f)
18-
@strided unexpose(Edest) .= f.(unexpose(Edest), permutedims(Esrc, perm))
19-
return unexpose(Edest)
20+
a_dest = unexpose(Edest)
21+
a_src = unexpose(Esrc)
22+
@strided a_dest .= f.(a_dest, permutedims(a_src, perm))
23+
return a_dest
2024
end

NDTensors/src/dense/densetensor.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,20 @@ end
193193
# Maybe allocate output data.
194194
# TODO: Remove this in favor of `map!`
195195
# applied to `PermutedDimsArray`.
196-
function permutedims!!(R::DenseTensor, T::DenseTensor, perm, f::Function=(r, t) -> t)
196+
function permutedims!!(R::DenseTensor, T::DenseTensor, perm, f::Function)
197197
Base.checkdims_perm(R, T, perm)
198198
RR = convert(promote_type(typeof(R), typeof(T)), R)
199199
permutedims!(RR, T, perm, f)
200200
return RR
201201
end
202202

203+
function permutedims!!(R::DenseTensor, T::DenseTensor, perm)
204+
Base.checkdims_perm(R, T, perm)
205+
RR = convert(promote_type(typeof(R), typeof(T)), R)
206+
permutedims!(RR, T, perm)
207+
return RR
208+
end
209+
203210
# TODO: call permutedims!(R,T,perm,(r,t)->t)?
204211
function permutedims!(
205212
R::DenseTensor{<:Number,N,StoreT}, T::DenseTensor{<:Number,N,StoreT}, perm::NTuple{N,Int}
@@ -216,7 +223,7 @@ function permutedims!(
216223
) where {N}
217224
RA = array(R)
218225
TA = array(T)
219-
RA .= permutedims(expose(TA), perm)
226+
permutedims!(expose(RA), expose(TA), perm)
220227
return R
221228
end
222229

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[deps]
2+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
3+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
4+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

NDTensors/src/tensoroperations/generic_tensor_operations.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,31 @@ end
99
# and return the result of the permutation.
1010
# Similar to `BangBang.jl` notation:
1111
# https://juliafolds.github.io/BangBang.jl/stable/.
12-
function permutedims!!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)
12+
function permutedims!!(output_tensor::Tensor, tensor::Tensor, perm, f::Function)
1313
Base.checkdims_perm(output_tensor, tensor, perm)
1414
permutedims!(output_tensor, tensor, perm, f)
1515
return output_tensor
1616
end
1717

18-
function permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)
18+
# Equivalent to `permutedims!!(output_tensor, tensor, perm, (r, t) -> t)`
19+
function permutedims!!(output_tensor::Tensor, tensor::Tensor, perm)
20+
Base.checkdims_perm(output_tensor, tensor, perm)
21+
permutedims!(output_tensor, tensor, perm)
22+
return output_tensor
23+
end
24+
25+
function permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function)
26+
Base.checkdims_perm(output_tensor, tensor, perm)
27+
error(
28+
"`permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function` not implemented for `typeof(output_tensor) = $(typeof(output_tensor))`, `typeof(tensor) = $(typeof(tensor))`, `perm = $perm`, and `f = $f`.",
29+
)
30+
return output_tensor
31+
end
32+
33+
function permutedims!(output_tensor::Tensor, tensor::Tensor, perm)
1934
Base.checkdims_perm(output_tensor, tensor, perm)
2035
error(
21-
"`permutedims!(output_tensor::Tensor, tensor::Tensor, perm, f::Function=(r, t) -> t)` not implemented for `typeof(output_tensor) = $(typeof(output_tensor))`, `typeof(tensor) = $(typeof(tensor))`, `perm = $perm`, and `f = $f`.",
36+
"`permutedims!(output_tensor::Tensor, tensor::Tensor, perm` not implemented for `typeof(output_tensor) = $(typeof(output_tensor))`, `typeof(tensor) = $(typeof(tensor))`, and `perm = $perm`.",
2237
)
2338
return output_tensor
2439
end

0 commit comments

Comments
 (0)