Skip to content

Commit e633092

Browse files
authored
Merge 0cfbe83 into 4c5c991
2 parents 4c5c991 + 0cfbe83 commit e633092

File tree

208 files changed

+1656
-271
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

208 files changed

+1656
-271
lines changed

NDTensors/src/NDTensors.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,33 @@ using Strided
1919
using TimerOutputs
2020
using TupleTools
2121

22-
# TODO: Define an `AlgorithmSelection` module
2322
# TODO: List types, macros, and functions being used.
24-
include("algorithm.jl")
25-
include("SetParameters/src/SetParameters.jl")
23+
include("lib/AlgorithmSelection/src/AlgorithmSelection.jl")
24+
using .AlgorithmSelection: AlgorithmSelection
25+
include("lib/BaseExtensions/src/BaseExtensions.jl")
26+
using .BaseExtensions: BaseExtensions
27+
include("lib/SetParameters/src/SetParameters.jl")
2628
using .SetParameters
27-
include("TensorAlgebra/src/TensorAlgebra.jl")
29+
include("lib/BroadcastMapConversion/src/BroadcastMapConversion.jl")
30+
using .BroadcastMapConversion: BroadcastMapConversion
31+
include("lib/Unwrap/src/Unwrap.jl")
32+
using .Unwrap
33+
include("lib/RankFactorization/src/RankFactorization.jl")
34+
using .RankFactorization: RankFactorization
35+
include("lib/TensorAlgebra/src/TensorAlgebra.jl")
2836
using .TensorAlgebra: TensorAlgebra
29-
include("DiagonalArrays/src/DiagonalArrays.jl")
37+
include("lib/DiagonalArrays/src/DiagonalArrays.jl")
3038
using .DiagonalArrays
31-
include("BlockSparseArrays/src/BlockSparseArrays.jl")
39+
include("lib/BlockSparseArrays/src/BlockSparseArrays.jl")
3240
using .BlockSparseArrays
33-
include("NamedDimsArrays/src/NamedDimsArrays.jl")
41+
include("lib/NamedDimsArrays/src/NamedDimsArrays.jl")
3442
using .NamedDimsArrays: NamedDimsArrays
35-
include("SmallVectors/src/SmallVectors.jl")
43+
include("lib/SmallVectors/src/SmallVectors.jl")
3644
using .SmallVectors
37-
include("SortedSets/src/SortedSets.jl")
45+
include("lib/SortedSets/src/SortedSets.jl")
3846
using .SortedSets
39-
include("TagSets/src/TagSets.jl")
47+
include("lib/TagSets/src/TagSets.jl")
4048
using .TagSets
41-
include("Unwrap/src/Unwrap.jl")
42-
using .Unwrap
4349

4450
using Base: @propagate_inbounds, ReshapedArray, DimOrInd, OneTo
4551

NDTensors/src/NamedDimsArrays/ext/NamedDimsArraysTensorAlgebraExt/src/NamedDimsArraysTensorAlgebraExt.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

NDTensors/src/NamedDimsArrays/ext/NamedDimsArraysTensorAlgebraExt/test/runtests.jl

Lines changed: 0 additions & 13 deletions
This file was deleted.

NDTensors/src/NamedDimsArrays/src/abstractnamedunitrange.jl

Lines changed: 0 additions & 14 deletions
This file was deleted.

NDTensors/src/NamedDimsArrays/test/test_basic.jl

Lines changed: 0 additions & 38 deletions
This file was deleted.

NDTensors/src/TensorAlgebra/src/fusedims.jl

Lines changed: 0 additions & 23 deletions
This file was deleted.

NDTensors/src/algorithm.jl

Lines changed: 0 additions & 29 deletions
This file was deleted.

