Skip to content

WIP: Supporting functions for array symbolics #123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/abstractalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,16 @@ let
@rule(zero(~x) => 0)
@rule(one(~x) => 1)]

simterm(x, f, args;metadata=nothing) = similarterm(x,f,args, symtype(x); metadata=metadata)
mpoly_rules = [@rule(~x::ismpoly - ~y::ismpoly => ~x + -1 * (~y))
@rule(-(~x) => -1 * ~x)
@acrule(~x::ismpoly + ~y::ismpoly => ~x + ~y)
@rule(+(~x) => ~x)
@acrule(~x::ismpoly * ~y::ismpoly => ~x * ~y)
@rule(*(~x) => ~x)
@rule((~x::ismpoly)^(~a::isnonnegint) => (~x)^(~a))]
global const MPOLY_CLEANUP = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_preprocess))))
MPOLY_MAKER = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_rules))))
global const MPOLY_CLEANUP = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_preprocess)), similarterm=simterm))
MPOLY_MAKER = Fixpoint(Postwalk(PassThrough(RestartedChain(mpoly_rules)), similarterm=simterm))

global to_mpoly
function to_mpoly(t, dicts=_dicts())
Expand Down
2 changes: 1 addition & 1 deletion src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function substitute(expr, dict; fold=true)
else
args = map(x->substitute(x, dict), arguments(expr))
end
similarterm(expr, operation(expr), args, metadata=metadata(expr))
similarterm(expr, operation(expr), args, symtype(expr), metadata=metadata(expr))
else
expr
end
Expand Down
5 changes: 5 additions & 0 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,8 @@ promote_symtype(::typeof(ifelse), _, ::Type{T}, ::Type{S}) where {T,S} = Union{T
# Specially handle inv and literal pow
Base.inv(x::Symbolic{<:Number}) = Base.:^(x, -1)
Base.literal_pow(::typeof(^), x::Symbolic{<:Number}, ::Val{p}) where {p} = Base.:^(x, p)

# Array-like operations
Base.size(x::Symbolic{<:Number}) = ()
Base.length(x::Symbolic{<:Number}) = 1
Base.ndims(x::Symbolic{<:Number}) = 0
8 changes: 5 additions & 3 deletions src/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,22 @@ end

<ₑ(a::Sym, b::Sym) = a.name < b.name

<ₑ(a::Function, b::Function) = nameof(a) <ₑ nameof(b)

function cmp_term_term(a, b)
la = arglength(a)
lb = arglength(b)

if la == 0 && lb == 0
return nameof(operation(a)) <ₑ nameof(operation(b))
return operation(a) <ₑ operation(b)
elseif la === 0
return operation(a) <ₑ b
elseif lb === 0
return a <ₑ operation(b)
end

na = nameof(operation(a))
nb = nameof(operation(b))
na = operation(a)
nb = operation(b)

if 0 < arglength(a) <= 2 && 0 < arglength(b) <= 2
# e.g. a < sin(a) < b ^ 2 < b
Expand Down
6 changes: 4 additions & 2 deletions src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ function makepattern(expr, keys)
else
:(term($(map(x->makepattern(x, keys), expr.args)...); type=Any))
end
elseif expr.head === :ref
:(term(getindex, $(map(x->makepattern(x, keys), expr.args)...); type=Any))
elseif expr.head === :$
return esc(expr.args[1])
else
error("Unsupported Expr of type $(expr.head) found in pattern")
Expr(expr.head, makepattern.(expr.args, (keys,))...)
end
else
# treat as a literal
Expand Down Expand Up @@ -327,7 +329,7 @@ function (acr::ACRule)(term)
if !isnothing(result)
# Assumption: inds are unique
length(args) == length(inds) && return result
return similarterm(term, f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...])
return similarterm(term, f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], symtype(term))
end
end
end
Expand Down
43 changes: 34 additions & 9 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ Base.show(io::IO, v::Sym) = Base.show_unquoted(io, v.name)
# Maybe don't even need a new type, can just use Sym{FnType}
struct FnType{X<:Tuple,Y} end

