Skip to content

Drop expr body in RGF #2166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ MacroTools = "0.5"
NaNMath = "0.3, 1"
RecursiveArrayTools = "2.3"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.4.3, 0.5"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "1.76.1"
Setfield = "0.7, 0.8, 1"
SimpleNonlinearSolve = "0.1.0"
Expand Down
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ using JumpProcesses
using DataStructures
using SpecialFunctions, NaNMath
using RuntimeGeneratedFunctions
using RuntimeGeneratedFunctions: drop_expr
using Base.Threads
using DiffEqCallbacks
using Graphs
Expand Down
3 changes: 2 additions & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ using ModelingToolkit.SystemStructures: algeqs, EquationsView

using ModelingToolkit.DiffEqBase
using ModelingToolkit.StaticArrays
using ModelingToolkit: @RuntimeGeneratedFunction, RuntimeGeneratedFunctions
using RuntimeGeneratedFunctions: @RuntimeGeneratedFunction, RuntimeGeneratedFunctions,
drop_expr

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down
6 changes: 3 additions & 3 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ function gen_nlsolve!(is_not_prepended_assignment, eqs, vars, u0map::AbstractDic
end

nlsolve_expr = Assignment[preassignments
fname ← @RuntimeGeneratedFunction(f)
fname ← drop_expr(@RuntimeGeneratedFunction(f))
DestructuredArgs(vars, inbounds = !checkbounds) ← solver_call]

nlsolve_expr
Expand Down Expand Up @@ -345,7 +345,7 @@ function build_torn_function(sys;
end
end

ODEFunction{true, SciMLBase.AutoSpecialize}(@RuntimeGeneratedFunction(expr),
ODEFunction{true, SciMLBase.AutoSpecialize}(drop_expr(@RuntimeGeneratedFunction(expr)),
sparsity = jacobian_sparsity ?
torn_system_with_nlsolve_jacobian_sparsity(state,
var_eq_matching,
Expand Down Expand Up @@ -501,7 +501,7 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
isscalar ? ts[1] : MakeArray(ts, output_type),
false))), sol_states)

expression ? ex : @RuntimeGeneratedFunction(ex)
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
end

"""
Expand Down
3 changes: 2 additions & 1 deletion src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,8 @@ function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; outputidxs = nothin
# applied user-provided function to the generated expression
if postprocess_affect_expr! !== nothing
postprocess_affect_expr!(rf_ip, integ)
(expression == Val{false}) && (return @RuntimeGeneratedFunction(rf_ip))
(expression == Val{false}) &&
(return drop_expr(@RuntimeGeneratedFunction(rf_ip)))
end
rf_ip
end
Expand Down
2 changes: 1 addition & 1 deletion src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
end
if eval_expression
affects = map(affect_funs) do a
@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a)))
drop_expr(@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a))))
end
else
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
Expand Down
12 changes: 7 additions & 5 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_expression ?
(@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
f_gen
f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)

Expand All @@ -299,7 +300,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
expression_module = eval_module,
checkbounds = checkbounds, kwargs...)
tgrad_oop, tgrad_iip = eval_expression ?
(@RuntimeGeneratedFunction(eval_module, ex) for ex in tgrad_gen) :
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in tgrad_gen) :
tgrad_gen
_tgrad(u, p, t) = tgrad_oop(u, p, t)
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
Expand All @@ -314,7 +315,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
expression_module = eval_module,
checkbounds = checkbounds, kwargs...)
jac_oop, jac_iip = eval_expression ?
(@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) :
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
jac_gen
_jac(u, p, t) = jac_oop(u, p, t)
_jac(J, u, p, t) = jac_iip(J, u, p, t)
Expand Down Expand Up @@ -423,7 +424,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_expression ?
(@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
f_gen
f(du, u, p, t) = f_oop(du, u, p, t)
f(out, du, u, p, t) = f_iip(out, du, u, p, t)

Expand All @@ -434,7 +436,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
expression_module = eval_module,
checkbounds = checkbounds, kwargs...)
jac_oop, jac_iip = eval_expression ?
(@RuntimeGeneratedFunction(eval_module, ex) for ex in jac_gen) :
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in jac_gen) :
jac_gen
_jac(du, u, p, ˍ₋gamma, t) = jac_oop(du, u, p, ˍ₋gamma, t)

Expand Down
2 changes: 1 addition & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ function build_explicit_observed_function(sys, ts;
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
false))) |> toexpr
expression ? ex : @RuntimeGeneratedFunction(ex)
expression ? ex : drop_expr(@RuntimeGeneratedFunction(ex))
end

function _eq_unordered(a, b)
Expand Down
15 changes: 9 additions & 6 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -385,10 +385,12 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys),
ps = scalarize.(ps)

f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, kwargs...)
f_oop, f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen
f_oop, f_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in f_gen) : f_gen
g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{eval_expression},
kwargs...)
g_oop, g_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in g_gen) : g_gen
g_oop, g_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in g_gen) : g_gen

f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)
Expand All @@ -399,7 +401,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys),
tgrad_gen = generate_tgrad(sys, dvs, ps; expression = Val{eval_expression},
kwargs...)
tgrad_oop, tgrad_iip = eval_expression ?
(@RuntimeGeneratedFunction(ex) for ex in tgrad_gen) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tgrad_gen) :
tgrad_gen
_tgrad(u, p, t) = tgrad_oop(u, p, t)
_tgrad(J, u, p, t) = tgrad_iip(J, u, p, t)
Expand All @@ -411,7 +413,8 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys),
jac_gen = generate_jacobian(sys, dvs, ps; expression = Val{eval_expression},
sparse = sparse, kwargs...)
jac_oop, jac_iip = eval_expression ?
(@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in jac_gen) :
jac_gen
_jac(u, p, t) = jac_oop(u, p, t)
_jac(J, u, p, t) = jac_iip(J, u, p, t)
else
Expand All @@ -422,10 +425,10 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys),
tmp_Wfact, tmp_Wfact_t = generate_factorized_W(sys, dvs, ps, true;
expression = Val{true}, kwargs...)
Wfact_oop, Wfact_iip = eval_expression ?
(@RuntimeGeneratedFunction(ex) for ex in tmp_Wfact) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact) :
tmp_Wfact
Wfact_oop_t, Wfact_iip_t = eval_expression ?
(@RuntimeGeneratedFunction(ex) for ex in tmp_Wfact_t) :
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in tmp_Wfact_t) :
tmp_Wfact_t
_Wfact(u, p, dtgamma, t) = Wfact_oop(u, p, dtgamma, t)
_Wfact(W, u, p, dtgamma, t) = Wfact_iip(W, u, p, dtgamma, t)
Expand Down
5 changes: 3 additions & 2 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ function SciMLBase.DiscreteProblem(sys::DiscreteSystem, u0map = [], tspan = get_

f_gen = generate_function(sys; expression = Val{eval_expression},
expression_module = eval_module)
f_oop, _ = (@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen)
f_oop, _ = (drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen)
f(u, p, iv) = f_oop(u, p, iv)
fd = DiscreteFunction(f; syms = Symbol.(dvs), indepsym = Symbol(iv),
paramsyms = Symbol.(ps), sys = sys)
Expand Down Expand Up @@ -339,7 +339,8 @@ function SciMLBase.DiscreteFunction{iip, specialize}(sys::DiscreteSystem,
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
expression_module = eval_module, kwargs...)
f_oop, f_iip = eval_expression ?
(@RuntimeGeneratedFunction(eval_module, ex) for ex in f_gen) : f_gen
(drop_expr(@RuntimeGeneratedFunction(eval_module, ex)) for ex in f_gen) :
f_gen
f(u, p, t) = f_oop(u, p, t)
f(du, u, p, t) = f_iip(du, u, p, t)

Expand Down
12 changes: 6 additions & 6 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,11 @@ function generate_affect_function(js::JumpSystem, affect, outputidxs)
end

function assemble_vrj(js, vrj, statetoid)
rate = @RuntimeGeneratedFunction(generate_rate_function(js, vrj.rate))
rate = drop_expr(@RuntimeGeneratedFunction(generate_rate_function(js, vrj.rate)))
outputvars = (value(affect.lhs) for affect in vrj.affect!)
outputidxs = [statetoid[var] for var in outputvars]
affect = @RuntimeGeneratedFunction(generate_affect_function(js, vrj.affect!,
outputidxs))
affect = drop_expr(@RuntimeGeneratedFunction(generate_affect_function(js, vrj.affect!,
outputidxs)))
VariableRateJump(rate, affect)
end

Expand All @@ -217,11 +217,11 @@ function assemble_vrj_expr(js, vrj, statetoid)
end

function assemble_crj(js, crj, statetoid)
rate = @RuntimeGeneratedFunction(generate_rate_function(js, crj.rate))
rate = drop_expr(@RuntimeGeneratedFunction(generate_rate_function(js, crj.rate)))
outputvars = (value(affect.lhs) for affect in crj.affect!)
outputidxs = [statetoid[var] for var in outputvars]
affect = @RuntimeGeneratedFunction(generate_affect_function(js, crj.affect!,
outputidxs))
affect = drop_expr(@RuntimeGeneratedFunction(generate_affect_function(js, crj.affect!,
outputidxs)))
ConstantRateJump(rate, affect)
end

Expand Down
6 changes: 4 additions & 2 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys
sparse = false, simplify = false,
kwargs...) where {iip}
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, kwargs...)
f_oop, f_iip = eval_expression ? (@RuntimeGeneratedFunction(ex) for ex in f_gen) : f_gen
f_oop, f_iip = eval_expression ?
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in f_gen) : f_gen
f(u, p) = f_oop(u, p)
f(du, u, p) = f_iip(du, u, p)

Expand All @@ -237,7 +238,8 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys
simplify = simplify, sparse = sparse,
expression = Val{eval_expression}, kwargs...)
jac_oop, jac_iip = eval_expression ?
(@RuntimeGeneratedFunction(ex) for ex in jac_gen) : jac_gen
(drop_expr(@RuntimeGeneratedFunction(ex)) for ex in jac_gen) :
jac_gen
_jac(u, p) = jac_oop(u, p)
_jac(J, u, p) = jac_iip(J, u, p)
else
Expand Down
2 changes: 1 addition & 1 deletion src/systems/pde/pdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ struct PDESystem <: ModelingToolkit.AbstractMultivariateSystem
p = ps isa SciMLBase.NullParameters ? [] : map(a -> a.first, ps)
args = vcat(DestructuredArgs(p), args)
ex = Func(args, [], eq.rhs) |> toexpr
eq.lhs => @RuntimeGeneratedFunction(ex)
eq.lhs => drop_expr(@RuntimeGeneratedFunction(ex))
end
end
end
Expand Down