Skip to content

Commit 7a6e342

Browse files
authored
[GradedAxes] [BlockSparseArrays] Upgrade to BlockArrays.jl v1 (#1495)
* [GradedAxes] [BlockSparseArrays] Upgrade to BlockArrays.jl v1 * [NDTensors] Bump to v0.3.28
1 parent 3e1305d commit 7a6e342

35 files changed

+279
-1263
lines changed

NDTensors/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "NDTensors"
22
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
33
authors = ["Matthew Fishman <[email protected]>"]
4-
version = "0.3.27"
4+
version = "0.3.28"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
@@ -56,7 +56,7 @@ AMDGPU = "0.9"
5656
Accessors = "0.1.33"
5757
Adapt = "3.7, 4"
5858
ArrayLayouts = "1.4"
59-
BlockArrays = "0.16"
59+
BlockArrays = "1"
6060
CUDA = "5"
6161
Compat = "4.9"
6262
cuTENSOR = "2"

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
module BlockSparseArraysGradedAxesExt
2-
using BlockArrays: AbstractBlockVector, Block, BlockedUnitRange, blocks
2+
using BlockArrays:
3+
AbstractBlockVector,
4+
AbstractBlockedUnitRange,
5+
Block,
6+
BlockIndexRange,
7+
blockedrange,
8+
blocks
39
using ..BlockSparseArrays:
410
BlockSparseArrays,
511
AbstractBlockSparseArray,
612
AbstractBlockSparseMatrix,
713
BlockSparseArray,
814
BlockSparseMatrix,
15+
BlockSparseVector,
916
block_merge
1017
using ...GradedAxes:
11-
GradedUnitRange,
18+
GradedAxes,
19+
AbstractGradedUnitRange,
1220
OneToOne,
1321
blockmergesortperm,
1422
blocksortperm,
@@ -23,11 +31,13 @@ using ...TensorAlgebra:
2331
# TODO: Make a `ReduceWhile` library.
2432
include("reducewhile.jl")
2533

26-
TensorAlgebra.FusionStyle(::GradedUnitRange) = SectorFusion()
34+
TensorAlgebra.FusionStyle(::AbstractGradedUnitRange) = SectorFusion()
2735

2836
# TODO: Need to implement this! Will require implementing
2937
# `block_merge(a::AbstractUnitRange, blockmerger::BlockedUnitRange)`.
30-
function BlockSparseArrays.block_merge(a::GradedUnitRange, blockmerger::BlockedUnitRange)
38+
function BlockSparseArrays.block_merge(
39+
a::AbstractGradedUnitRange, blockmerger::AbstractBlockedUnitRange
40+
)
3141
return a
3242
end
3343

@@ -75,6 +85,44 @@ function Base.axes(a::Adjoint{<:Any,<:AbstractBlockSparseMatrix})
7585
return dual.(reverse(axes(a')))
7686
end
7787

88+
# TODO: Delete this definition in favor of the one in
89+
# GradedAxes once https://github.com/JuliaArrays/BlockArrays.jl/pull/405 is merged.
90+
# TODO: Make a special definition for `BlockedVector{<:Block{1}}` in order
91+
# to merge blocks.
92+
function GradedAxes.blockedunitrange_getindices(
93+
a::AbstractBlockedUnitRange, indices::AbstractVector{<:Union{Block{1},BlockIndexRange{1}}}
94+
)
95+
# Without converting `indices` to `Vector`,
96+
# mapping `indices` outputs a `BlockVector`
97+
# which is harder to reason about.
98+
blocks = map(index -> a[index], Vector(indices))
99+
# We pass `length.(blocks)` to `mortar` in order
100+
# to pass block labels to the axes of the output,
101+
# if they exist. This makes it so that
102+
# `only(axes(a[indices])) isa `GradedUnitRange`
103+
# if `a isa `GradedUnitRange`, for example.
104+
# TODO: Remove `unlabel` once `BlockArray` axes
105+
# type is generalized in BlockArrays.jl.
106+
# TODO: Support using `BlockSparseVector`, need
107+
# to make more `BlockSparseArray` constructors.
108+
return BlockSparseArray(blocks, (blockedrange(length.(blocks)),))
109+
end
110+
111+
# This definition is only needed since calls like
112+
# `a[[Block(1), Block(2)]]` where `a isa AbstractGradedUnitRange`
113+
# returns a `BlockSparseVector` instead of a `BlockVector`
114+
# due to limitations in the `BlockArray` type not allowing
115+
# axes with non-Int element types.
116+
# TODO: Remove this once that issue is fixed,
117+
# see https://github.com/JuliaArrays/BlockArrays.jl/pull/405.
118+
using BlockArrays: BlockRange
119+
using NDTensors.LabelledNumbers: label
120+
function GradedAxes.blocklabels(a::BlockSparseVector)
121+
return map(BlockRange(a)) do block
122+
return label(blocks(a)[Int(block)])
123+
end
124+
end
125+
78126
# This is a temporary fix for `show` being broken for BlockSparseArrays
79127
# with mixed dual and non-dual axes. This shouldn't be needed once
80128
# GradedAxes is rewritten using BlockArrays v1.

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysGradedAxesExt/test/runtests.jl

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
@eval module $(gensym())
22
using Compat: Returns
33
using Test: @test, @testset, @test_broken
4-
using BlockArrays: Block, blockedrange, blocksize
4+
using BlockArrays: Block, BlockedOneTo, blockedrange, blocklengths, blocksize
55
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
66
using NDTensors.GradedAxes:
7-
GradedAxes, GradedUnitRange, UnitRangeDual, blocklabels, dual, gradedrange
7+
GradedAxes, GradedOneTo, UnitRangeDual, blocklabels, dual, gradedrange
88
using NDTensors.LabelledNumbers: label
99
using NDTensors.SparseArrayInterface: nstored
1010
using NDTensors.TensorAlgebra: fusedims, splitdims
@@ -35,15 +35,34 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3535
for b in (a + a, 2 * a)
3636
@test size(b) == (4, 4, 4, 4)
3737
@test blocksize(b) == (2, 2, 2, 2)
38+
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
3839
@test nstored(b) == 32
3940
@test block_nstored(b) == 2
4041
# TODO: Have to investigate why this fails
4142
# on Julia v1.6, or drop support for v1.6.
4243
for i in 1:ndims(a)
43-
@test axes(b, i) isa GradedUnitRange
44+
@test axes(b, i) isa GradedOneTo
4445
end
4546
@test label(axes(b, 1)[Block(1)]) == U1(0)
4647
@test label(axes(b, 1)[Block(2)]) == U1(1)
48+
@test Array(b) isa Array{elt}
49+
@test Array(b) == b
50+
@test 2 * Array(a) == b
51+
end
52+
53+
# Test mixing graded axes and dense axes
54+
# in addition/broadcasting.
55+
for b in (a + Array(a), Array(a) + a)
56+
@test size(b) == (4, 4, 4, 4)
57+
@test blocksize(b) == (2, 2, 2, 2)
58+
@test blocklengths.(axes(b)) == ([2, 2], [2, 2], [2, 2], [2, 2])
59+
# TODO: Fix this for `BlockedArray`.
60+
@test_broken nstored(b) == 256
61+
# TODO: Fix this for `BlockedArray`.
62+
@test_broken block_nstored(b) == 16
63+
for i in 1:ndims(a)
64+
@test axes(b, i) isa BlockedOneTo{Int}
65+
end
4766
@test Array(a) isa Array{elt}
4867
@test Array(a) == a
4968
@test 2 * Array(a) == b
@@ -55,7 +74,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
5574
@test nstored(b) == 2
5675
@test block_nstored(b) == 2
5776
for i in 1:ndims(a)
58-
@test axes(b, i) isa GradedUnitRange
77+
@test axes(b, i) isa GradedOneTo
5978
end
6079
@test label(axes(b, 1)[Block(1)]) == U1(0)
6180
@test label(axes(b, 1)[Block(2)]) == U1(1)
@@ -72,7 +91,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
7291
# TODO: Once block merging is implemented, this should
7392
# be the real test.
7493
for ax in axes(m)
75-
@test ax isa GradedUnitRange
94+
@test ax isa GradedOneTo
7695
# TODO: Current `fusedims` doesn't merge
7796
# common sectors, need to fix.
7897
@test_broken blocklabels(ax) == [U1(0), U1(1), U1(2)]

NDTensors/src/lib/BlockSparseArrays/ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
module BlockSparseArraysTensorAlgebraExt
2-
using BlockArrays: BlockedUnitRange
2+
using BlockArrays: AbstractBlockedUnitRange
33
using ..BlockSparseArrays: AbstractBlockSparseArray, block_reshape
44
using ...GradedAxes: tensor_product
55
using ...TensorAlgebra: TensorAlgebra, FusionStyle, BlockReshapeFusion
66

7-
TensorAlgebra.:(a1::BlockedUnitRange, a2::BlockedUnitRange) = tensor_product(a1, a2)
7+
function TensorAlgebra.:(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
8+
return tensor_product(a1, a2)
9+
end
810

9-
TensorAlgebra.FusionStyle(::BlockedUnitRange) = BlockReshapeFusion()
11+
TensorAlgebra.FusionStyle(::AbstractBlockedUnitRange) = BlockReshapeFusion()
1012

1113
function TensorAlgebra.fusedims(
1214
::BlockReshapeFusion, a::AbstractArray, axes::AbstractUnitRange...

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/broadcast.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using BlockArrays: BlockedUnitRange, BlockSlice
1+
using BlockArrays: AbstractBlockedUnitRange, BlockSlice
22
using Base.Broadcast: Broadcast
33

44
function Broadcast.BroadcastStyle(arraytype::Type{<:BlockSparseArrayLike})
@@ -12,7 +12,7 @@ function Broadcast.BroadcastStyle(
1212
<:Any,
1313
<:Any,
1414
<:AbstractBlockSparseArray,
15-
<:Tuple{BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}},
15+
<:Tuple{BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
1616
},
1717
},
1818
)
@@ -25,8 +25,8 @@ function Broadcast.BroadcastStyle(
2525
<:Any,
2626
<:AbstractBlockSparseArray,
2727
<:Tuple{
28-
BlockSlice{<:Any,<:BlockedUnitRange},
29-
BlockSlice{<:Any,<:BlockedUnitRange},
28+
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
29+
BlockSlice{<:Any,<:AbstractBlockedUnitRange},
3030
Vararg{Any},
3131
},
3232
},
@@ -40,7 +40,7 @@ function Broadcast.BroadcastStyle(
4040
<:Any,
4141
<:Any,
4242
<:AbstractBlockSparseArray,
43-
<:Tuple{Any,BlockSlice{<:Any,<:BlockedUnitRange},Vararg{Any}},
43+
<:Tuple{Any,BlockSlice{<:Any,<:AbstractBlockedUnitRange},Vararg{Any}},
4444
},
4545
},
4646
)

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
using Adapt: Adapt, WrappedArray
22
using BlockArrays:
3-
BlockArrays, BlockedUnitRange, BlockIndexRange, BlockRange, blockedrange, mortar, unblock
3+
BlockArrays,
4+
AbstractBlockedUnitRange,
5+
BlockIndexRange,
6+
BlockRange,
7+
blockedrange,
8+
mortar,
9+
unblock
410
using SplitApplyCombine: groupcount
511

612
const WrappedAbstractBlockSparseArray{T,N} = WrappedArray{
@@ -208,7 +214,7 @@ end
208214
# Fixes ambiguity error with `BlockArrays.jl`.
209215
function Base.similar(
210216
arraytype::Type{<:BlockSparseArrayLike},
211-
axes::Tuple{BlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
217+
axes::Tuple{AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
212218
)
213219
return similar(arraytype, eltype(arraytype), axes)
214220
end
@@ -217,14 +223,27 @@ end
217223
# Fixes ambiguity error with `BlockArrays.jl`.
218224
function Base.similar(
219225
arraytype::Type{<:BlockSparseArrayLike},
220-
axes::Tuple{AbstractUnitRange{Int},BlockedUnitRange,Vararg{AbstractUnitRange{Int}}},
226+
axes::Tuple{
227+
AbstractBlockedUnitRange,AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}
228+
},
229+
)
230+
return similar(arraytype, eltype(arraytype), axes)
231+
end
232+
233+
# Needed by `BlockArrays` matrix multiplication interface
234+
# Fixes ambiguity error with `BlockArrays.jl`.
235+
function Base.similar(
236+
arraytype::Type{<:BlockSparseArrayLike},
237+
axes::Tuple{
238+
AbstractUnitRange{Int},AbstractBlockedUnitRange,Vararg{AbstractUnitRange{Int}}
239+
},
221240
)
222241
return similar(arraytype, eltype(arraytype), axes)
223242
end
224243

225244
# Needed for disambiguation
226245
function Base.similar(
227-
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{BlockedUnitRange}}
246+
arraytype::Type{<:BlockSparseArrayLike}, axes::Tuple{Vararg{AbstractBlockedUnitRange}}
228247
)
229248
return similar(arraytype, eltype(arraytype), axes)
230249
end
@@ -251,7 +270,9 @@ end
251270
# TODO: Define a `blocksparse_similar` function.
252271
# Fixes ambiguity error with `BlockArrays`.
253272
function Base.similar(
254-
a::BlockSparseArrayLike, elt::Type, axes::Tuple{BlockedUnitRange,Vararg{BlockedUnitRange}}
273+
a::BlockSparseArrayLike,
274+
elt::Type,
275+
axes::Tuple{AbstractBlockedUnitRange,Vararg{AbstractBlockedUnitRange}},
255276
)
256277
# TODO: Make generic for GPU, maybe using `blocktype`.
257278
# TODO: For non-block axes this should output `Array`.

NDTensors/src/lib/BlockSparseArrays/src/backup/BlockSparseArrays.jl

Lines changed: 0 additions & 45 deletions
This file was deleted.

NDTensors/src/lib/BlockSparseArrays/src/backup/LinearAlgebraExt/LinearAlgebraExt.jl

Lines changed: 0 additions & 16 deletions
This file was deleted.

NDTensors/src/lib/BlockSparseArrays/src/backup/LinearAlgebraExt/eigen.jl

Lines changed: 0 additions & 19 deletions
This file was deleted.

NDTensors/src/lib/BlockSparseArrays/src/backup/LinearAlgebraExt/hermitian.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

0 commit comments

Comments
 (0)