-
Notifications
You must be signed in to change notification settings - Fork 36
Add ChainRules definitions #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Nice. This should definitely be added to ChainRulesCore, and ChainRulesTestUtils reverse dependency tests if this is merged. |
Δy = fft(Δx, dims) | ||
return y, Δy | ||
end | ||
function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since fft(x)
just calls plan_fft(x) * x
, shouldn't the chain rules be defined for *
and mul!
operations with a Plan
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this but then I noticed the following problem: the interface does not allow to dispatch on something like Plan{fft}
or Plan{rfft}
, so we don't know which transform the plan computes. In the reverse-mode rules this would be desirable though (if we don't want to use heuristics based on e.g. sizes of x
and P * x
) since the rules for rfft
/irfft
/brfft
require some additional scaling (caused by the conjugate symmetry). (The other problem with P.region
that I mentioned in the comment above seems easier to solve, we could just add a region(::Plan)
function that falls back to the field.)
Codecov Report
@@ Coverage Diff @@
## master #58 +/- ##
===========================================
+ Coverage 52.63% 82.75% +30.12%
===========================================
Files 2 2
Lines 95 203 +108
===========================================
+ Hits 50 168 +118
+ Misses 45 35 -10
Continue to review full report at Codecov.
|
The test failures of the reverse-mode rules for I thought I was able to fix the issues by changing julia> f(x) = circshift(x, (1, 1, 1))
f (generic function with 1 method)
julia> @code_warntype f(rand(3, 4, 2))
Body::Any
1 1 ─ %1 = (Base.arraysize)(x, 1)::Int64 │╻╷╷╷╷╷╷╷ circshift
│ %2 = (Base.arraysize)(x, 2)::Int64 ││╻ similar
│ %3 = (Base.arraysize)(x, 3)::Int64 │││╻ similar
│ %4 = (Base.slt_int)(%1, 0)::Bool ││││╻╷╷╷ axes
│ %5 = (Base.ifelse)(%4, 0, %1)::Int64 │││││┃│││ map
│ %6 = (Base.slt_int)(%2, 0)::Bool ││││││╻╷╷ Type
│ %7 = (Base.ifelse)(%6, 0, %2)::Int64 │││││││┃│ Type
│ %8 = (Base.slt_int)(%3, 0)::Bool │││││││╻╷╷ Type
│ %9 = (Base.ifelse)(%8, 0, %3)::Int64 ││││││││┃ max
│ %10 = $(Expr(:foreigncall, :(:jl_alloc_array_3d), Array{Float64,3}, svec(Any, Int64, Int64, Int64), :(:ccall), 4, Array{Float64,3}, :(%5), :(%7), :(%9), :(%9), :(%7), :(%5)))::Array{Float64,3} │││││╻╷ Type
│ %11 = invoke Base.circshift!(%10::Array{Float64,3}, _2::Array{Float64,3}, (1, 1, 1)::Tuple{Int64,Int64,Int64})::Any ││
└── return %11 │ Thus the type inference for vectors and matrices works now, both for the primal and the reverse-mode rule, but the tests fail since type inference of the reverse-mode rules with |
I think it's fine to drop Julia 1.0 in the tests. Regarding |
I disabled the failing type inference tests of
We could add new subtypes of |
What's missing from this PR? What should I do so it can be merged? 🙂 As discussed above, ideally we would define rules for
So I think a somewhat nice implementation requires some changes of the types and interface in AbstractFFTs which seems a bit much for this PR (I assume). I think the rules of |
LGTM. |
Do you have any additional comments or suggestions, @oxinabox or @sethaxen (pinging you since you are familiar with ChainRules 😉)? If not, then I'll merge this PR at the end of the week. We can also iterate and improve the rules (in addition to the suggestions above) in subsequent PR if I missed anything. |
Partially addresses FluxML#1377 ChainRules for these have been added in JuliaMath/AbstractFFTs.jl#58
* drop adjoints for [i,r,b]fft() Partially addresses #1377 ChainRules for these have been added in JuliaMath/AbstractFFTs.jl#58 * add back gradient test for *fft without dims argument * increase compat constraint for AbstractFFTs to 1.3.1 * fix typo Co-authored-by: Brian Chen <[email protected]> --------- Co-authored-by: Brian Chen <[email protected]>
@devmotion , any updates on this? Is it possible now to write rrules for |
This PR adds ChainRules derivatives (both forward- and reverse-mode) for
fft
,ifft
,bfft
,rfft
,irfft
,brfft
,fftshift
andifftshift
, as discussed in FluxML/Zygote.jl#835.Large parts of the PR are an extended implementation of FFTs with the AbstractFFTs interface in the tests that supports real transforms and dimensions as well. I added it together with additional comparisons against FFTW to be able to properly check the ChainRules definitions.
The reverse-mode definitions are loosely based on the adjoints in Zygote (https://github.com/FluxML/Zygote.jl/blob/e6a86745d66b5974eaafa8a8f28bcd4b100374df/src/lib/array.jl#L764-L899). However, in contrast to them the rules in this PR
dims
argument since all calls tofft
etc. end up there as far as I understand,AbstractFFTs.normalization
instead of manually computingprod(size(...))
rfft
,irfft
, andbrfft
in ZygoteOn purpose, I did not include any derivatives for
*(::Plan, ::AbstractArray)
and\(::Plan, ::AbstractArray)
in this PR. Since a special scaling is needed forrfft
etc., as in the rules forrfft
etc., it would be helpful to be able to infer the corresponding "type" of the plan, i.e., e.g. if it corresponds tofft
orrfft
. However, it seems this is not possible with the current interface. It seems currently one would have to use the sizes of the input and output as a heuristic for whether the plan is for a real transform or not. The implementation in Zygote is also problematic since it is too general: it does neither work for DCT plans (FluxML/Zygote.jl#899) nor forScaledPlan
s (JuliaMath/FFTW.jl#182). It seems JuliaMath/FFTW.jl#182 could be solved with an additional API such asthat could be used instead of
p.region
in general derivatives for*(::Plan, ...)
etc.