NDTensors/src/blocksparse/contract.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using .AlgorithmSelection: Algorithm, @Algorithm_str
2+
13
function contract(
24
tensor1::BlockSparseTensor,
35
labelstensor1,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module AlgorithmSelection
2+
include("algorithm.jl")
3+
end
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
Algorithm
3+
4+
A type representing an algorithm backend for a function.
5+
6+
For example, a function might have multiple backend algorithm
7+
implementations, which internally are selected with an `Algorithm` type.
8+
9+
This allows users to extend functionality with a new algorithm but
10+
use the same interface.
11+
"""
12+
struct Algorithm{Alg,Kwargs<:NamedTuple}
13+
kwargs::Kwargs
14+
end
15+
16+
Algorithm{Alg}(kwargs::NamedTuple) where {Alg} = Algorithm{Alg,typeof(kwargs)}(kwargs)
17+
Algorithm{Alg}(; kwargs...) where {Alg} = Algorithm{Alg}(NamedTuple(kwargs))
18+
Algorithm(s; kwargs...) = Algorithm{Symbol(s)}(NamedTuple(kwargs))
19+
20+
Algorithm(alg::Algorithm) = alg
21+
22+
# TODO: Use `SetParameters`.
23+
algorithm_string(::Algorithm{Alg}) where {Alg} = string(Alg)
24+
25+
function Base.show(io::IO, alg::Algorithm)
26+
return print(io, "Algorithm type ", algorithm_string(alg), ", ", alg.kwargs)
27+
end
28+
Base.print(io::IO, alg::Algorithm) = print(io, algorithm_string(alg), ", ", alg.kwargs)
29+
30+
"""
31+
@Algorithm_str
32+
33+
A convenience macro for writing [`Algorithm`](@ref) types, typically used when
34+
adding methods to a function that supports multiple algorithm
35+
backends.
36+
"""
37+
macro Algorithm_str(s)
38+
return :(Algorithm{$(Expr(:quote, Symbol(s)))})
39+
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
3+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
@eval module $(gensym())
2+
using Test: @test, @testset
3+
using NDTensors.AlgorithmSelection: Algorithm, @Algorithm_str
4+
@testset "AlgorithmSelection" begin
5+
@test Algorithm"alg"() isa Algorithm{:alg}
6+
@test Algorithm("alg") isa Algorithm{:alg}
7+
@test Algorithm(:alg) isa Algorithm{:alg}
8+
alg = Algorithm"alg"(; x=2, y=3)
9+
@test alg isa Algorithm{:alg}
10+
@test alg.kwargs == (; x=2, y=3)
11+
end
12+
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module BaseExtensions
2+
include("replace.jl")
3+
end
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
replace(collection, replacements::Pair...) = Base.replace(collection, replacements...)
2+
@static if VERSION < v"1.7.0-DEV.15"
3+
# https://github.com/JuliaLang/julia/pull/38216
4+
# TODO: Add to `Compat.jl` or delete when we drop Julia 1.6 support.
5+
# `replace` for Tuples.
6+
function _replace(f::Base.Callable, t::Tuple, count::Int)
7+
return if count == 0 || isempty(t)
8+
t
9+
else
10+
x = f(t[1])
11+
(x, _replace(f, Base.tail(t), count - !==(x, t[1]))...)
12+
end
13+
end
14+
15+
function replace(f::Base.Callable, t::Tuple; count::Integer=typemax(Int))
16+
return _replace(f, t, Base.check_count(count))
17+
end
18+
19+
function _replace(t::Tuple, count::Int, old_new::Tuple{Vararg{Pair}})
20+
return _replace(t, count) do x
21+
Base.@_inline_meta
22+
for o_n in old_new
23+
isequal(first(o_n), x) && return last(o_n)
24+
end
25+
return x
26+
end
27+
end
28+
29+
function replace(t::Tuple, old_new::Pair...; count::Integer=typemax(Int))
30+
return _replace(t, Base.check_count(count), old_new)
31+
end
32+
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
3+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using SafeTestsets: @safetestset
2+
3+
@safetestset "BaseExtensions" begin
4+
using NDTensors.BaseExtensions: BaseExtensions
5+
using Test: @test, @testset
6+
@testset "replace $(typeof(collection))" for collection in
7+
(["a", "b", "c"], ("a", "b", "c"))
8+
r1 = BaseExtensions.replace(collection, "b" => "d")
9+
@test r1 == typeof(collection)(["a", "d", "c"])
10+
@test typeof(r1) === typeof(collection)
11+
r2 = BaseExtensions.replace(collection, "b" => "d", "a" => "e")
12+
@test r2 == typeof(collection)(["e", "d", "c"])
13+
@test typeof(r2) === typeof(collection)
14+
end
15+
end

NDTensors/src/BlockSparseArrays/src/BlockSparseArrays.jl renamed to NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
module BlockSparseArrays
2+
using ..AlgorithmSelection: Algorithm, @Algorithm_str
23
using BlockArrays:
34
AbstractBlockArray,
45
BlockArrays,

NDTensors/src/BlockSparseArrays/src/LinearAlgebraExt/LinearAlgebraExt.jl renamed to NDTensors/src/lib/BlockSparseArrays/src/LinearAlgebraExt/LinearAlgebraExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module LinearAlgebraExt
2+
using ...AlgorithmSelection: Algorithm, @Algorithm_str
23
using BlockArrays: BlockArrays, blockedrange, blocks
34
using ..BlockSparseArrays: SparseArray, nonzero_keys # TODO: Move to `SparseArraysExtensions` module, rename `SparseArrayDOK`.
45
using ..BlockSparseArrays: BlockSparseArrays, BlockSparseArray, nonzero_blockkeys
56
using LinearAlgebra: LinearAlgebra, Hermitian, Transpose, I, eigen, qr
6-
using ...NDTensors: Algorithm, @Algorithm_str # TODO: Move to `AlgorithmSelector` module.
77
using SparseArrays: SparseArrays, SparseMatrixCSC, spzeros, sparse
88

99
# TODO: Move to `SparseArraysExtensions`.

NDTensors/src/BlockSparseArrays/test/runtests.jl renamed to NDTensors/src/lib/BlockSparseArrays/test/runtests.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ include("TestBlockSparseArraysUtils.jl")
1313
@testset "README" begin
1414
@test include(
1515
joinpath(
16-
pkgdir(BlockSparseArrays), "src", "BlockSparseArrays", "examples", "README.jl"
16+
pkgdir(BlockSparseArrays),
17+
"src",
18+
"lib",
19+
"BlockSparseArrays",
20+
"examples",
21+
"README.jl",
1722
),
1823
) isa Any
1924
end
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
module BroadcastMapConversion
2+
# Convert broadcast call to map call by capturing array arguments
3+
# with `map_args` and creating a map function with `map_function`.
4+
# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl.
5+
6+
using Base.Broadcast: Broadcasted
7+
8+
const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}
9+
10+
function map_args(bc::Broadcasted, rest...)
11+
return (map_args(bc.args...)..., map_args(rest...)...)
12+
end
13+
map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...)
14+
map_args(a, rest...) = map_args(rest...)
15+
map_args() = ()
16+
17+
struct MapFunction{F,Args<:Tuple}
18+
f::F
19+
args::Args
20+
end
21+
struct Arg end
22+
23+
# construct MapFunction
24+
function map_function(bc::Broadcasted)
25+
args = map_function_tuple(bc.args)
26+
return MapFunction(bc.f, args)
27+
end
28+
map_function_tuple(t::Tuple{}) = t
29+
map_function_tuple(t::Tuple) = (map_function(t[1]), map_function_tuple(Base.tail(t))...)
30+
map_function(a::WrappedScalarArgs) = a[]
31+
map_function(a::AbstractArray) = Arg()
32+
map_function(a) = a
33+
34+
# Evaluate MapFunction
35+
(f::MapFunction)(args...) = apply(f, args)[1]
36+
function apply(f::MapFunction, args)
37+
args, newargs = apply_tuple(f.args, args)
38+
return f.f(args...), newargs
39+
end
40+
apply(a::Arg, args::Tuple) = args[1], Base.tail(args)
41+
apply(a, args) = a, args
42+
apply_tuple(t::Tuple{}, args) = t, args
43+
function apply_tuple(t::Tuple, args)
44+
t1, newargs1 = apply(t[1], args)
45+
ttail, newargs = apply_tuple(Base.tail(t), newargs1)
46+
return (t1, ttail...), newargs
47+
end
48+
end
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
@eval module $(gensym())
2+
using Test: @test, @testset
3+
using NDTensors.BroadcastMapConversion: map_function, map_args
4+
@testset "BroadcastMapConversion" begin
5+
using Base.Broadcast: Broadcasted
6+
c = 2.2
7+
a = randn(2, 3)
8+
b = randn(2, 3)
9+
bc = Broadcasted(*, (c, a))
10+
@test copy(bc) c * a map(map_function(bc), map_args(bc)...)
11+
bc = Broadcasted(+, (a, b))
12+
@test copy(bc) a + b map(map_function(bc), map_args(bc)...)
13+
end
14+
end

0 commit comments

Comments
 (0)