Skip to content

Function-local safe symbolic Jacobians of numerical functions #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 28, 2019

Conversation

ChrisRackauckas
Copy link
Member

Okay, this is absolutely bonkers. This is the final result. Here's how you'd write a function that would compute the Jacobian of a known ODEProblem symbolically:

using ModelingToolkit, OrdinaryDiffEq

function lotka(du,u,p,t)
  x = u[1]
  y = u[2]
  du[1] = p[1]*x - p[2]*x*y
  du[2] = -p[3]*y + p[4]*x*y
end

prob = ODEProblem(lotka,[1.0,1.0],(0.0,1.0),[1.5,1.0,3.0,1.0])

function calcjac(prob)
  de, vars, params = modelingtoolkitize(prob)
  J = prob.u0*prob.u0'
  ODEFunction(de, vars, params, Val{true}, jac = true).jac(J,prob.u0,prob.p,0.0)
  J
end
calcjac(prob)

All symbolic manipulations and simplifications can thus be applied in here. As demonstrated that, it is world-age safe (function can be called in its callee context) and only has a 5ns overhead over writing down the Jacobian function yourself. A version with safe=Val{false} has no overhead but doesn't allow local evaluation of course due to world-age. A nicer fix might be to completely remove backedges, but that would be digging into Julia's compiler more directly and I will likely talk to someone before continuing down that route.

@codecov
Copy link

codecov bot commented Jul 28, 2019

Codecov Report

Merging #155 into master will decrease coverage by 2.65%.
The diff coverage is 75%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #155      +/-   ##
==========================================
- Coverage   95.01%   92.36%   -2.66%     
==========================================
  Files          11       11              
  Lines         401      432      +31     
