Skip to content

Commit 0cfbe83

Browse files
committed
Define AlgorithmSelection and BroadcatMapConversion submodules
1 parent 31cd2b5 commit 0cfbe83

26 files changed

+88
-84
lines changed

NDTensors/src/NDTensors.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ using Strided
1919
using TimerOutputs
2020
using TupleTools
2121

22-
# TODO: Define an `AlgorithmSelection` module
2322
# TODO: List types, macros, and functions being used.
24-
include("lib/algorithm.jl")
23+
include("lib/AlgorithmSelection/src/AlgorithmSelection.jl")
24+
using .AlgorithmSelection: AlgorithmSelection
2525
include("lib/BaseExtensions/src/BaseExtensions.jl")
2626
using .BaseExtensions: BaseExtensions
2727
include("lib/SetParameters/src/SetParameters.jl")
2828
using .SetParameters
29+
include("lib/BroadcastMapConversion/src/BroadcastMapConversion.jl")
30+
using .BroadcastMapConversion: BroadcastMapConversion
2931
include("lib/Unwrap/src/Unwrap.jl")
3032
using .Unwrap
3133
include("lib/RankFactorization/src/RankFactorization.jl")

NDTensors/src/blocksparse/contract.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using .AlgorithmSelection: Algorithm, @Algorithm_str
2+
13
function contract(
24
tensor1::BlockSparseTensor,
35
labelstensor1,
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module AlgorithmSelection
2+
include("algorithm.jl")
3+
end
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
Algorithm
3+
4+
A type representing an algorithm backend for a function.
5+
6+
For example, a function might have multiple backend algorithm
7+
implementations, which internally are selected with an `Algorithm` type.
8+
9+
This allows users to extend functionality with a new algorithm but
10+
use the same interface.
11+
"""
12+
struct Algorithm{Alg,Kwargs<:NamedTuple}
13+
kwargs::Kwargs
14+
end
15+
16+
Algorithm{Alg}(kwargs::NamedTuple) where {Alg} = Algorithm{Alg,typeof(kwargs)}(kwargs)
17+
Algorithm{Alg}(; kwargs...) where {Alg} = Algorithm{Alg}(NamedTuple(kwargs))
18+
Algorithm(s; kwargs...) = Algorithm{Symbol(s)}(NamedTuple(kwargs))
19+
20+
Algorithm(alg::Algorithm) = alg
21+
22+
# TODO: Use `SetParameters`.
23+
algorithm_string(::Algorithm{Alg}) where {Alg} = string(Alg)
24+
25+
function Base.show(io::IO, alg::Algorithm)
26+
return print(io, "Algorithm type ", algorithm_string(alg), ", ", alg.kwargs)
27+
end
28+
Base.print(io::IO, alg::Algorithm) = print(io, algorithm_string(alg), ", ", alg.kwargs)
29+
30+
"""
31+
@Algorithm_str
32+
33+
A convenience macro for writing [`Algorithm`](@ref) types, typically used when
34+
adding methods to a function that supports multiple algorithm
35+
backends.
36+
"""
37+
macro Algorithm_str(s)
38+
return :(Algorithm{$(Expr(:quote, Symbol(s)))})
39+
end
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[deps]
2+
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
3+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
@eval module $(gensym())
2+
using Test: @test, @testset
3+
using NDTensors.AlgorithmSelection: Algorithm, @Algorithm_str
4+
@testset "AlgorithmSelection" begin
5+
@test Algorithm"alg"() isa Algorithm{:alg}
6+
@test Algorithm("alg") isa Algorithm{:alg}
7+
@test Algorithm(:alg) isa Algorithm{:alg}
8+
alg = Algorithm"alg"(; x=2, y=3)
9+
@test alg isa Algorithm{:alg}
10+
@test alg.kwargs == (; x=2, y=3)
11+
end
12+
end

NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
module BlockSparseArrays
2+
using ..AlgorithmSelection: Algorithm, @Algorithm_str
23
using BlockArrays:
34
AbstractBlockArray,
45
BlockArrays,

NDTensors/src/lib/BlockSparseArrays/src/LinearAlgebraExt/LinearAlgebraExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
module LinearAlgebraExt
2+
using ...AlgorithmSelection: Algorithm, @Algorithm_str
23
using BlockArrays: BlockArrays, blockedrange, blocks
34
using ..BlockSparseArrays: SparseArray, nonzero_keys # TODO: Move to `SparseArraysExtensions` module, rename `SparseArrayDOK`.
45
using ..BlockSparseArrays: BlockSparseArrays, BlockSparseArray, nonzero_blockkeys
56
using LinearAlgebra: LinearAlgebra, Hermitian, Transpose, I, eigen, qr
6-
using ...NDTensors: Algorithm, @Algorithm_str # TODO: Move to `AlgorithmSelector` module.
77
using SparseArrays: SparseArrays, SparseMatrixCSC, spzeros, sparse
88

99
# TODO: Move to `SparseArraysExtensions`.

NDTensors/src/lib/NamedDimsArrays/src/broadcastmapconversion.jl renamed to NDTensors/src/lib/BroadcastMapConversion/src/BroadcastMapConversion.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
module BroadcastMapConversion
12
# Convert broadcast call to map call by capturing array arguments
23
# with `map_args` and creating a map function with `map_function`.
34
# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl.
4-
# TODO: Move to a `BroadcastMapConversion` module.
55

66
using Base.Broadcast: Broadcasted
77

@@ -45,3 +45,4 @@ function apply_tuple(t::Tuple, args)
4545
ttail, newargs = apply_tuple(Base.tail(t), newargs1)
4646
return (t1, ttail...), newargs
4747
end
48+
end
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
@eval module $(gensym())
2+
using Test: @test, @testset
3+
using NDTensors.BroadcastMapConversion: map_function, map_args
4+
@testset "BroadcastMapConversion" begin
5+
using Base.Broadcast: Broadcasted
6+
c = 2.2
7+
a = randn(2, 3)
8+
b = randn(2, 3)
9+
bc = Broadcasted(*, (c, a))
10+
@test copy(bc) c * a map(map_function(bc), map_args(bc)...)
11+
bc = Broadcasted(+, (a, b))
12+
@test copy(bc) a + b map(map_function(bc), map_args(bc)...)
13+
end
14+
end

NDTensors/src/lib/NamedDimsArrays/src/NamedDimsArrays.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
module NamedDimsArrays
2-
using ..BaseExtensions: BaseExtensions
3-
4-
# TODO: Move to a `BroadcastMapConversion` module.
5-
include("broadcastmapconversion.jl")
6-
72
include("traits.jl")
83
include("randname.jl")
94
include("abstractnamedint.jl")

NDTensors/src/lib/NamedDimsArrays/src/abstractnameddimsarray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using ..BaseExtensions: BaseExtensions
2+
3+
# Some of the interface is inspired by:
14
# https://github.com/invenia/NamedDims.jl
25
# https://github.com/mcabbott/NamedPlus.jl
36

NDTensors/src/lib/NamedDimsArrays/src/broadcast.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
2+
using ..BroadcastMapConversion: map_function, map_args
23

34
struct NamedDimsArrayStyle{N} <: AbstractArrayStyle{N} end
45

NDTensors/src/lib/TensorAlgebra/src/TensorAlgebra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module TensorAlgebra
2+
using ..AlgorithmSelection: Algorithm, @Algorithm_str
23
using LinearAlgebra: mul!
3-
using ..NDTensors: Algorithm, @Algorithm_str
44

55
include("bipartitionedpermutation.jl")
66
include("fusedims.jl")

NDTensors/src/lib/algorithm.jl

Lines changed: 0 additions & 29 deletions
This file was deleted.

NDTensors/test/lib/BaseExtensions.jl

Lines changed: 0 additions & 3 deletions
This file was deleted.

NDTensors/test/lib/DiagonalArrays.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

NDTensors/test/lib/NamedDimsArrays.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

NDTensors/test/lib/SetParameters.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

NDTensors/test/lib/SmallVectors.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

NDTensors/test/lib/SortedSets.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

NDTensors/test/lib/TagSets.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

NDTensors/test/lib/TensorAlgebra.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

NDTensors/test/lib/Unwrap.jl

Lines changed: 0 additions & 4 deletions
This file was deleted.

NDTensors/test/lib/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
@eval module $(gensym())
22
using Test: @testset
33
@testset "Test NDTensors lib $lib" for lib in [
4+
"AlgorithmSelection",
45
"BaseExtensions",
56
"BlockSparseArrays",
7+
"BroadcastMapConversion",
68
"DiagonalArrays",
79
"NamedDimsArrays",
810
"SetParameters",

NDTensors/test/runtests.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,6 @@ end
2020
@safetestset "NDTensors" begin
2121
@testset "$filename" for filename in [
2222
"lib/runtests.jl",
23-
## "lib/BaseExtensions.jl",
24-
## "lib/BlockSparseArrays.jl",
25-
## "lib/DiagonalArrays.jl",
26-
## "lib/NamedDimsArrays.jl",
27-
## "lib/SetParameters.jl",
28-
## "lib/SmallVectors.jl",
29-
## "lib/SortedSets.jl",
30-
## "lib/TagSets.jl",
31-
## "lib/TensorAlgebra.jl",
32-
## "lib/Unwrap.jl",
3323
"linearalgebra.jl",
3424
"dense.jl",
3525
"blocksparse.jl",

0 commit comments

Comments
 (0)