Skip to content

Commit 3f1afb8

Browse files
authored
[NDTensors] Start TensorAlgebra module, new TTGT implementation (#1265)
1 parent 408516d commit 3f1afb8

24 files changed

+393
-2
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2

NDTensors/src/NDTensors.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
module NDTensors
2+
# TODO: List types, macros, and functions being used.
23
using Adapt
34
using Base.Threads
45
using Compat
@@ -19,9 +20,12 @@ using TimerOutputs
1920
using TupleTools
2021

2122
# TODO: Define an `AlgorithmSelection` module
23+
# TODO: List types, macros, and functions being used.
2224
include("algorithm.jl")
2325
include("SetParameters/src/SetParameters.jl")
2426
using .SetParameters
27+
include("TensorAlgebra/src/TensorAlgebra.jl")
28+
using .TensorAlgebra: TensorAlgebra
2529
include("DiagonalArrays/src/DiagonalArrays.jl")
2630
using .DiagonalArrays
2731
include("BlockSparseArrays/src/BlockSparseArrays.jl")
@@ -76,8 +80,8 @@ include("dims.jl")
7680
include("tensor/set_types.jl")
7781
include("tensor/similar.jl")
7882
include("adapt.jl")
79-
include("tensoralgebra/generic_tensor_operations.jl")
80-
include("tensoralgebra/contraction_logic.jl")
83+
include("tensoroperations/generic_tensor_operations.jl")
84+
include("tensoroperations/contraction_logic.jl")
8185
include("abstractarray/tensoralgebra/contract.jl")
8286

8387
#####################################
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module TensorAlgebra
2+
using LinearAlgebra: mul!
3+
using ..NDTensors: Algorithm, @Algorithm_str
4+
5+
include("bipartitionedpermutation.jl")
6+
include("fusedims.jl")
7+
include("contract/contract.jl")
8+
include("contract/output_labels.jl")
9+
include("contract/allocate_output.jl")
10+
include("contract/contract_matricize/contract.jl")
11+
end
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
struct BipartitionedPermutation{P1,P2}
2+
partition1::P1
3+
partition2::P2
4+
end
5+
6+
function Base.getindex(biperm::BipartitionedPermutation, i)
7+
if i == 1
8+
return biperm.partition1
9+
elseif i == 2
10+
return biperm.partition2
11+
end
12+
return error("Only 2 partitions")
13+
end
14+
15+
function flatten(biperm::BipartitionedPermutation)
16+
return (biperm[1]..., biperm[2]...)
17+
end
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
function allocate_output(
2+
::typeof(contract),
3+
alg::Algorithm,
4+
labels_dest,
5+
a1::AbstractArray,
6+
labels1,
7+
a2::AbstractArray,
8+
labels2,
9+
α,
10+
β,
11+
)
12+
axes_dest = output_axes(contract, alg, labels_dest, axes(a1), labels1, axes(a2), labels2)
13+
# TODO: Define `output_type(contract, alg, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β)`.
14+
# TODO: Define `output_structure(contract, alg, labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β)`.
15+
# TODO: Define `allocate(type, structure)`.
16+
return Array{promote_type(eltype(a1), eltype(a2))}(undef, length.(axes_dest))
17+
end
18+
19+
# TODO: Generalize to `output_structure`.
20+
function output_axes(
21+
f::typeof(contract), alg::Algorithm, labels_dest, axes1, labels1, axes2, labels2
22+
)
23+
biperm_dest, biperm1, biperm2 = bipartitioned_permutations(
24+
f, labels_dest, labels1, labels2
25+
)
26+
return output_axes(f, alg, biperm_dest, axes1, biperm1, axes2, biperm2)
27+
end
28+
29+
# TODO: Generalize to `output_structure`.
30+
function output_axes(
31+
f::typeof(contract),
32+
alg::Algorithm,
33+
biperm_dest::BipartitionedPermutation,
34+
axes1,
35+
biperm1::BipartitionedPermutation,
36+
axes2,
37+
biperm2::BipartitionedPermutation,
38+
)
39+
perm_dest = flatten(biperm_dest)
40+
nuncontracted1 = length(biperm1[1])
41+
axes_dest = map(perm_dest) do i
42+
return if i <= nuncontracted1
43+
axes1[biperm1[1][i]]
44+
else
45+
axes2[biperm2[2][i - nuncontracted1]]
46+
end
47+
end
48+
return axes_dest
49+
end
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
struct BipartitionedPermutation{P1,P2}
2+
partition1::P1
3+
partition2::P2
4+
end
5+
6+
function Base.getindex(biperm::BipartitionedPermutation, i)
7+
if i == 1
8+
return biperm.partition1
9+
elseif i == 2
10+
return biperm.partition2
11+
end
12+
return error("Only 2 partitions")
13+
end
14+
15+
function flatten(biperm::BipartitionedPermutation)
16+
return (biperm[1]..., biperm[2]...)
17+
end
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# TODO: Add `contract!!` definitions as pass-throughs to `contract!`.
2+
3+
default_contract_alg() = Algorithm"matricize"()
4+
5+
function contract(a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...)
6+
return contract(a1, labels1, a2, labels2, true, false; kwargs...)
7+
end
8+
9+
function contract(
10+
a1::AbstractArray,
11+
labels1,
12+
a2::AbstractArray,
13+
labels2,
14+
α,
15+
β;
16+
alg=default_contract_alg(),
17+
kwargs...,
18+
)
19+
return contract(Algorithm(alg), a1, labels1, a2, labels2, α, β; kwargs...)
20+
end
21+
22+
function contract(
23+
alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...
24+
)
25+
return contract(alg, a1, labels1, a2, labels2, true, false; kwargs...)
26+
end
27+
28+
function contract(
29+
alg::Algorithm, a1::AbstractArray, labels1, a2::AbstractArray, labels2, α, β; kwargs...
30+
)
31+
labels_dest = output_labels(contract, alg, a1, labels1, a2, labels2, α, β; kwargs...)
32+
return contract(alg, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...), labels_dest
33+
end
34+
35+
function contract(
36+
labels_dest, a1::AbstractArray, labels1, a2::AbstractArray, labels2; kwargs...
37+
)
38+
return contract(
39+
labels_dest,
40+
a1::AbstractArray,
41+
labels1,
42+
a2::AbstractArray,
43+
labels2,
44+
true,
45+
false;
46+
kwargs...,
47+
)
48+
end
49+
50+
function contract(
51+
labels_dest,
52+
a1::AbstractArray,
53+
labels1,
54+
a2::AbstractArray,
55+
labels2,
56+
α,
57+
β;
58+
alg=default_contract_alg(),
59+
kwargs...,
60+
)
61+
return contract(Algorithm(alg), labels_dest, a1, labels1, a2, labels2, α, β; kwargs...)
62+
end
63+
64+
function contract(
65+
alg::Algorithm,
66+
labels_dest,
67+
a1::AbstractArray,
68+
labels1,
69+
a2::AbstractArray,
70+
labels2;
71+
kwargs...,
72+
)
73+
return contract(alg, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...)
74+
end
75+
76+
function contract(
77+
alg::Algorithm,
78+
labels_dest,
79+
a1::AbstractArray,
80+
labels1,
81+
a2::AbstractArray,
82+
labels2,
83+
α,
84+
β;
85+
kwargs...,
86+
)
87+
biperm_dest, biperm1, biperm2 = bipartitioned_permutations(
88+
contract, alg, labels_dest, labels1, labels2
89+
)
90+
a_dest = allocate_output(
91+
contract, alg, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...
92+
)
93+
contract!(alg, a_dest, biperm_dest, a1, biperm1, a2, biperm2, α, β; kwargs...)
94+
return a_dest
95+
end
96+
97+
function contract!(
98+
a_dest::AbstractArray,
99+
labels_dest,
100+
a1::AbstractArray,
101+
labels1,
102+
a2::AbstractArray,
103+
labels2;
104+
kwargs...,
105+
)
106+
contract!(a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...)
107+
return a_dest
108+
end
109+
110+
function contract!(
111+
a_dest::AbstractArray,
112+
labels_dest,
113+
a1::AbstractArray,
114+
labels1,
115+
a2::AbstractArray,
116+
labels2,
117+
α,
118+
β;
119+
alg=default_contract_alg(),
120+
kwargs...,
121+
)
122+
contract!(Algorithm(alg), a_dest, labels_dest, a1, labels1, a2, labels2, α, β; kwargs...)
123+
return a_dest
124+
end
125+
126+
function contract!(
127+
alg::Algorithm,
128+
a_dest::AbstractArray,
129+
labels_dest,
130+
a1::AbstractArray,
131+
labels1,
132+
a2::AbstractArray,
133+
labels2;
134+
kwargs...,
135+
)
136+
contract!(alg, a_dest, labels_dest, a1, labels1, a2, labels2, true, false; kwargs...)
137+
return a_dest
138+
end
139+
140+
function contract!(
141+
alg::Algorithm,
142+
a_dest::AbstractArray,
143+
labels_dest,
144+
a1::AbstractArray,
145+
labels1,
146+
a2::AbstractArray,
147+
labels2,
148+
α,
149+
β;
150+
kwargs...,
151+
)
152+
return error("Not implemented")
153+
end
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
function contract!(
2+
alg::Algorithm"matricize",
3+
a_dest::AbstractArray,
4+
biperm_dest::BipartitionedPermutation,
5+
a1::AbstractArray,
6+
biperm1::BipartitionedPermutation,
7+
a2::AbstractArray,
8+
biperm2::BipartitionedPermutation,
9+
α,
10+
β,
11+
)
12+
a_dest_matricized = matricize(a_dest, biperm_dest)
13+
a1_matricized = matricize(a1, biperm1)
14+
a2_matricized = matricize(a2, biperm2)
15+
mul!(a_dest_matricized, a1_matricized, a2_matricized, α, β)
16+
perm_dest = flatten(biperm_dest)
17+
# TODO: Create a function `unmatricize` or `unfusedims`.
18+
# unmatricize!(a_dest, a_dest_matricized, axes(a_dest), perm_dest)
19+
a_dest_copy = reshape(a_dest_matricized, axes(a_dest))
20+
permutedims!(a_dest, a_dest_copy, perm_dest)
21+
return a_dest
22+
end
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
function output_labels(
2+
f::typeof(contract),
3+
alg::Algorithm,
4+
a1::AbstractArray,
5+
labels1,
6+
a2::AbstractArray,
7+
labels2,
8+
α,
9+
β,
10+
)
11+
return output_labels(f, alg, labels1, labels2)
12+
end
13+
14+
function output_labels(f::typeof(contract), alg::Algorithm, labels1, labels2)
15+
return output_labels(f, labels1, labels2)
16+
end
17+
18+
function output_labels(::typeof(contract), labels1, labels2)
19+
return symdiff(labels1, labels2)
20+
end
21+
22+
function bipartitioned_permutations(
23+
f::typeof(contract), alg::Algorithm, labels_dest, labels1, labels2
24+
)
25+
return bipartitioned_permutations(f, labels_dest, labels1, labels2)
26+
end
27+
28+
function bipartitioned_permutations(::typeof(contract), labels_dest, labels1, labels2)
29+
labels12 = (labels1..., labels2...)
30+
if isodd(length(labels12) - length(labels_dest))
31+
error("Can't contract $labels1 and $labels2 to $labels_dest")
32+
end
33+
labels_contracted = unique(setdiff(labels12, labels_dest))
34+
labels1_uncontracted = setdiff(labels1, labels_contracted)
35+
labels2_uncontracted = setdiff(labels2, labels_contracted)
36+
# Positions of labels.
37+
pos_dest_1 = map(l -> findfirst(isequal(l), labels_dest), labels1_uncontracted)
38+
pos_dest_2 = map(l -> findfirst(isequal(l), labels_dest), labels2_uncontracted)
39+
pos1_contracted = map(l -> findfirst(isequal(l), labels1), labels_contracted)
40+
pos2_contracted = map(l -> findfirst(isequal(l), labels2), labels_contracted)
41+
pos1_uncontracted = map(l -> findfirst(isequal(l), labels1), labels1_uncontracted)
42+
pos2_uncontracted = map(l -> findfirst(isequal(l), labels2), labels2_uncontracted)
43+
# Bipartitioned permutations.
44+
biperm_dest = BipartitionedPermutation(pos_dest_1, pos_dest_2)
45+
biperm1 = BipartitionedPermutation(pos1_uncontracted, pos1_contracted)
46+
biperm2 = BipartitionedPermutation(pos2_contracted, pos2_uncontracted)
47+
return biperm_dest, biperm1, biperm2
48+
end
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
fuse(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) * length(a2))
2+
fuse(a...) = foldl(fuse, a)
3+
4+
matricize(a::AbstractArray, biperm) = matricize(a, BipartitionedPermutation(biperm...))
5+
6+
function matricize(a::AbstractArray, biperm::BipartitionedPermutation)
7+
# Permute and fuse the axes
8+
axes_src = axes(a)
9+
axes_codomain = map(i -> axes_src[i], biperm[1])
10+
axes_domain = map(i -> axes_src[i], biperm[2])
11+
axis_codomain_fused = fuse(axes_codomain...)
12+
axis_domain_fused = fuse(axes_domain...)
13+
# Permute the array
14+
perm = flatten(biperm)
15+
a_permuted = permutedims(a, perm)
16+
return reshape(a_permuted, (axis_codomain_fused, axis_domain_fused))
17+
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
3+
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"

0 commit comments

Comments
 (0)