Skip to content

Commit cb8d766

Browse files
authored
Merge a3c13c6 into 7a6e342
2 parents 7a6e342 + a3c13c6 commit cb8d766

File tree

4 files changed

+32
-17
lines changed

4 files changed

+32
-17
lines changed

NDTensors/src/diag/diagtensor.jl

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using .DiagonalArrays: diaglength
1+
using .DiagonalArrays: diaglength, diagview
22

33
const DiagTensor{ElT,N,StoreT,IndsT} = Tensor{ElT,N,StoreT,IndsT} where {StoreT<:Diag}
44
const NonuniformDiagTensor{ElT,N,StoreT,IndsT} =
@@ -9,9 +9,7 @@ const UniformDiagTensor{ElT,N,StoreT,IndsT} =
99
function diag(tensor::DiagTensor)
1010
tensor_diag = NDTensors.similar(dense(typeof(tensor)), (diaglength(tensor),))
1111
# TODO: Define `eachdiagindex`.
12-
for j in 1:diaglength(tensor)
13-
tensor_diag[j] = getdiagindex(tensor, j)
14-
end
12+
diagview(tensor_diag) .= diagview(tensor)
1513
return tensor_diag
1614
end
1715

@@ -33,6 +31,19 @@ function Array(T::DiagTensor{ElT,N}) where {ElT,N}
3331
return Array{ElT,N}(T)
3432
end
3533

34+
function DiagonalArrays.diagview(T::NonuniformDiagTensor)
35+
return data(T)
36+
end
37+
38+
function DiagonalArrays.diagview(T::UniformDiagTensor)
39+
return fill(getdiagindex(T, 1), diaglength(T))
40+
end
41+
42+
## Should this go in dense.jl or here since its related to diag?
43+
function DiagonalArrays.diagview(T::DenseTensor)
44+
return diagview(array(T))
45+
end
46+
3647
function zeros(tensortype::Type{<:DiagTensor}, inds)
3748
return tensor(generic_zeros(storagetype(tensortype), mindim(inds)), inds)
3849
end
@@ -145,16 +156,14 @@ function permutedims!(
145156
f::Function=(r, t) -> t,
146157
) where {N}
147158
# TODO: check that inds(R)==permute(inds(T),perm)?
148-
for i in 1:diaglength(R)
149-
@inbounds setdiagindex!(R, f(getdiagindex(R, i), getdiagindex(T, i)), i)
150-
end
159+
diagview(R) .= f.(diagview(R), diagview(T))
151160
return R
152161
end
153162

154163
function permutedims(
155164
T::DiagTensor{<:Number,N}, perm::NTuple{N,Int}, f::Function=identity
156165
) where {N}
157-
R = NDTensors.similar(T, permute(inds(T), perm))
166+
R = NDTensors.similar(T)
158167
g(r, t) = f(t)
159168
permutedims!(R, T, perm, g)
160169
return R
@@ -193,9 +202,7 @@ end
193202
function permutedims!(
194203
R::DenseTensor{ElR,N}, T::DiagTensor{ElT,N}, perm::NTuple{N,Int}, f::Function=(r, t) -> t
195204
) where {ElR,ElT,N}
196-
for i in 1:diaglength(T)
197-
@inbounds setdiagindex!(R, f(getdiagindex(R, i), getdiagindex(T, i)), i)
198-
end
205+
diagview(array(R)) .= f.(diagview(array(R)), diagview(T))
199206
return R
200207
end
201208

NDTensors/src/linearalgebra/linearalgebra.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ matrix is unique. Returns a tuple (Q,R).
369369
function qr_positive(M::AbstractMatrix)
370370
sparseQ, R = qr(M)
371371
Q = convert(typeof(R), sparseQ)
372-
nc = size(Q, 2)
373372
signs = nonzero_sign.(diag(R))
374373
Q = Q * Diagonal(signs)
375374
R = Diagonal(conj.(signs)) * R

NDTensors/src/tensor/tensor.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,16 +361,15 @@ function getdiagindex(T::Tensor{<:Number,N}, ind::Int) where {N}
361361
return getindex(T, CartesianIndex(ntuple(_ -> ind, Val(N))))
362362
end
363363

364+
using .DiagonalArrays: diagview
364365
# TODO: add support for off-diagonals, return
365366
# block sparse vector instead of dense.
366367
function diag(tensor::Tensor)
367368
## d = NDTensors.similar(T, ElT, (diaglength(T),))
368369
tensordiag = NDTensors.similar(
369370
dense(typeof(tensor)), eltype(tensor), (diaglength(tensor),)
370371
)
371-
for n in 1:diaglength(tensor)
372-
tensordiag[n] = tensor[n, n]
373-
end
372+
data(tensordiag) .= diagview(tensor)
374373
return tensordiag
375374
end
376375

NDTensors/test/test_diag.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,25 @@ using LinearAlgebra: dot
3838
D = dev(tensor(Diag(vr), (d, d)))
3939
Da = Array(D)
4040
Dm = Matrix(D)
41+
Da = permutedims(D, (2, 1))
4142
@allowscalar begin
4243
@test Da == NDTensors.LinearAlgebra.diagm(0 => vr)
4344
@test Da == NDTensors.LinearAlgebra.diagm(0 => vr)
4445

45-
## TODO Currently this permutedims requires scalar indexing on GPU.
46-
Da = permutedims(D, (2, 1))
4746
@test Da == D
4847
end
4948

49+
# This if statement corresponds to the reported bug:
50+
# https://github.com/JuliaGPU/Metal.jl/issues/364
51+
if (dev == NDTensors.mtl && elt != ComplexF32)
52+
S = permutedims(dev(D), (1, 2), sqrt)
53+
@allowscalar begin
54+
for i in 1:diaglength(S)
55+
@test S[i, i] == sqrt(D[i, i])
56+
end
57+
end
58+
end
59+
5060
# Regression test for https://github.com/ITensor/ITensors.jl/issues/1199
5161
S = dev(tensor(Diag(randn(elt, 2)), (2, 2)))
5262
## This was creating a `Dense{ReshapedArray{Adjoint{Matrix}}}` which, in mul!, was

0 commit comments

Comments
 (0)