Skip to content

Commit a8fd705

Browse files
authored
Merge f9d2e7b into 3f1afb8
2 parents 3f1afb8 + f9d2e7b commit a8fd705

File tree

7 files changed

+94
-21
lines changed

7 files changed

+94
-21
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module LinearAlgebraExtensions
2+
using LinearAlgebra: LinearAlgebra, qr
3+
using ..TensorAlgebra:
4+
TensorAlgebra,
5+
BipartitionedPermutation,
6+
bipartition,
7+
bipartitioned_permutations,
8+
matricize,
9+
unmatricize
10+
11+
include("qr.jl")
12+
end
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function LinearAlgebra.qr(a::AbstractArray, labels_a, labels_q, labels_r)
2+
return qr(a, bipartitioned_permutations(qr, labels_a, labels_q, labels_r)...)
3+
end
4+
5+
function LinearAlgebra.qr(a::AbstractArray, biperm::BipartitionedPermutation)
6+
# TODO: Use a thin QR, define `qr_thin`.
7+
a_matricized = matricize(a, biperm)
8+
q_matricized, r_matricized = qr(a_matricized)
9+
q_matricized_thin = typeof(a_matricized)(q_matricized)
10+
axes_codomain, axes_domain = bipartition(axes(a), biperm)
11+
q = unmatricize(q_matricized_thin, axes_codomain, (axes(q_matricized_thin, 2),))
12+
r = unmatricize(r_matricized, (axes(r_matricized, 1),), axes_domain)
13+
return q, r
14+
end
15+
16+
function TensorAlgebra.bipartitioned_permutations(qr, labels_a, labels_q, labels_r)
17+
# TODO: Use something like `findall`?
18+
pos_q = map(l -> findfirst(isequal(l), labels_a), labels_q)
19+
pos_r = map(l -> findfirst(isequal(l), labels_a), labels_r)
20+
return (BipartitionedPermutation(pos_q, pos_r),)
21+
end

NDTensors/src/TensorAlgebra/src/TensorAlgebra.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ include("contract/contract.jl")
88
include("contract/output_labels.jl")
99
include("contract/allocate_output.jl")
1010
include("contract/contract_matricize/contract.jl")
11+
include("LinearAlgebraExtensions/LinearAlgebraExtensions.jl")
1112
end

NDTensors/src/TensorAlgebra/src/bipartitionedpermutation.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,12 @@ end
1515
function flatten(biperm::BipartitionedPermutation)
1616
return (biperm[1]..., biperm[2]...)
1717
end
18+
19+
# Bipartition a vector according to the
20+
# bipartitioned permutation.
21+
function bipartition(v, biperm::BipartitionedPermutation)
22+
# TODO: Use `TupleTools.getindices`.
23+
v1 = map(i -> v[i], biperm[1])
24+
v2 = map(i -> v[i], biperm[2])
25+
return v1, v2
26+
end

