Skip to content

Commit d76cb1c

Browse files
authored
Merge branch 'main' into ranged_truncate_mps
2 parents d399674 + 3221d34 commit d76cb1c

File tree

7 files changed

+68
-55
lines changed

7 files changed

+68
-55
lines changed

NDTensors/src/diag/diag.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ function contract!(
452452
if all(i -> i < 0, Blabels)
453453
# If all of B is contracted
454454
# TODO: can also check NC+NB==NA
455-
min_dim = minimum(dims(B))
455+
min_dim = min(minimum(dims(A)), minimum(dims(B)))
456456
if length(Clabels) == 0
457457
# all indices are summed over, just add the product of the diagonal
458458
# elements of A and B

src/decomp.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ function svd(A::ITensor, Linds...; kwargs...)
9191
# @warn "Keyword arguments `utags` and `vtags` are deprecated in favor of `leftags` and `righttags`."
9292
#end
9393

94-
Lis = commoninds(A, indices(Linds))
94+
Lis = commoninds(A, indices(Linds...))
9595
Ris = uniqueinds(A, Lis)
9696

9797
Lis_original = Lis
@@ -333,7 +333,7 @@ qr(A::ITensor; kwargs...) = error(noinds_error_message("qr"))
333333
# call qr on the order-2 tensors directly
334334
function qr(A::ITensor, Linds...; kwargs...)
335335
tags::TagSet = get(kwargs, :tags, "Link,qr")
336-
Lis = commoninds(A, indices(Linds))
336+
Lis = commoninds(A, indices(Linds...))
337337
Ris = uniqueinds(A, Lis)
338338

339339
Lis_original = Lis
@@ -388,7 +388,7 @@ function factorize_qr(A::ITensor, Linds...; kwargs...)
388388
if ortho == "left"
389389
L, R, q = qr(A, Linds...; kwargs...)
390390
elseif ortho == "right"
391-
Lis = uniqueinds(A, indices(Linds))
391+
Lis = uniqueinds(A, indices(Linds...))
392392
R, L, q = qr(A, Lis...; kwargs...)
393393
else
394394
error(
@@ -427,9 +427,9 @@ function factorize_eigen(A::ITensor, Linds...; kwargs...)
427427
ortho::String = get(kwargs, :ortho, "left")
428428
delta_A2 = get(kwargs, :eigen_perturbation, nothing)
429429
if ortho == "left"
430-
Lis = commoninds(A, indices(Linds))
430+
Lis = commoninds(A, indices(Linds...))
431431
elseif ortho == "right"
432-
Lis = uniqueinds(A, indices(Linds))
432+
Lis = uniqueinds(A, indices(Linds...))
433433
else
434434
error(
435435
"In factorize using eigen decomposition, ortho keyword $ortho not supported. Supported options are left or right.",
@@ -506,7 +506,8 @@ Note that the default is now `left`, meaning for the results L,R = factorize(A),
506506
# Determines when to use eigen vs. svd (eigen is less precise,
507507
# so eigen should only be used if a larger cutoff is requested)
508508
automatic_cutoff = 1e-12
509-
dL, dR = dim(indices(Linds)), dim(indices(setdiff(inds(A), Linds)))
509+
Lis = indices(Linds...)
510+
dL, dR = dim(Lis), dim(indices(setdiff(inds(A), Lis)))
510511
maxdim = get(kwargs, :maxdim, min(dL, dR))
511512
might_truncate = !isnothing(cutoff) || maxdim < min(dL, dR)
512513

src/indexset.jl

Lines changed: 30 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -23,58 +23,43 @@ const IndexTuple{IndexT<:Index} = Tuple{Vararg{IndexT}}
2323
# Definition to help with generic code
2424
const Indices{IndexT<:Index} = Union{Vector{IndexT},Tuple{Vararg{IndexT}}}
2525

26-
# Flatten combinations of tuples and vectors into a single collection
27-
# of indices
28-
tuple_vcat(t::Tuple) = t
29-
tuple_vcat() = ()
30-
tuple_vcat(t) = (t,)
31-
tuple_vcat(a, args...) = (tuple_vcat(a)..., tuple_vcat(args...)...)
32-
33-
tuple_to_vector(t::Tuple) = collect(t)
34-
tuple_to_vector(t) = t
35-
36-
function _narrow_eltype(v::Vector{T}) where {T}
26+
function _narrow_eltype(v::Vector{T}; default_empty_eltype=T) where {T}
3727
if isempty(v)
38-
return v
28+
return default_empty_eltype[]
3929
end
4030
return convert(Vector{mapreduce(typeof, promote_type, v)}, v)
4131
end
42-
narrow_eltype(v::Vector{T}) where {T} = isconcretetype(T) ? v : _narrow_eltype(v)
43-
44-
push_or_append!(v, x::Union{Vector,Tuple}) = append!(v, x)
45-
push_or_append!(v, x) = push!(v, x)
46-
47-
function _indices(is::Vector)
48-
isempty(is) && return is
49-
is_flat = Index[]
50-
for i in is
51-
push_or_append!(is_flat, i)
52-
end
53-
return narrow_eltype(is_flat)
54-
end
55-
indices(is::Vector{<:Index}) = narrow_eltype(is)
56-
57-
_indices(is::Tuple{Vararg{<:Index}}) = is
58-
_indices(is::Tuple{Vararg{Union{<:Vector,<:Index}}}) = vcat(is...)
59-
_indices(is::Tuple{Vararg{Union{<:Tuple,<:Index}}}) = tuple_vcat(is...)
60-
_indices(is::Tuple{Vararg{Union{<:Tuple,<:Vector,<:Index}}}) = indices(tuple_to_vector.(is))
61-
indices(is::Tuple{Vararg{<:Index}}) = is
62-
function indices(is::Tuple)
63-
inds = _indices(is)
64-
if isempty(inds)
65-
# Otherwise it outputs `Any[]`, which breaks
66-
# some generic code like `dim`.
67-
return Index[]
32+
function narrow_eltype(v::Vector{T}; default_empty_eltype=T) where {T}
33+
if isconcretetype(T)
34+
return v
6835
end
69-
return inds
70-
end
71-
indices(is::Union{<:Tuple,<:Vector,<:Index}...) = indices(is)
72-
function indices(is::Vector)
73-
# This narrows the type. Also handles the empty case.
74-
all(i -> i isa Index, is) && return narrow_eltype(is)
75-
return _indices(is)
36+
return _narrow_eltype(v; default_empty_eltype)
7637
end
7738

39+
_indices() = ()
40+
_indices(x::Index) = (x,)
41+
42+
# Tuples
43+
_indices(x1::Tuple, x2::Tuple) = (x1..., x2...)
44+
_indices(x1::Index, x2::Tuple) = (x1, x2...)
45+
_indices(x1::Tuple, x2::Index) = (x1..., x2)
46+
_indices(x1::Index, x2::Index) = (x1, x2)
47+
48+
# Vectors
49+
_indices(x1::Vector, x2::Vector) = narrow_eltype(vcat(x1, x2); default_empty_eltype=Index)
50+
51+
# Mix vectors and tuples/elements
52+
_indices(x1::Vector, x2) = _indices(x1, [x2])
53+
_indices(x1, x2::Vector) = _indices([x1], x2)
54+
_indices(x1::Vector, x2::Tuple) = _indices(x1, [x2...])
55+
_indices(x1::Tuple, x2::Vector) = _indices([x1...], x2)
56+
57+
indices(x::Vector{Index{S}}) where {S} = x
58+
indices(x::Vector{Index}) = narrow_eltype(x; default_empty_eltype=Index)
59+
indices(x::Tuple) = reduce(_indices, x; init=())
60+
indices(x::Vector) = reduce(_indices, x; init=Index[])
61+
indices(x...) = indices(x)
62+
7863
# To help with backwards compatibility
7964
IndexSet(inds::IndexSet) = inds
8065
IndexSet(inds::Indices) = collect(inds)

src/itensor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,9 +1764,9 @@ T[1, 1, 1] == pT_alias[1, 1, 1]
17641764
```
17651765
"""
17661766
function permute(T::ITensor, new_inds...; kwargs...)
1767-
if !hassameinds(T, indices(new_inds))
1767+
if !hassameinds(T, indices(new_inds...))
17681768
error(
1769-
"In `permute(::ITensor, inds...)`, the input ITensor has indices: \n\n$(inds(T))\n\nbut the desired Index ordering is: \n\n$(indices(new_inds))",
1769+
"In `permute(::ITensor, inds...)`, the input ITensor has indices: \n\n$(inds(T))\n\nbut the desired Index ordering is: \n\n$(indices(new_inds...))",
17701770
)
17711771
end
17721772
allow_alias = deprecated_keyword_argument(

test/ITensorChainRules/test_chainrules.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,19 @@ Random.seed!(1234)
352352
@test f2'(x) 2MPO(s, "I")
353353
@test f3'(x) -MPO(s, "I")
354354
end
355+
356+
@testset "issue 969" begin
357+
i = Index(2)
358+
j = Index(3)
359+
A = randomITensor(i)
360+
B = randomITensor(j)
361+
f = function (x, y)
362+
d = δ(ind(x, 1), ind(y, 1))
363+
return (x * d * y)[]
364+
end
365+
args = (A, B)
366+
test_rrule(ZygoteRuleConfig(), f, args...; rrule_f=rrule_via_ad, check_inferred=false)
367+
end
355368
end
356369

357370
@testset "ChainRules rrules: op" begin

test/ITensorChainRules/test_chainrules_ops.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ using Zygote: ZygoteRuleConfig, gradient
204204
return os
205205
end
206206

207-
if VERSION v"1.7"
207+
if VERSION.minor == 7
208+
# For some reason this is broken in Julia 1.6 and 1.8?
209+
# Seems like a Zygote problem
208210
f = function (x)
209211
return ITensor(exp(1.5 * H(x, x); alg=Trotter{1}(1)), s)[1, 1]
210212
end

test/diagitensor.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,18 @@ using Test
530530
@test D1 * D2 dense(D1) * dense(D2)
531531
@test D2 * D1 dense(D1) * dense(D2)
532532
end
533+
534+
@testset "Rectangular Diag * Dense regression test (#969)" begin
535+
i = Index(3)
536+
j = Index(2)
537+
A = randomITensor(i)
538+
B = delta(i, j)
539+
C = A * B
540+
@test hassameinds(C, j)
541+
for n in 1:dim(j)
542+
@test C[n] == A[n]
543+
end
544+
end
533545
end
534546
end
535547

0 commit comments

Comments
 (0)