The key words you want to look up here are “Wirtinger derivatives”. AFAIK you may have to wait until various packages use ChainRules.jl before this will work really nicely (I did think Zygote already supported complex derivatives to some extent. Maybe I got the wrong impression about that.).
You actually don’t need Wirtinger derivatives for complex functions with real input, this case is actually pretty straightforward with forward-mode AD, so implementing this in ForwardDiff.jl shouldn’t be that difficult. Because Zygote.jl uses reverse-mode AD, it is much better suited for differentiating functions with complex input, but real output. First-class complex differentiation support in ChainRules is still very WIP, and probably won’t be part of v1.0, since there are many challenges in supporting this for both forward- as well as reverse-mode AD. This is the PR working on this in the underlying ChainRulesCore.jl, and quite a bit of ChainRules.jl will have to be changed accordingly as well.
My usecase is rather simple. I have a real function of real argument which I need to calculate the derivative of. Several function layers deep inside that function I have a complex valued function of real argument. Further up the chain I take the real part. Despite this, the fact that the nested function is complex valued prevents ForwardDiff and Zygote from working. Until recently that deeply nested function was real valued and everything worked fine. I was hoping to find the minimum change to my code/packages to get this working. In principle I could separate the real and imaginary parts analytically but I would prefer to avoid that, as I would need to do it for every case.
If it helps, the “deeply nested” complex valued function of real argument I described above knows how to calculate its own derivative, so this can be provided. But I do not understand how to pass that information to e.g. ForwardDiff (so that it doesn’t need to look inside, but can just take it as an opaque function with derivative).
Zygote.jl should definitely be able to handle R -> R functions with intermediary complex functions. Have you actually tried it on your whole function, not just the part that is R -> C?
struct CSpline{Tx,Ty}
x::AbstractArray{Tx,1}
y::AbstractArray{Ty,1}
D::AbstractArray{Ty,1}
end
# make broadcast like a scalar
Broadcast.broadcastable(c::CSpline) = Ref(c)
function CSpline(x, y)
R = similar(y)
R[1] = y[2] - y[1]
for i in 2:(length(y)-1)
R[i] = y[i+1] - y[i-1]
end
R[end] = y[end] - y[end - 1]
@. R *= 3
d = fill(4.0, size(y))
d[1] = 2.0
d[end] = 2.0
dl = fill(1.0, length(y) - 1)
M = LinearAlgebra.Tridiagonal(dl, d, dl)
D = M \ R
CSpline(x, y, D)
end
function (c::CSpline)(x0)
if x0 <= c.x[1]
i = 2
elseif x0 >= c.x[end]
i = length(c.x)
else
i = findfirst(x0 .< c.x)
end
t = (x0 - c.x[i - 1])/(c.x[i] - c.x[i - 1])
c.y[i - 1] + c.D[i - 1]*t + (3*(c.y[i] - c.y[i - 1]) - 2*c.D[i - 1] - c.D[i])*t^2 + (2*(c.y[i - 1] - c.y[i]) + c.D[i - 1] + c.D[i])*t^3
end
and then I create g with (note this is an artificial test)
A_x = [1.0, 1.7, 8.0, 9.7, 10.3, 12.5, 32]
A = [4.0, 2.7, 1.0, 0.7, 4.3, 17.4, 43]
g = CSpline(A_x, A .+ im.*A./3.0)
But I believe the answer in my previous post is also incorrect (the derivative of the real function should be real, but it provided the full complex derivative of the nested function. Or am I missing some basic mathematics here?
Your function contains mutation of arrays, which Zygote can’t just differentiate through. You could take a look at Zygote.Buffer and see if that works for your case. The best solution might also be to just implement your own custom adjoint. BTW, your struct CSpline still contains abstract types, so it can’t be stack-allocated. Try
struct CSpline{Tx,Ty,Vx<:AbstractVector{Tx},Vy<:AbstractVector{Ty}}
x::Vx
y::Vy
D::Vy
end