Skip to content

Commit 1864f14

Browse files
authored
Merge d9b7d97 into ea0f602
2 parents ea0f602 + d9b7d97 commit 1864f14

File tree

13 files changed

+412
-41
lines changed

13 files changed

+412
-41
lines changed

NDTensors/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.2.11"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
89
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
910
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1011
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# BlockSparseArrays.jl
2+
3+
A Julia `BlockSparseArray` type based on the `BlockArrays.jl` interface.
4+
5+
It wraps an elementwise `SparseArray` type that uses a dictionary-of-keys
6+
to store non-zero values, specifically a `Dictionary` from `Dictionaries.jl`.
7+
`BlockArrays` reinterprets the `SparseArray` as a blocked data structure.
8+
9+
````julia
10+
using NDTensors.BlockSparseArrays
11+
using BlockArrays
12+
13+
# Block dimensions
14+
i1 = [2, 3]
15+
i2 = [2, 3]
16+
17+
i_axes = (blockedrange(i1), blockedrange(i2))
18+
19+
function block_size(axes, block)
20+
return length.(getindex.(axes, Block.(block.n)))
21+
end
22+
23+
# Data
24+
nz_blocks = [Block(1, 1), Block(2, 2)]
25+
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
26+
nz_block_lengths = prod.(nz_block_sizes)
27+
28+
# Blocks with discontiguous underlying data
29+
d_blocks = randn.(nz_block_sizes)
30+
31+
# Blocks with contiguous underlying data
32+
# d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths)
33+
# d_blocks = [reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for i in 1:length(nz_blocks)]
34+
35+
B = BlockSparseArray(nz_blocks, d_blocks, i_axes)
36+
37+
# Access a block
38+
B[Block(1, 1)]
39+
40+
# Access a non-zero block, returns a zero matrix
41+
B[Block(1, 2)]
42+
43+
# Set a zero block
44+
B[Block(1, 2)] = randn(2, 3)
45+
46+
# Matrix multiplication (not optimized for sparsity yet)
47+
B * B
48+
````
49+
50+
You can generate this README with:
51+
```julia
52+
using Literate
53+
Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
54+
```
55+
56+
---
57+
58+
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
59+
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# # BlockSparseArrays.jl
2+
#
3+
# A Julia `BlockSparseArray` type based on the `BlockArrays.jl` interface.
4+
#
5+
# It wraps an elementwise `SparseArray` type that uses a dictionary-of-keys
6+
# to store non-zero values, specifically a `Dictionary` from `Dictionaries.jl`.
7+
# `BlockArrays` reinterprets the `SparseArray` as a blocked data structure.
8+
9+
using NDTensors.BlockSparseArrays
10+
using BlockArrays
11+
12+
## Block dimensions
13+
i1 = [2, 3]
14+
i2 = [2, 3]
15+
16+
i_axes = (blockedrange(i1), blockedrange(i2))
17+
18+
function block_size(axes, block)
19+
return length.(getindex.(axes, Block.(block.n)))
20+
end
21+
22+
## Data
23+
nz_blocks = [Block(1, 1), Block(2, 2)]
24+
nz_block_sizes = [block_size(i_axes, nz_block) for nz_block in nz_blocks]
25+
nz_block_lengths = prod.(nz_block_sizes)
26+
27+
## Blocks with discontiguous underlying data
28+
d_blocks = randn.(nz_block_sizes)
29+
30+
## Blocks with contiguous underlying data
31+
## d_data = PseudoBlockVector(randn(sum(nz_block_lengths)), nz_block_lengths)
32+
## d_blocks = [reshape(@view(d_data[Block(i)]), block_size(i_axes, nz_blocks[i])) for i in 1:length(nz_blocks)]
33+
34+
B = BlockSparseArray(nz_blocks, d_blocks, i_axes)
35+
36+
## Access a block
37+
B[Block(1, 1)]
38+
39+
## Access a non-zero block, returns a zero matrix
40+
B[Block(1, 2)]
41+
42+
## Set a zero block
43+
B[Block(1, 2)] = randn(2, 3)
44+
45+
## Matrix multiplication (not optimized for sparsity yet)
46+
B * B
47+
48+
# You can generate this README with:
49+
# ```julia
50+
# using Literate
51+
# Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
52+
# ```
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module BlockSparseArrays
2+
using BlockArrays
3+
using Dictionaries
4+
5+
using BlockArrays: block
6+
7+
export BlockSparseArray, SparseArray
8+
9+
include("sparsearray.jl")
10+
include("blocksparsearray.jl")
11+
12+
end
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
using BlockArrays: block
2+
3+
# Also add a version with contiguous underlying data.
4+
struct BlockSparseArray{
5+
T,N,Blocks<:SparseArray{<:AbstractArray{T,N},N},Axes<:NTuple{N,AbstractUnitRange{Int}}
6+
} <: AbstractBlockArray{T,N}
7+
blocks::Blocks
8+
axes::Axes
9+
end
10+
11+
# The size of a block
12+
function block_size(axes::Tuple, block::Block)
13+
return length.(getindex.(axes, Block.(block.n)))
14+
end
15+
16+
struct BlockZero{Axes}
17+
axes::Axes
18+
end
19+
20+
function (f::BlockZero)(T::Type, I::CartesianIndex)
21+
return fill!(T(undef, block_size(f.axes, Block(Tuple(I)))), false)
22+
end
23+
24+
function BlockSparseArray(
25+
blocks::AbstractVector{<:Block{N}}, blockdata::AbstractVector, axes::NTuple{N}
26+
) where {N}
27+
return BlockSparseArray(Dictionary(blocks, blockdata), axes)
28+
end
29+
30+
function BlockSparseArray(
31+
blockdata::Dictionary{<:Block{N}}, axes::NTuple{N,AbstractUnitRange{Int}}
32+
) where {N}
33+
blocks = keys(blockdata)
34+
cartesianblocks = map(block -> CartesianIndex(block.n), blocks)
35+
cartesiandata = Dictionary(cartesianblocks, blockdata)
36+
block_storage = SparseArray(cartesiandata, blocklength.(axes), BlockZero(axes))
37+
return BlockSparseArray(block_storage, axes)
38+
end
39+
40+
function BlockSparseArray(
41+
blockdata::Dictionary{<:Block{N}}, blockinds::NTuple{N,AbstractVector}
42+
) where {N}
43+
return BlockSparseArray(blockdata, blockedrange.(blockinds))
44+
end
45+
46+
Base.axes(block_arr::BlockSparseArray) = block_arr.axes
47+
48+
function Base.copy(block_arr::BlockSparseArray)
49+
return BlockSparseArray(deepcopy(block_arr.blocks), copy.(block_arr.axes))
50+
end
51+
52+
function BlockArrays.viewblock(block_arr::BlockSparseArray, block)
53+
blks = block.n
54+
@boundscheck blockcheckbounds(block_arr, blks...)
55+
## block_size = length.(getindex.(axes(block_arr), Block.(blks)))
56+
# TODO: Make this `Zeros`?
57+
## zero = zeros(eltype(block_arr), block_size)
58+
return block_arr.blocks[blks...] # Fails because zero isn't defined
59+
## return get_nonzero(block_arr.blocks, blks, zero)
60+
end
61+
62+
function Base.getindex(block_arr::BlockSparseArray{T,N}, bi::BlockIndex{N}) where {T,N}
63+
@boundscheck blockcheckbounds(block_arr, Block(bi.I))
64+
bl = view(block_arr, block(bi))
65+
inds = bi.α
66+
@boundscheck checkbounds(bl, inds...)
67+
v = bl[inds...]
68+
return v
69+
end
70+
71+
function Base.setindex!(
72+
block_arr::BlockSparseArray{T,N}, v, i::Vararg{Integer,N}
73+
) where {T,N}
74+
@boundscheck checkbounds(block_arr, i...)
75+
block_indices = findblockindex.(axes(block_arr), i)
76+
block = map(block_index -> Block(block_index.I), block_indices)
77+
offsets = map(block_index -> only(block_index.α), block_indices)
78+
block_view = @view block_arr[block...]
79+
block_view[offsets...] = v
80+
block_arr[block...] = block_view
81+
return block_arr
82+
end
83+
84+
function BlockArrays._check_setblock!(
85+
block_arr::BlockSparseArray{T,N}, v, block::NTuple{N,Integer}
86+
) where {T,N}
87+
for i in 1:N
88+
bsz = length(axes(block_arr, i)[Block(block[i])])
89+
if size(v, i) != bsz
90+
throw(
91+
DimensionMismatch(
92+
string(
93+
"tried to assign $(size(v)) array to ",
94+
length.(getindex.(axes(block_arr), block)),
95+
" block",
96+
),
97+
),
98+
)
99+
end
100+
end
101+
end
102+
103+
function Base.setindex!(
104+
block_arr::BlockSparseArray{T,N}, v, block::Vararg{Block{1},N}
105+
) where {T,N}
106+
blks = Int.(block)
107+
@boundscheck blockcheckbounds(block_arr, blks...)
108+
@boundscheck BlockArrays._check_setblock!(block_arr, v, blks)
109+
# This fails since it tries to replace the element
110+
block_arr.blocks[blks...] = v
111+
# Use .= here to overwrite data.
112+
## block_view = @view block_arr[Block(blks)]
113+
## block_view .= v
114+
return block_arr
115+
end
116+
117+
function Base.getindex(block_arr::BlockSparseArray{T,N}, i::Vararg{Integer,N}) where {T,N}
118+
@boundscheck checkbounds(block_arr, i...)
119+
v = block_arr[findblockindex.(axes(block_arr), i)...]
120+
return v
121+
end
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
2+
data::Dictionary{CartesianIndex{N},T}
3+
dims::NTuple{N,Int64}
4+
zero::Zero
5+
end
6+
7+
Base.size(a::SparseArray) = a.dims
8+
9+
function Base.setindex!(a::SparseArray{T,N}, v, I::CartesianIndex{N}) where {T,N}
10+
set!(a.data, I, v)
11+
return a
12+
end
13+
function Base.setindex!(a::SparseArray{T,N}, v, I::Vararg{Int,N}) where {T,N}
14+
return setindex!(a, v, CartesianIndex(I))
15+
end
16+
17+
function Base.getindex(a::SparseArray{T,N}, I::CartesianIndex{N}) where {T,N}
18+
return get(a.data, I, a.zero(T, I))
19+
end
20+
function Base.getindex(a::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
21+
return getindex(a, CartesianIndex(I))
22+
end
23+
24+
## # `getindex` but uses a default if the value is
25+
## # structurally zero.
26+
## function get_nonzero(a::SparseArray{T,N}, I::CartesianIndex{N}, zero) where {T,N}
27+
## @boundscheck checkbounds(a, I)
28+
## return get(a.data, I, zero)
29+
## end
30+
## function get_nonzero(a::SparseArray{T,N}, I::NTuple{N,Int}, zero) where {T,N}
31+
## return get_nonzero(a, CartesianIndex(I), zero)
32+
## end
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using Test
2+
using NDTensors.BlockSparseArrays
3+
4+
@testset "Test NDTensors.BlockSparseArrays" begin
5+
@testset "README" begin
6+
@test include(
7+
joinpath(
8+
pkgdir(BlockSparseArrays), "src", "BlockSparseArrays", "examples", "README.jl"
9+
),
10+
) isa Any
11+
end
12+
end

NDTensors/src/NDTensors.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ using TupleTools
1919

2020
include("SetParameters/src/SetParameters.jl")
2121
using .SetParameters
22+
include("BlockSparseArrays/src/BlockSparseArrays.jl")
23+
using .BlockSparseArrays
2224
include("SmallVectors/src/SmallVectors.jl")
2325
using .SmallVectors
2426

@@ -117,6 +119,7 @@ include("empty/adapt.jl")
117119
#
118120
include("arraytensor/arraytensor.jl")
119121
include("arraytensor/array.jl")
122+
include("arraytensor/blocksparsearray.jl")
120123

121124
#####################################
122125
# Deprecations

NDTensors/src/arraytensor/arraytensor.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# Used for dispatch to distinguish from Tensors wrapping TensorStorage.
22
# Remove once TensorStorage is removed.
33
const ArrayStorage{T,N} = Union{
4-
Array{T,N},ReshapedArray{T,N},SubArray{T,N},PermutedDimsArray{T,N},StridedView{T,N}
4+
Array{T,N},
5+
ReshapedArray{T,N},
6+
SubArray{T,N},
7+
PermutedDimsArray{T,N},
8+
StridedView{T,N},
9+
BlockSparseArray{T,N},
510
}
611
const MatrixStorage{T} = Union{
712
ArrayStorage{T,2},
@@ -41,6 +46,7 @@ function setindex!(tensor::MatrixOrArrayStorageTensor, v, I::Integer...)
4146
return tensor
4247
end
4348

49+
# TODO: Just call `contraction_output(storage(tensor1), storage(tensor2), indsR)`
4450
function contraction_output(
4551
tensor1::MatrixOrArrayStorageTensor, tensor2::MatrixOrArrayStorageTensor, indsR
4652
)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# TODO: Implement.
2+
function contraction_output(tensor1::BlockSparseArray, tensor2::BlockSparseArray, indsR)
3+
return error("Not implemented")
4+
end
5+
6+
# TODO: Implement.
7+
function contract!(
8+
tensorR::BlockSparseArray,
9+
labelsR,
10+
tensor1::BlockSparseArray,
11+
labels1,
12+
tensor2::BlockSparseArray,
13+
labels2,
14+
)
15+
return error("Not implemented")
16+
end

0 commit comments

Comments
 (0)