-
-
Notifications
You must be signed in to change notification settings - Fork 226
Function-local safe symbolic Jacobians of numerical functions #155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #155 +/- ##
==========================================
- Coverage 95.01% 92.36% -2.66%
==========================================
Files 11 11
Lines 401 432 +31
==========================================
+ Hits 381 399 +18
- Misses 20 33 +13
Continue to review full report at Codecov.
|
using ModelingToolkit, BenchmarkTools
f(x) = x^4 - 3x^3 + 46*(x-x-x-x)
@variables x
function clean_expr(O)
if O isa ModelingToolkit.Constant
return O.value
elseif isa(O.op, Variable)
isempty(O.args) && return O.op.name
return Expr(:call, O.op,name, to_expr.(O.args)...)
end
return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end
const to_expr = clean_expr #oops
Base.@pure function differentiate1(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(:($(x.op.name) -> $ex2))
@show first(methods(_f)).min_world
@show first(methods(_f)).max_world
first(methods(_f)).min_world-=10
first(methods(_f)).max_world=typemax(Int)
Base.invokelatest(_f,2)
first(methods(_f)).specializations.min_world-=10
first(methods(_f)).specializations.max_world=typemax(Int)
first(methods(_f)).specializations.func.min_world-=10
first(methods(_f)).specializations.func.max_world=typemax(Int)
first(methods(_f)).specializations.func.def.min_world-=10
first(methods(_f)).specializations.func.def.max_world=typemax(Int)
return _f
end
function differentiate2(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(:($(x.op.name) -> $ex2))
x -> Base.invokelatest(_f,x)
end
function differentiate3(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(quote
$(x.op.name) -> begin
$ex2
end
end)
Base.invokelatest(_f,2)
x -> superdeadlyunsafe_invokelatest(_f,x)
end
@inline function superdeadlyunsafe_invokelatest(f, args...)
_f = @cfunction $f Int (Int,)
return ccall(_f.ptr,Int,(Int,),args...)
end
function failure(f)
_f = differentiate3(f)
_f(5)
end
failure(f)
_df1 = differentiate1(f)
_df2 = differentiate2(f)
_df3 = differentiate3(f)
@btime $_df1($6) # 12
@btime $_df2($6) # 12
@btime $_df3($6) # 12
###############################################################
#### Generic version?
###############################################################
using ModelingToolkit, BenchmarkTools
f(x) = x^4 - 3x^3 + 46*(x-x-x-x)
@variables x
function to_expr(O)
if O isa ModelingToolkit.Constant
return O.value
elseif isa(O.op, Variable)
isempty(O.args) && return O.op.name
return Expr(:call, O.op,name, to_expr.(O.args)...)
end
return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end
function generate(O)
if isa(O.op, Variable)
isempty(O.args) && return O.op.name
return Expr(:call, O.op,name, to_expr.(O.args)...)
end
return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end
Base.@pure function differentiate1(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(:($(x.op.name) -> $ex2))
@show first(methods(_f)).min_world
@show first(methods(_f)).max_world
first(methods(_f)).min_world-=10
first(methods(_f)).max_world=typemax(Int)
Base.invokelatest(_f,2)
first(methods(_f)).specializations.min_world-=10
first(methods(_f)).specializations.max_world=typemax(Int)
first(methods(_f)).specializations.func.min_world-=10
first(methods(_f)).specializations.func.max_world=typemax(Int)
first(methods(_f)).specializations.func.def.min_world-=10
first(methods(_f)).specializations.func.def.max_world=typemax(Int)
return _f
end
Base.@pure function differentiate12(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(:($(x.op.name) -> $ex2))
return _f
end
function differentiate2(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(:($(x.op.name) -> $ex2))
x -> Base.invokelatest(_f,x)
end
function differentiate3(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(quote
$(x.op.name) -> begin
$ex2
end
end)
Base.invokelatest(_f,2)
rt = Base.return_types(_f,(typeof(2),))
let rt=first(rt)
x -> superdeadlyunsafe_invokelatest(_f,rt,x)
end
end
@inline @generated function superdeadlyunsafe_invokelatest(f, ::Type{T}, args...) where T
tupargs = Expr(:tuple,args...)
ex = quote
_f = $(Expr(:cfunction, Base.CFunction, :f, T, :((Core.svec)($args...)), :(:ccall)))
return ccall(_f.ptr,T,$tupargs,args...)
end
end
function failure(f)
_f = differentiate3(f)
_f(5)
end
failure(f)
_df1 = differentiate12(f)
_df2 = differentiate2(f)
_df3 = differentiate3(f)
_df2(6)
_df3(6)
@btime $_df1($6) # 12
@btime $_df2($6) # 12
@btime $_df3($6) # 12
function differentiate4(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(quote
$(x.op.name) -> begin
$ex2
end
end)
Base.invokelatest(_f,2.0)
x -> superdeadlyunsafe_invokelatest2(_f,x)
end
@inline @generated function superdeadlyunsafe_invokelatest2(f, args...)
tupargs = Expr(:tuple,args...)
ex = quote
_f = $(Expr(:cfunction, Base.CFunction, :f, Float64, :((Core.svec)($args...)), :(:ccall)))
return ccall(_f.ptr,Float64,$tupargs,args...)
end
Core.println(ex)
ex
end
_df4 = differentiate4(f)
_df4(6.0)
@btime $_df4($6.0)
Base.@pure super_rt(f,args) = first(Base.return_types(f,args))
function differentiate5(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(quote
$(x.op.name) -> begin
$ex2
end
end)
Base.invokelatest(_f,2.0)
rt = super_rt(_f,(typeof(2.0),))
let rt=rt
x -> superdeadlyunsafe_invokelatest4(_f,rt,x)
end
end
@inline @generated function superdeadlyunsafe_invokelatest4(f, ::Type{rt}, args...) where rt
tupargs = Expr(:tuple,args...)
ex = quote
_f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($args...)), :(:ccall)))
return ccall(_f.ptr,$rt,$tupargs,args...)
end
Core.println(ex)
ex
end
_df5 = differentiate5(f)
_df5(6.0)
@btime $_df5($6.0)
function differentiate6(f,::Type{rt}) where rt
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(quote
$(x.op.name) -> begin
$ex2
end
end)
Base.invokelatest(_f,2.0)
x -> superdeadlyunsafe_invokelatest4(_f,rt,x)
end
@inline @generated function superdeadlyunsafe_invokelatest4(f, ::Type{rt}, args...) where rt
tupargs = Expr(:tuple,args...)
ex = quote
_f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($args...)), :(:ccall)))
return ccall(_f.ptr,$rt,$tupargs,args...)
end
Core.println(ex)
ex
end
_df6 = differentiate6(f,Float64)
_df6(6.0)
@btime $_df6($6.0)
function differentiate7(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(quote
$(x.op.name) -> begin
$ex2
end
end)
Base.invokelatest(_f,2.0)
(rt,x) -> superdeadlyunsafe_invokelatest7(_f,rt,x)
end
@inline @generated function superdeadlyunsafe_invokelatest7(f, ::Type{rt}, args...) where rt
tupargs = Expr(:tuple,args...)
quote
_f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($args...)), :(:ccall)))
return ccall(_f.ptr,$rt,$tupargs,args...)
end
end
_df7 = differentiate7(f)
_df7(Float64,6.0)
_df7(Int,6)
@btime $_df7(Float64,$6.0)
@btime $_df7($Int,$6)
function failure(f)
_f = differentiate7(f)
_f(Int64,6)
end
failure(f)
function rewrap(f,x)
_f = differentiate7(f)
x -> _f(typeof(x),x)
end
_df8 = rewrap(f,6)
@btime $_df8($6.0)
@btime $_df8($6)
### Answer
using ModelingToolkit, BenchmarkTools
f(x) = x^4 - 3x^3 + 46*(x-x-x-x)
@variables x
function to_expr(O)
if O isa ModelingToolkit.Constant
return O.value
elseif isa(O.op, Variable)
isempty(O.args) && return O.op.name
return Expr(:call, O.op,name, to_expr.(O.args)...)
end
return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end
function generate(O)
if isa(O.op, Variable)
isempty(O.args) && return O.op.name
return Expr(:call, O.op,name, to_expr.(O.args)...)
end
return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end
function differentiate1(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
eval(:($(x.op.name) -> $ex2))
end
function _differentiate(f)
@variables x
@derivatives D'~x
op = f(x)
op2 = expand_derivatives(D(op))
ex2 = to_expr(op2)
_f = eval(quote
$(x.op.name) -> begin
$ex2
end
end)
(rt,x) -> superdeadlyunsafe_invokelatest(_f,rt,x)
end
@inline @generated function superdeadlyunsafe_invokelatest(f, ::Type{rt}, args...) where rt
tupargs = Expr(:tuple,args...)
quote
_f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($args...)), :(:ccall)))
return ccall(_f.ptr,rt,$tupargs,$((:(getindex(args,$i)) for i in 1:length(args))...))
end
end
function differentiate2(f)
_f = _differentiate(f)
x -> _f(typeof(x),x)
end
_df1 = differentiate1(f)
_df2 = differentiate2(f)
_df1(6)
_df2(6)
_df1(6.0)
_df2(6.0)
@btime $_df1($6) # 12
@btime $_df2($6) # 12
@btime $_df1($6.0) # 12
@btime $_df2($6.0) # 12 |
MWE of world-age tricks: function builder()
_f = eval(:(x -> x))
@show first(methods(_f)).min_world
@show first(methods(_f)).max_world
first(methods(_f)).min_world=0
first(methods(_f)).max_world=typemax(Int)
Base.invokelatest(_f,2)
first(methods(_f)).specializations.min_world=0
first(methods(_f)).specializations.max_world=typemax(Int)
first(methods(_f)).specializations.func.min_world=0
first(methods(_f)).specializations.func.max_world=typemax(Int)
first(methods(_f)).specializations.func.def.min_world=0
first(methods(_f)).specializations.func.def.max_world=typemax(Int)
first(methods(_f)).specializations.func.backedges = nothing
return _f
end
function worldage_failure()
_f = builder()
_f(5)
end
worldage_failure() |
@ChrisRackauckas this is great! So you're finding the Jacobian symbolically through ModelingToolkit, not using Zygote or other AD? Do you see this approach generalizing to probabilistic modeling applications outside the DiffEq space? |
Yes, it's symbolic and not AD. You can do this on any numerical function, so not just DiffEq. |
Okay, this is absolutely bonkers. This is the final result. Here's how you'd write a function that would compute the Jacobian of a known ODEProblem symbolically:
All symbolic manipulations and simplifications can thus be applied in here. As demonstrated that, it is world-age safe (function can be called in its callee context) and only has a 5ns overhead over writing down the Jacobian function yourself. A version with
safe=Val{false}
has no overhead but doesn't allow local evaluation of course due to world-age. A nicer fix might be to completely remove backedges, but that would be digging into Julia's compiler more directly and I will likely talk to someone before continuing down that route.