NDTensors/src/TensorAlgebra/src/contract/contract_matricize/contract.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function contract!(
1616
perm_dest = flatten(biperm_dest)
1717
# TODO: Create a function `unmatricize` or `unfusedims`.
1818
# unmatricize!(a_dest, a_dest_matricized, axes(a_dest), perm_dest)
19-
a_dest_copy = reshape(a_dest_matricized, axes(a_dest))
20-
permutedims!(a_dest, a_dest_copy, perm_dest)
19+
a_dest_copy = reshape(a_dest_matricized, map(i -> axes(a_dest, i), perm_dest))
20+
permutedims!(a_dest, a_dest_copy, invperm(perm_dest))
2121
return a_dest
2222
end

NDTensors/src/TensorAlgebra/src/fusedims.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ fuse(a...) = foldl(fuse, a)
33

44
matricize(a::AbstractArray, biperm) = matricize(a, BipartitionedPermutation(biperm...))
55

6+
# TODO: Make this more generic, i.e. for `BlockSparseArray`.
67
function matricize(a::AbstractArray, biperm::BipartitionedPermutation)
78
# Permute and fuse the axes
89
axes_src = axes(a)
@@ -15,3 +16,8 @@ function matricize(a::AbstractArray, biperm::BipartitionedPermutation)
1516
a_permuted = permutedims(a, perm)
1617
return reshape(a_permuted, (axis_codomain_fused, axis_domain_fused))
1718
end
19+
20+
# TODO: Make this more generic, i.e. for `BlockSparseArray`.
21+
function unmatricize(a::AbstractArray, axes_codomain, axes_domain)
22+
return reshape(a, (axes_codomain..., axes_domain...))
23+
end
Lines changed: 43 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,51 @@
11
using Combinatorics: permutations
2+
using LinearAlgebra: qr
23
using NDTensors.TensorAlgebra: TensorAlgebra
34
using TensorOperations: TensorOperations
4-
using Test: @test, @testset
5+
using Test: @test, @test_broken, @testset
56

67
@testset "TensorAlgebra" begin
7-
dims = (2, 3, 4, 5)
8-
labels = (:a, :b, :c, :d)
9-
for (d1s, d2s) in (((1, 2), (2, 3)), ((1, 2, 3), (2, 3, 4)), ((1, 2, 3), (3, 4)))
10-
a1 = randn(map(i -> dims[i], d1s))
11-
labels1 = map(i -> labels[i], d1s)
12-
a2 = randn(map(i -> dims[i], d2s))
13-
labels2 = map(i -> labels[i], d2s)
14-
for perm1 in permutations(1:ndims(a1)), perm2 in permutations(1:ndims(a2))
15-
a1′ = permutedims(a1, perm1)
16-
a2′ = permutedims(a2, perm2)
17-
labels1′ = map(i -> labels1[i], perm1)
18-
labels2′ = map(i -> labels2[i], perm2)
19-
a_dest, labels_dest = TensorAlgebra.contract(a1′, labels1′, a2′, labels2′)
20-
@test labels_dest == symdiff(labels1′, labels2′)
21-
a_dest_tensoroperations = TensorOperations.tensorcontract(
22-
labels_dest, a1′, labels1′, a2′, labels2′
23-
)
24-
@test a_dest a_dest_tensoroperations
8+
elts = (Float32, ComplexF32, Float64, ComplexF64)
9+
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
10+
dims = (2, 3, 4, 5)
11+
labels = (:a, :b, :c, :d)
12+
for (d1s, d2s) in (((1, 2), (2, 3)), ((1, 2, 3), (2, 3, 4)), ((1, 2, 3), (3, 4)))
13+
a1 = randn(elt1, map(i -> dims[i], d1s))
14+
labels1 = map(i -> labels[i], d1s)
15+
a2 = randn(elt2, map(i -> dims[i], d2s))
16+
labels2 = map(i -> labels[i], d2s)
17+
for perm1 in permutations(1:ndims(a1)), perm2 in permutations(1:ndims(a2))
18+
a1′ = permutedims(a1, perm1)
19+
a2′ = permutedims(a2, perm2)
20+
labels1′ = map(i -> labels1[i], perm1)
21+
labels2′ = map(i -> labels2[i], perm2)
22+
a_dest, labels_dest = TensorAlgebra.contract(a1′, labels1′, a2′, labels2′)
23+
@test labels_dest == symdiff(labels1′, labels2′)
24+
a_dest_tensoroperations = TensorOperations.tensorcontract(
25+
labels_dest, a1′, labels1′, a2′, labels2′
26+
)
27+
@test a_dest a_dest_tensoroperations
28+
end
2529
end
2630
end
31+
@testset "contract broken" begin
32+
a1 = randn(3, 5, 8)
33+
a2 = randn(8, 2, 4)
34+
labels_dest = (:a, :b, :c, :d)
35+
labels1 = (:c, :a, :x)
36+
labels2 = (:x, :d, :b)
37+
@test_broken a′ = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
38+
end
39+
@testset "qr" begin
40+
a = randn(5, 4, 3, 2)
41+
labels_a = (:a, :b, :c, :d)
42+
labels_q = (:b, :a)
43+
labels_r = (:d, :c)
44+
q, r = qr(a, labels_a, labels_q, labels_r)
45+
label_qr = :qr
46+
a′ = TensorAlgebra.contract(
47+
labels_a, q, (labels_q..., label_qr), r, (label_qr, labels_r...)
48+
)
49+
@test a a′
50+
end
2751
end

0 commit comments

Comments
 (0)