(f::Sym{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, [args...])
(f::Symbolic{<:FnType})(args...) = Term{promote_symtype(f, symtype.(args)...)}(f, [args...])

function (f::Sym)(args...)
function (f::Symbolic)(args...)
error("Sym $f is not callable. " *
"Use @syms $f(var1, var2,...) to create it as a callable. " *
"See ?@fun for more options")
Expand All @@ -210,7 +210,7 @@ end
The output symtype of applying variable `f` to arugments of symtype `arg_symtypes...`.
if the arguments are of the wrong type then this function will error.
"""
function promote_symtype(f::Sym{FnType{X,Y}}, args...) where {X, Y}
function promote_symtype(f::Symbolic{FnType{X,Y}}, args...) where {X, Y}
if X === Tuple
return Y
end
Expand Down Expand Up @@ -250,8 +250,10 @@ macro syms(xs...)
defs = map(xs) do x
n, t = _name_type(x)
:($(esc(n)) = Sym{$(esc(t))}($(Expr(:quote, n))))
nt = _name_type(x)
n, t = nt.name, nt.type
:($(esc(n)) = Sym{$(esc(t))}($(Expr(:quote, n))))
end

Expr(:block, defs...,
:(tuple($(map(x->esc(_name_type(x).name), xs)...))))
end
Expand All @@ -275,14 +277,20 @@ function _name_type(x)
else
return (name=lhs, type=rhs)
end
elseif x isa Expr && x.head === :ref
ntype = _name_type(x.args[1]) # a::Number
N = length(x.args)-1
return (name=ntype.name,
type=:(Array{$(ntype.type), $N}),
array_metadata=:(Base.Slice.(($(x.args[2:end]...),))))
elseif x isa Expr && x.head === :call
return _name_type(:($x::Number))
else
syms_syntax_error()
end
end

function Base.show(io::IO, f::Sym{<:FnType{X,Y}}) where {X,Y}
function Base.show(io::IO, f::Symbolic{<:FnType{X,Y}}) where {X,Y}
print(io, f.name)
# Use `Base.unwrap_unionall` to handle `Tuple{T} where T`. This is not the
# best printing, but it's better than erroring.
Expand Down Expand Up @@ -433,7 +441,7 @@ setargs(t, args) = Term{symtype(t)}(operation(t), args)
cdrargs(args) = setargs(t, cdr(args))

print_arg(io, x::Union{Complex, Rational}; paren=true) = print(io, "(", x, ")")
isbinop(f) = istree(f) && Base.isbinaryoperator(nameof(operation(f)))
isbinop(f) = istree(f) && !istree(operation(f)) && Base.isbinaryoperator(nameof(operation(f)))
function print_arg(io, x; paren=false)
if paren && isbinop(x)
print(io, "(", x, ")")
Expand Down Expand Up @@ -506,8 +514,23 @@ function show_mul(io, args)
end
end

function show_ref(io, f, args)
x = args[1]
idx = args[2:end]

istree(x) && print(io, "(")
print(io, x)
istree(x) && print(io, ")")
print(io, "[")
for i=1:length(idx)
print_arg(io, idx[i])
i != length(idx) && print(io, ", ")
end
print(io, "]")
end

function show_call(io, f, args)
fname = nameof(f)
fname = istree(f) ? Symbol(repr(f)) : nameof(f)
binary = Base.isbinaryoperator(fname)
if binary
for (i, t) in enumerate(args)
Expand Down Expand Up @@ -543,6 +566,8 @@ function show_term(io::IO, t)
show_mul(io, args)
elseif f === (^)
show_pow(io, args)
elseif f === (getindex)
show_ref(io, f, args)
else
show_call(io, f, args)
end
Expand Down Expand Up @@ -573,7 +598,7 @@ where `coeff` and the vals are `<:Number` and keys are symbolic.
- `arguments(::Add)` -- returns a totally ordered vector of arguments. i.e.
`[coeff, keyM*valM, keyN*valN...]`
"""
struct Add{X, T<:Number, D, M} <: Symbolic{X}
struct Add{X<:Number, T<:Number, D, M} <: Symbolic{X}
coeff::T
dict::D
sorted_args_cache::Ref{Any}
Expand Down Expand Up @@ -699,7 +724,7 @@ where `coeff` and the vals are `<:Number` and keys are symbolic.
- `arguments(::Mul)` -- returns a totally ordered vector of arguments. i.e.
`[coeff, keyM^valM, keyN^valN...]`
"""
struct Mul{X, T<:Number, D, M} <: Symbolic{X}
struct Mul{X<:Number, T<:Number, D, M} <: Symbolic{X}
coeff::T
dict::D
sorted_args_cache::Ref{Any}
Expand Down