==========================================
+ Hits          381      399      +18     
- Misses         20       33      +13
Impacted Files Coverage Δ
src/ModelingToolkit.jl 75% <ø> (ø) ⬆️
src/variables.jl 94.73% <100%> (+0.19%) ⬆️
src/utils.jl 92.98% <100%> (-0.47%) ⬇️
src/systems/diffeqs/diffeqsystem.jl 90.08% <67.56%> (-9.92%) ⬇️
src/simplify.jl 92.5% <0%> (-2.5%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0c833cb...9ab9df3. Read the comment docs.

@ChrisRackauckas ChrisRackauckas merged commit bdcf035 into master Jul 28, 2019
@ChrisRackauckas ChrisRackauckas deleted the builder branch July 28, 2019 03:57
@ChrisRackauckas
Copy link
Member Author

ChrisRackauckas commented Jul 28, 2019

using ModelingToolkit, BenchmarkTools
f(x) = x^4 - 3x^3 + 46*(x-x-x-x)
@variables x

function clean_expr(O)
  if O isa ModelingToolkit.Constant
    return O.value
  elseif isa(O.op, Variable)
    isempty(O.args) && return O.op.name
    return Expr(:call, O.op,name, to_expr.(O.args)...)
  end
  return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end
const to_expr = clean_expr #oops

Base.@pure function differentiate1(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(:($(x.op.name) -> $ex2))
  @show first(methods(_f)).min_world
  @show first(methods(_f)).max_world
  first(methods(_f)).min_world-=10
  first(methods(_f)).max_world=typemax(Int)
  Base.invokelatest(_f,2)
  first(methods(_f)).specializations.min_world-=10
  first(methods(_f)).specializations.max_world=typemax(Int)
  first(methods(_f)).specializations.func.min_world-=10
  first(methods(_f)).specializations.func.max_world=typemax(Int)
  first(methods(_f)).specializations.func.def.min_world-=10
  first(methods(_f)).specializations.func.def.max_world=typemax(Int)
  return _f
end

function differentiate2(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(:($(x.op.name) -> $ex2))
  x -> Base.invokelatest(_f,x)
end

function differentiate3(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(quote
    $(x.op.name) -> begin
    $ex2
    end
  end)
  Base.invokelatest(_f,2)
  x -> superdeadlyunsafe_invokelatest(_f,x)
end

@inline function superdeadlyunsafe_invokelatest(f, args...)
  _f = @cfunction $f Int (Int,)
  return ccall(_f.ptr,Int,(Int,),args...)
end

function failure(f)
  _f = differentiate3(f)
  _f(5)
end
failure(f)

_df1 = differentiate1(f)
_df2 = differentiate2(f)
_df3 = differentiate3(f)
@btime $_df1($6) # 12
@btime $_df2($6) # 12
@btime $_df3($6) # 12























###############################################################
#### Generic version?
###############################################################

using ModelingToolkit, BenchmarkTools
f(x) = x^4 - 3x^3 + 46*(x-x-x-x)
@variables x

function to_expr(O)
  if O isa ModelingToolkit.Constant
    return O.value
  elseif isa(O.op, Variable)
    isempty(O.args) && return O.op.name
    return Expr(:call, O.op,name, to_expr.(O.args)...)
  end
  return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end

function generate(O)
  if isa(O.op, Variable)
    isempty(O.args) && return O.op.name
    return Expr(:call, O.op,name, to_expr.(O.args)...)
  end
  return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end

Base.@pure function differentiate1(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(:($(x.op.name) -> $ex2))
  @show first(methods(_f)).min_world
  @show first(methods(_f)).max_world
  first(methods(_f)).min_world-=10
  first(methods(_f)).max_world=typemax(Int)
  Base.invokelatest(_f,2)
  first(methods(_f)).specializations.min_world-=10
  first(methods(_f)).specializations.max_world=typemax(Int)
  first(methods(_f)).specializations.func.min_world-=10
  first(methods(_f)).specializations.func.max_world=typemax(Int)
  first(methods(_f)).specializations.func.def.min_world-=10
  first(methods(_f)).specializations.func.def.max_world=typemax(Int)
  return _f
end

Base.@pure function differentiate12(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(:($(x.op.name) -> $ex2))
  return _f
end

function differentiate2(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(:($(x.op.name) -> $ex2))
  x -> Base.invokelatest(_f,x)
end

function differentiate3(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(quote
    $(x.op.name) -> begin
    $ex2
    end
  end)
  Base.invokelatest(_f,2)
  rt = Base.return_types(_f,(typeof(2),))
  let rt=first(rt)
    x -> superdeadlyunsafe_invokelatest(_f,rt,x)
  end
end

@inline @generated function superdeadlyunsafe_invokelatest(f, ::Type{T}, args...) where T
  tupargs = Expr(:tuple,args...)
  ex = quote
    _f = $(Expr(:cfunction, Base.CFunction, :f, T, :((Core.svec)($args...)), :(:ccall)))
    return ccall(_f.ptr,T,$tupargs,args...)
  end
end

function failure(f)
  _f = differentiate3(f)
  _f(5)
end
failure(f)

_df1 = differentiate12(f)
_df2 = differentiate2(f)
_df3 = differentiate3(f)
_df2(6)
_df3(6)

@btime $_df1($6) # 12
@btime $_df2($6) # 12
@btime $_df3($6) # 12

function differentiate4(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(quote
    $(x.op.name) -> begin
    $ex2
    end
  end)
  Base.invokelatest(_f,2.0)
  x -> superdeadlyunsafe_invokelatest2(_f,x)
end

@inline @generated function superdeadlyunsafe_invokelatest2(f, args...)
  tupargs = Expr(:tuple,args...)
  ex = quote
    _f = $(Expr(:cfunction, Base.CFunction, :f, Float64, :((Core.svec)($args...)), :(:ccall)))
    return ccall(_f.ptr,Float64,$tupargs,args...)
  end
  Core.println(ex)
  ex
end

_df4 = differentiate4(f)
_df4(6.0)
@btime $_df4($6.0)

Base.@pure super_rt(f,args) = first(Base.return_types(f,args))
function differentiate5(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(quote
    $(x.op.name) -> begin
    $ex2
    end
  end)
  Base.invokelatest(_f,2.0)
  rt = super_rt(_f,(typeof(2.0),))
  let rt=rt
    x -> superdeadlyunsafe_invokelatest4(_f,rt,x)
  end
end

@inline @generated function superdeadlyunsafe_invokelatest4(f, ::Type{rt}, args...) where rt
  tupargs = Expr(:tuple,args...)
  ex = quote
    _f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($args...)), :(:ccall)))
    return ccall(_f.ptr,$rt,$tupargs,args...)
  end
  Core.println(ex)
  ex
end

_df5 = differentiate5(f)
_df5(6.0)
@btime $_df5($6.0)

function differentiate6(f,::Type{rt}) where rt
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(quote
    $(x.op.name) -> begin
    $ex2
    end
  end)
  Base.invokelatest(_f,2.0)
  x -> superdeadlyunsafe_invokelatest4(_f,rt,x)
end

@inline @generated function superdeadlyunsafe_invokelatest4(f, ::Type{rt}, args...) where rt
  tupargs = Expr(:tuple,args...)
  ex = quote
    _f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($args...)), :(:ccall)))
    return ccall(_f.ptr,$rt,$tupargs,args...)
  end
  Core.println(ex)
  ex
end

_df6 = differentiate6(f,Float64)
_df6(6.0)
@btime $_df6($6.0)

function differentiate7(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(quote
    $(x.op.name) -> begin
    $ex2
    end
  end)
  Base.invokelatest(_f,2.0)
  (rt,x) -> superdeadlyunsafe_invokelatest7(_f,rt,x)
end

@inline @generated function superdeadlyunsafe_invokelatest7(f, ::Type{rt}, args...) where rt
  tupargs = Expr(:tuple,args...)
  quote
    _f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($args...)), :(:ccall)))
    return ccall(_f.ptr,$rt,$tupargs,args...)
  end
end

_df7 = differentiate7(f)
_df7(Float64,6.0)
_df7(Int,6)
@btime $_df7(Float64,$6.0)
@btime $_df7($Int,$6)

function failure(f)
  _f = differentiate7(f)
  _f(Int64,6)
end
failure(f)

function rewrap(f,x)
  _f = differentiate7(f)
  x -> _f(typeof(x),x)
end
_df8 = rewrap(f,6)
@btime $_df8($6.0)
@btime $_df8($6)









### Answer

using ModelingToolkit, BenchmarkTools
f(x) = x^4 - 3x^3 + 46*(x-x-x-x)
@variables x

function to_expr(O)
  if O isa ModelingToolkit.Constant
    return O.value
  elseif isa(O.op, Variable)
    isempty(O.args) && return O.op.name
    return Expr(:call, O.op,name, to_expr.(O.args)...)
  end
  return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end

function generate(O)
  if isa(O.op, Variable)
    isempty(O.args) && return O.op.name
    return Expr(:call, O.op,name, to_expr.(O.args)...)
  end
  return Expr(:call, Symbol(O.op), to_expr.(O.args)...)
end

function differentiate1(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  eval(:($(x.op.name) -> $ex2))
end

function _differentiate(f)
  @variables x
  @derivatives D'~x
  op = f(x)
  op2 = expand_derivatives(D(op))
  ex2 = to_expr(op2)
  _f = eval(quote
    $(x.op.name) -> begin
    $ex2
    end
  end)
  (rt,x) -> superdeadlyunsafe_invokelatest(_f,rt,x)
end

@inline @generated function superdeadlyunsafe_invokelatest(f, ::Type{rt}, args...) where rt
  tupargs = Expr(:tuple,args...)
  quote
    _f = $(Expr(:cfunction, Base.CFunction, :f, rt, :((Core.svec)($args...)), :(:ccall)))
    return ccall(_f.ptr,rt,$tupargs,$((:(getindex(args,$i)) for i in 1:length(args))...))
  end
end

function differentiate2(f)
  _f = _differentiate(f)
  x -> _f(typeof(x),x)
end

_df1 = differentiate1(f)
_df2 = differentiate2(f)
_df1(6)
_df2(6)
_df1(6.0)
_df2(6.0)

@btime $_df1($6) # 12
@btime $_df2($6) # 12
@btime $_df1($6.0) # 12
@btime $_df2($6.0) # 12

@ChrisRackauckas
Copy link
Member Author

MWE of world-age tricks:

function builder()
  _f = eval(:(x -> x))
  @show first(methods(_f)).min_world
  @show first(methods(_f)).max_world
  first(methods(_f)).min_world=0
  first(methods(_f)).max_world=typemax(Int)
  Base.invokelatest(_f,2)
  first(methods(_f)).specializations.min_world=0
  first(methods(_f)).specializations.max_world=typemax(Int)
  first(methods(_f)).specializations.func.min_world=0
  first(methods(_f)).specializations.func.max_world=typemax(Int)
  first(methods(_f)).specializations.func.def.min_world=0
  first(methods(_f)).specializations.func.def.max_world=typemax(Int)
  first(methods(_f)).specializations.func.backedges = nothing
  return _f
end

function worldage_failure()
  _f = builder()
  _f(5)
end
worldage_failure()

@cscherrer
Copy link

@ChrisRackauckas this is great! So you're finding the Jacobian symbolically through ModelingToolkit, not using Zygote or other AD? Do you see this approach generalizing to probabilistic modeling applications outside the DiffEq space?

@ChrisRackauckas
Copy link
Member Author

Yes, it's symbolic and not AD. You can do this on any numerical function, so not just DiffEq.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants