Skip to content

Commit 489b534

Browse files
authored
Merge 28775ef into c47eb7c
2 parents c47eb7c + 28775ef commit 489b534

20 files changed

+483
-0
lines changed

NDTensors/src/NDTensors.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ include("DiagonalArrays/src/DiagonalArrays.jl")
3030
using .DiagonalArrays
3131
include("BlockSparseArrays/src/BlockSparseArrays.jl")
3232
using .BlockSparseArrays
33+
include("NamedDimsArrays/src/NamedDimsArrays.jl")
34+
using .NamedDimsArrays: NamedDimsArrays
3335
include("SmallVectors/src/SmallVectors.jl")
3436
using .SmallVectors
3537
include("SortedSets/src/SortedSets.jl")
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
style = "blue"
2+
indent = 2
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# NamedDimsArrays.jl
2+
3+
````julia
4+
using NDTensors.NamedDimsArrays: align, dimnames, named, unname
5+
using NDTensors.TensorAlgebra: TensorAlgebra
6+
7+
# Named dimensions
8+
i = named(2, "i")
9+
j = named(2, "j")
10+
k = named(2, "k")
11+
12+
# Arrays with named dimensions
13+
na1 = randn(i, j)
14+
na2 = randn(j, k)
15+
16+
@show dimnames(na1) == ("i", "j")
17+
18+
# Indexing
19+
@show na1[j => 2, i => 1] == na1[1, 2]
20+
21+
# Tensor contraction
22+
na_dest = TensorAlgebra.contract(na1, na2)
23+
24+
@show issetequal(dimnames(na_dest), ("i", "k"))
25+
# `unname` removes the names and returns an `Array`
26+
@show unname(na_dest, (i, k)) unname(na1) * unname(na2)
27+
28+
# Permute dimensions (like `ITensors.permute`)
29+
na1 = align(na1, (j, i))
30+
@show na1[i => 1, j => 2] == na1[2, 1]
31+
````
32+
33+
---
34+
35+
*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
36+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# # NamedDimsArrays.jl
2+
3+
using NDTensors.NamedDimsArrays: align, dimnames, named, unname
4+
using NDTensors.TensorAlgebra: TensorAlgebra
5+
6+
## Named dimensions
7+
i = named(2, "i")
8+
j = named(2, "j")
9+
k = named(2, "k")
10+
11+
## Arrays with named dimensions
12+
na1 = randn(i, j)
13+
na2 = randn(j, k)
14+
15+
@show dimnames(na1) == ("i", "j")
16+
17+
## Indexing
18+
@show na1[j => 2, i => 1] == na1[1, 2]
19+
20+
## Tensor contraction
21+
na_dest = TensorAlgebra.contract(na1, na2)
22+
23+
@show issetequal(dimnames(na_dest), ("i", "k"))
24+
## `unname` removes the names and returns an `Array`
25+
@show unname(na_dest, (i, k)) unname(na1) * unname(na2)
26+
27+
## Permute dimensions (like `ITensors.permute`)
28+
na1 = align(na1, (j, i))
29+
@show na1[i => 1, j => 2] == na1[2, 1]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
module NamedDimsArraysTensorAlgebraExt
2+
using ..NamedDimsArrays: NamedDimsArrays
3+
using ...NDTensors.TensorAlgebra: TensorAlgebra
4+
5+
include("contract.jl")
6+
end
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
using NDTensors.NamedDimsArrays: AbstractNamedDimsArray, dimnames, named, unname
2+
using NDTensors.TensorAlgebra: contract
3+
4+
function TensorAlgebra.contract(
5+
na1::AbstractNamedDimsArray, na2::AbstractNamedDimsArray, α, β; kwargs...
6+
)
7+
a_dest, names_dest = contract(
8+
unname(na1), dimnames(na1), unname(na2), dimnames(na2), α, β; kwargs...
9+
)
10+
# TODO: Automate `Tuple` conversion of names?
11+
return named(a_dest, Tuple(names_dest))
12+
end
13+
14+
function TensorAlgebra.contract(
15+
na1::AbstractNamedDimsArray, na2::AbstractNamedDimsArray; kwargs...
16+
)
17+
return contract(na1, na2, true, false; kwargs...)
18+
end
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using Test: @test, @testset
2+
using NDTensors.NamedDimsArrays: named, unname
3+
using NDTensors.TensorAlgebra: TensorAlgebra
4+
5+
@testset "NamedDimsArraysTensorAlgebraExt" begin
6+
i = named(2, "i")
7+
j = named(2, "j")
8+
k = named(2, "k")
9+
na1 = randn(i, j)
10+
na2 = randn(j, k)
11+
na_dest = TensorAlgebra.contract(na1, na2)
12+
@test unname(na_dest, (i, k)) unname(na1) * unname(na2)
13+
end
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using Literate
2+
using NDTensors.NamedDimsArrays: NamedDimsArrays
3+
Literate.markdown(
4+
joinpath(
5+
pkgdir(NamedDimsArrays), "src", "NamedDimsArrays", "examples", "example_readme.jl"
6+
),
7+
joinpath(pkgdir(NamedDimsArrays), "src", "NamedDimsArrays");
8+
flavor=Literate.CommonMarkFlavor(),
9+
name="README",
10+
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module NamedDimsArrays
2+
include("traits.jl")
3+
include("abstractnamedint.jl")
4+
include("abstractnamedunitrange.jl")
5+
include("abstractnameddimsarray.jl")
6+
include("namedint.jl")
7+
include("namedunitrange.jl")
8+
include("nameddimsarray.jl")
9+
10+
# Extensions
11+
include("../ext/NamedDimsArraysTensorAlgebraExt/src/NamedDimsArraysTensorAlgebraExt.jl")
12+
end
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# https://github.com/invenia/NamedDims.jl
2+
# https://github.com/mcabbott/NamedPlus.jl
3+
4+
abstract type AbstractNamedDimsArray{T,N,Names} <: AbstractArray{T,N} end
5+
6+
# Required interface
7+
8+
# Output the names.
9+
dimnames(a::AbstractNamedDimsArray) = error("Not implemented")
10+
11+
# Unwrapping the names
12+
Base.parent(::AbstractNamedDimsArray) = error("Not implemented")
13+
14+
# Set the names of an unnamed AbstractArray
15+
# `ndims(a) == length(names)`
16+
# This is a constructor
17+
## named(a::AbstractArray, names) = error("Not implemented")
18+
19+
# Traits
20+
isnamed(::AbstractNamedDimsArray) = true
21+
22+
# AbstractArray interface
23+
# TODO: Use `unname` instead of `parent`?
24+
25+
# Helper function, move to `utils.jl`.
26+
named_tuple(t::Tuple, names) = ntuple(i -> named(t[i], names[i]), length(t))
27+
28+
# TODO: Use the proper type, `namedaxistype(a)`.
29+
Base.axes(a::AbstractNamedDimsArray) = named_tuple(axes(unname(a)), dimnames(a))
30+
# TODO: Use the proper type, `namedlengthtype(a)`.
31+
Base.size(a::AbstractNamedDimsArray) = length.(axes(a))
32+
Base.getindex(a::AbstractNamedDimsArray, I...) = unname(a)[I...]
33+
function Base.setindex!(a::AbstractNamedDimsArray, x, I...)
34+
unname(a)[I...] = x
35+
return a
36+
end
37+
38+
# Derived interface
39+
40+
# Output the names.
41+
dimname(a::AbstractNamedDimsArray, i) = dimnames(a)[i]
42+
43+
# Renaming
44+
# Unname and set new naems
45+
rename(a::AbstractNamedDimsArray, names) = named(unname(a), names)
46+
47+
# replacenames(a, :i => :a, :j => :b)
48+
# `rename` in `NamedPlus.jl`.
49+
replacenames(a::AbstractNamedDimsArray, names::Pair) = error("Not implemented yet")
50+
51+
# Either define new names or replace names
52+
setnames(a::AbstractArray, names) = named(a, names)
53+
setnames(a::AbstractNamedDimsArray, names) = rename(a, names)
54+
55+
function getperm(x, y)
56+
return map(xᵢ -> findfirst(isequal(xᵢ), y), x)
57+
end
58+
59+
function get_name_perm(a::AbstractNamedDimsArray, names::Tuple)
60+
return getperm(dimnames(a), names)
61+
end
62+
63+
function get_name_perm(
64+
a::AbstractNamedDimsArray, namedints::Tuple{Vararg{AbstractNamedInt}}
65+
)
66+
return getperm(size(a), namedints)
67+
end
68+
69+
function get_name_perm(
70+
a::AbstractNamedDimsArray, namedaxes::Tuple{Vararg{AbstractNamedUnitRange}}
71+
)
72+
return getperm(axes(a), namedaxes)
73+
end
74+
75+
# Indexing
76+
# a[:i => 2, :j => 3]
77+
# TODO: Write a generic version using `dim`.
78+
# TODO: Define a `NamedIndex` type for indexing?
79+
function Base.getindex(a::AbstractNamedDimsArray, I::Pair...)
80+
perm = get_name_perm(a, first.(I))
81+
i = last.(I)
82+
return unname(a)[map(p -> i[p], perm)...]
83+
end
84+
85+
# a[:i => 2, :j => 3] = 12
86+
# TODO: Write a generic version using `dim`.
87+
function Base.setindex!(a::AbstractNamedDimsArray, value, I::Pair...)
88+
perm = get_name_perm(a, first.(I))
89+
i = last.(I)
90+
unname(a)[map(p -> i[p], perm)...] = value
91+
return a
92+
end
93+
94+
# Output the dimension of the specified name.
95+
dim(a::AbstractNamedDimsArray, name) = findfirst(==(name), dimnames(a))
96+
97+
# Output the dimensions of the specified names.
98+
dims(a::AbstractNamedDimsArray, names) = map(name -> dim(a, name), names)
99+
100+
# Unwrapping the names
101+
unname(a::AbstractNamedDimsArray) = parent(a)
102+
unname(a::AbstractArray) = a
103+
104+
# Permute into a certain order.
105+
# align(a, (:j, :k, :i))
106+
# Like `named(nameless(a, names), names)`
107+
function align(a::AbstractNamedDimsArray, names)
108+
perm = get_name_perm(a, names)
109+
# TODO: Avoid permutation if it is a trivial permutation?
110+
return typeof(a)(permutedims(unname(a), perm), names)
111+
end
112+
113+
# Unwrapping names and permuting
114+
# nameless(a, (:j, :i))
115+
# Could just call `unname`?
116+
## nameless(a::AbstractNamedDimsArray, names) = unname(align(a, names))
117+
unname(a::AbstractNamedDimsArray, names) = unname(align(a, names))
118+
119+
# In `TensorAlgebra` this this `fuse` and `unfuse`,
120+
# in `NDTensors`/`ITensors` this is `combine` and `uncombine`.
121+
# t = split(g, :n => (j=4, k=5))
122+
# join(t, (:i, :k) => :χ)
123+
124+
# TensorAlgebra
125+
# contract, fusedims, unfusedims, qr, eigen, svd, add, etc.
126+
# Some of these can simply wrap `TensorAlgebra.jl` functions.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
abstract type AbstractNamedInt{Value,Name} <: Integer end
2+
3+
# Interface
4+
unname(i::AbstractNamedInt) = error("Not implemented")
5+
name(i::AbstractNamedInt) = error("Not implemented")
6+
7+
# Derived
8+
unname(::Type{<:AbstractNamedInt{Value}}) where {Value} = Value
9+
10+
# Integer interface
11+
# TODO: Should this make a random name, or require defining a way
12+
# to combine names?
13+
Base.:*(i1::AbstractNamedInt, i2::AbstractNamedInt) = unname(i1) * unname(i2)
14+
Base.:-(i::AbstractNamedInt) = typeof(i)(-unname(i), name(i))
15+
16+
# TODO: Define for `NamedInt`, `NamedUnitRange` fallback?
17+
# Base.OneTo(stop::AbstractNamedInt) = namedoneto(stop)
18+
## nameduniterange_type(::Type{<:AbstractNamedInt}) = error("Not implemented")
19+
20+
# TODO: Use conversion from `AbstractNamedInt` to `AbstractNamedUnitRange`
21+
# instead of general `named`.
22+
# Base.OneTo(stop::AbstractNamedInt) = namedoneto(stop)
23+
Base.OneTo(stop::AbstractNamedInt) = named(Base.OneTo(unname(stop)), name(stop))
24+
25+
# TODO: Is this needed?
26+
# Include the name as well?
27+
Base.:<(i1::AbstractNamedInt, i2::AbstractNamedInt) = unname(i1) < unname(i2)
28+
## Base.zero(type::Type{<:AbstractNamedInt}) = zero(unname(type))
29+
30+
function Base.promote_rule(type1::Type{<:AbstractNamedInt}, type2::Type{<:Integer})
31+
return promote_type(unname(type1), type2)
32+
end
33+
(type::Type{<:Integer})(i::AbstractNamedInt) = type(unname(i))
34+
# TODO: Use conversion from `AbstractNamedInt` to `AbstractNamedUnitRange`
35+
# instead of general `named`.
36+
function Base.oftype(i1::AbstractNamedInt, i2::Integer)
37+
return named(convert(typeof(unname(i1)), i2), name(i1))
38+
end
39+
40+
# Traits
41+
isnamed(::AbstractNamedInt) = true
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
abstract type AbstractNamedUnitRange{T,Value<:AbstractUnitRange{T},Name} <:
2+
AbstractUnitRange{T} end
3+
4+
# Required interface
5+
unname(::AbstractNamedUnitRange) = error("Not implemented")
6+
name(::AbstractNamedUnitRange) = error("Not implemented")
7+
8+
# Traits
9+
isnamed(::AbstractNamedUnitRange) = true
10+
11+
# Unit range
12+
Base.first(i::AbstractNamedUnitRange) = first(unname(i))
13+
Base.last(i::AbstractNamedUnitRange) = last(unname(i))
14+
Base.length(i::AbstractNamedUnitRange) = named(length(unname(i)), name(i))

0 commit comments

Comments
 (0)