Skip to content

Commit 78a0bad

Browse files
authored
Merge b537e6a into 690b219
2 parents 690b219 + b537e6a commit 78a0bad

File tree

10 files changed

+273
-12
lines changed

10 files changed

+273
-12
lines changed

NDTensors/src/BlockSparseArrays/src/blocksparsearray.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ function BlockArrays.viewblock(block_arr::BlockSparseArray, block)
6565
# TODO: Make this `Zeros`?
6666
## zero = zeros(eltype(block_arr), block_size)
6767
return block_arr.blocks[blks...] # Fails because zero isn't defined
68-
## return get_nonzero(block_arr.blocks, blks, zero)
6968
end
7069

7170
function Base.getindex(block_arr::BlockSparseArray{T,N}, bi::BlockIndex{N}) where {T,N}

NDTensors/src/BlockSparseArrays/src/sparsearray.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
# TODO: Define a constructor with a default `zero`.
12
struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
23
data::Dictionary{CartesianIndex{N},T}
3-
dims::NTuple{N,Int64}
4+
dims::NTuple{N,Int}
45
zero::Zero
56
end
67

@@ -20,13 +21,3 @@ end
2021
function Base.getindex(a::SparseArray{T,N}, I::Vararg{Int,N}) where {T,N}
2122
return getindex(a, CartesianIndex(I))
2223
end
23-
24-
## # `getindex` but uses a default if the value is
25-
## # structurally zero.
26-
## function get_nonzero(a::SparseArray{T,N}, I::CartesianIndex{N}, zero) where {T,N}
27-
## @boundscheck checkbounds(a, I)
28-
## return get(a.data, I, zero)
29-
## end
30-
## function get_nonzero(a::SparseArray{T,N}, I::NTuple{N,Int}, zero) where {T,N}
31-
## return get_nonzero(a, CartesianIndex(I), zero)
32-
## end
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# DiagonalArrays.jl
2+
3+
A Julia `DiagonalArray` type.
4+
5+
````julia
6+
using NDTensors.DiagonalArrays:
7+
DiagonalArray,
8+
densearray,
9+
diagview,
10+
diaglength,
11+
getdiagindex,
12+
setdiagindex!,
13+
setdiag!,
14+
diagcopyto!
15+
16+
d = DiagonalArray([1., 2, 3], 3, 4, 5)
17+
@show d[1, 1, 1] == 1
18+
@show d[2, 2, 2] == 2
19+
@show d[1, 2, 1] == 0
20+
21+
d[2, 2, 2] = 22
22+
@show d[2, 2, 2] == 22
23+
24+
@show diaglength(d) == 3
25+
@show densearray(d) == d
26+
@show getdiagindex(d, 2) == d[2, 2, 2]
27+
28+
setdiagindex!(d, 222, 2)
29+
@show d[2, 2, 2] == 222
30+
31+
a = randn(3, 4, 5)
32+
new_diag = randn(3)
33+
setdiag!(a, new_diag)
34+
diagcopyto!(d, a)
35+
36+
@show diagview(a) == new_diag
37+
@show diagview(d) == new_diag
38+
````
39+
40+
You can generate this README with:
41+
```julia
42+
using Literate
43+
Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
44+
```
45+
46+
---
47+
48+
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
49+
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# # DiagonalArrays.jl
2+
#
3+
# A Julia `DiagonalArray` type.
4+
5+
using NDTensors.DiagonalArrays:
6+
DiagonalArray,
7+
densearray,
8+
diagview,
9+
diaglength,
10+
getdiagindex,
11+
setdiagindex!,
12+
setdiag!,
13+
diagcopyto!
14+
15+
d = DiagonalArray([1.0, 2, 3], 3, 4, 5)
16+
@show d[1, 1, 1] == 1
17+
@show d[2, 2, 2] == 2
18+
@show d[1, 2, 1] == 0
19+
20+
d[2, 2, 2] = 22
21+
@show d[2, 2, 2] == 22
22+
23+
@show diaglength(d) == 3
24+
@show densearray(d) == d
25+
@show getdiagindex(d, 2) == d[2, 2, 2]
26+
27+
setdiagindex!(d, 222, 2)
28+
@show d[2, 2, 2] == 222
29+
30+
a = randn(3, 4, 5)
31+
new_diag = randn(3)
32+
setdiag!(a, new_diag)
33+
diagcopyto!(d, a)
34+
35+
@show diagview(a) == new_diag
36+
@show diagview(d) == new_diag
37+
38+
# You can generate this README with:
39+
# ```julia
40+
# using Literate
41+
# Literate.markdown("examples/README.jl", "."; flavor=Literate.CommonMarkFlavor())
42+
# ```
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
module DiagonalArrays
2+
3+
using LinearAlgebra
4+
5+
export DiagonalArray
6+
7+
include("diagview.jl")
8+
9+
struct DefaultZero end
10+
11+
function (::DefaultZero)(eltype::Type, I::CartesianIndex)
12+
return zero(eltype)
13+
end
14+
15+
struct DiagonalArray{T,N,Diag<:AbstractVector{T},Zero} <: AbstractArray{T,N}
16+
diag::Diag
17+
dims::NTuple{N,Int}
18+
zero::Zero
19+
end
20+
21+
function DiagonalArray{T,N}(
22+
diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
23+
) where {T,N}
24+
return DiagonalArray{T,N,typeof(diag),typeof(zero)}(diag, d, zero)
25+
end
26+
27+
function DiagonalArray{T,N}(
28+
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
29+
) where {T,N}
30+
return DiagonalArray{T,N}(T.(diag), d, zero)
31+
end
32+
33+
function DiagonalArray{T,N}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
34+
return DiagonalArray{T,N}(diag, d)
35+
end
36+
37+
function DiagonalArray{T}(
38+
diag::AbstractVector, d::Tuple{Vararg{Int,N}}, zero=DefaultZero()
39+
) where {T,N}
40+
return DiagonalArray{T,N}(diag, d, zero)
41+
end
42+
43+
function DiagonalArray{T}(diag::AbstractVector, d::Vararg{Int,N}) where {T,N}
44+
return DiagonalArray{T,N}(diag, d)
45+
end
46+
47+
function DiagonalArray(diag::AbstractVector{T}, d::Tuple{Vararg{Int,N}}) where {T,N}
48+
return DiagonalArray{T,N}(diag, d)
49+
end
50+
51+
function DiagonalArray(diag::AbstractVector{T}, d::Vararg{Int,N}) where {T,N}
52+
return DiagonalArray{T,N}(diag, d)
53+
end
54+
55+
# undef
56+
function DiagonalArray{T,N}(::UndefInitializer, d::Tuple{Vararg{Int,N}}) where {T,N}
57+
return DiagonalArray{T,N}(Vector{T}(undef, minimum(d)), d)
58+
end
59+
60+
function DiagonalArray{T,N}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
61+
return DiagonalArray{T,N}(undef, d)
62+
end
63+
64+
function DiagonalArray{T}(::UndefInitializer, d::Tuple{Vararg{Int,N}}) where {T,N}
65+
return DiagonalArray{T,N}(undef, d)
66+
end
67+
68+
function DiagonalArray{T}(::UndefInitializer, d::Vararg{Int,N}) where {T,N}
69+
return DiagonalArray{T,N}(undef, d)
70+
end
71+
72+
Base.size(a::DiagonalArray) = a.dims
73+
74+
diagview(a::DiagonalArray) = a.diag
75+
LinearAlgebra.diag(a::DiagonalArray) = copy(diagview(a))
76+
77+
function Base.getindex(a::DiagonalArray{T,N}, I::CartesianIndex{N}) where {T,N}
78+
i = diagindex(a, I)
79+
isnothing(i) && return a.zero(T, I)
80+
return getdiagindex(a, i)
81+
end
82+
83+
function Base.getindex(a::DiagonalArray{T,N}, I::Vararg{Int,N}) where {T,N}
84+
return getindex(a, CartesianIndex(I))
85+
end
86+
87+
function Base.setindex!(a::DiagonalArray{T,N}, v, I::CartesianIndex{N}) where {T,N}
88+
i = diagindex(a, I)
89+
isnothing(i) && return error("Can't set off-diagonal element of DiagonalArray")
90+
setdiagindex!(a, v, i)
91+
return a
92+
end
93+
94+
function Base.setindex!(a::DiagonalArray{T,N}, v, I::Vararg{Int,N}) where {T,N}
95+
a[CartesianIndex(I)] = v
96+
return a
97+
end
98+
99+
# Make dense.
100+
function densearray(a::DiagonalArray)
101+
# TODO: Check this works on GPU.
102+
# TODO: Make use of `a.zero`?
103+
d = similar(diagview(a), size(a))
104+
fill!(d, zero(eltype(a)))
105+
diagcopyto!(d, a)
106+
return d
107+
end
108+
109+
end
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Convert to an offset along the diagonal.
2+
# Otherwise, return `nothing`.
3+
function diagindex(a::AbstractArray{T,N}, I::CartesianIndex{N}) where {T,N}
4+
!allequal(Tuple(I)) && return nothing
5+
return first(Tuple(I))
6+
end
7+
8+
function diagindex(a::AbstractArray{T,N}, I::Vararg{Int,N}) where {T,N}
9+
return diagindex(a, CartesianIndex(I))
10+
end
11+
12+
function getdiagindex(a::AbstractArray, i::Integer)
13+
return diagview(a)[i]
14+
end
15+
16+
function setdiagindex!(a::AbstractArray, v, i::Integer)
17+
diagview(a)[i] = v
18+
return a
19+
end
20+
21+
function setdiag!(a::AbstractArray, v)
22+
copyto!(diagview(a), v)
23+
return a
24+
end
25+
26+
function diaglength(a::AbstractArray)
27+
# length(diagview(a))
28+
return minimum(size(a))
29+
end
30+
31+
function diagstride(A::AbstractArray)
32+
s = 1
33+
p = 1
34+
for i in 1:(ndims(A) - 1)
35+
p *= size(A, i)
36+
s += p
37+
end
38+
return s
39+
end
40+
41+
function diagindices(A::AbstractArray)
42+
diaglength = minimum(size(A))
43+
maxdiag = LinearIndices(A)[CartesianIndex(ntuple(Returns(diaglength), ndims(A)))]
44+
return 1:diagstride(A):maxdiag
45+
end
46+
47+
function diagview(A::AbstractArray)
48+
return @view A[diagindices(A)]
49+
end
50+
51+
function diagcopyto!(dest::AbstractArray, src::AbstractArray)
52+
copyto!(diagview(dest), diagview(src))
53+
return dest
54+
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Test
2+
using NDTensors.DiagonalArrays
3+
4+
@testset "Test NDTensors.DiagonalArrays" begin
5+
@testset "README" begin
6+
@test include(
7+
joinpath(pkgdir(DiagonalArrays), "src", "DiagonalArrays", "examples", "README.jl")
8+
) isa Any
9+
end
10+
end

NDTensors/src/NDTensors.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ using TupleTools
2020

2121
include("SetParameters/src/SetParameters.jl")
2222
using .SetParameters
23+
include("DiagonalArrays/src/DiagonalArrays.jl")
24+
using .DiagonalArrays
2325
include("BlockSparseArrays/src/BlockSparseArrays.jl")
2426
using .BlockSparseArrays
2527
include("SmallVectors/src/SmallVectors.jl")

NDTensors/test/DiagonalArrays.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
using Test
2+
using NDTensors
3+
4+
include(joinpath(pkgdir(NDTensors), "src", "DiagonalArrays", "test", "runtests.jl"))

NDTensors/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ end
2020
@safetestset "NDTensors" begin
2121
@testset "$filename" for filename in [
2222
"BlockSparseArrays.jl",
23+
"DiagonalArrays.jl",
2324
"SetParameters.jl",
2425
"SmallVectors.jl",
2526
"SortedSets.jl",

0 commit comments

Comments
 (0)