Skip to content

Add ChainRules definitions #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 10, 2022
Merged

Conversation

devmotion
Copy link
Member

This PR adds ChainRules derivatives (both forward- and reverse-mode) for fft, ifft, bfft, rfft, irfft, brfft, fftshift and ifftshift, as discussed in FluxML/Zygote.jl#835.

Large parts of the PR are an extended implementation of FFTs with the AbstractFFTs interface in the tests that supports real transforms and dimensions as well. I added it together with additional comparisons against FFTW to be able to properly check the ChainRules definitions.

The reverse-mode definitions are loosely based on the adjoints in Zygote (https://github.com/FluxML/Zygote.jl/blob/e6a86745d66b5974eaafa8a8f28bcd4b100374df/src/lib/array.jl#L764-L899). However, in contrast to them the rules in this PR

  • define only rules for the versions with dims argument since all calls to fft etc. end up there as far as I understand,
  • use AbstractFFTs.normalization instead of manually computing prod(size(...))
  • fix a bug in the implementation of the pullbacks for rfft, irfft, and brfft in Zygote

On purpose, I did not include any derivatives for *(::Plan, ::AbstractArray) and \(::Plan, ::AbstractArray) in this PR. Since a special scaling is needed for rfft etc., as in the rules for rfft etc., it would be helpful to be able to infer the corresponding "type" of the plan, i.e., e.g. if it corresponds to fft or rfft. However, it seems this is not possible with the current interface. It seems currently one would have to use the sizes of the input and output as a heuristic for whether the plan is for a real transform or not. The implementation in Zygote is also problematic since it is too general: it does neither work for DCT plans (FluxML/Zygote.jl#899) nor for ScaledPlans (JuliaMath/FFTW.jl#182). It seems JuliaMath/FFTW.jl#182 could be solved with an additional API such as

region(p::Plan) = p.region
region(p::ScaledPlan) = region(p.p)

that could be used instead of p.region in general derivatives for *(::Plan, ...) etc.

@oxinabox
Copy link

oxinabox commented Sep 6, 2021

Nice.

This should definitely be added to ChainRulesCore, and ChainRulesTestUtils reverse dependency tests if this is merged.
So we can make certain not to ever break this by mistake.

Δy = fft(Δx, dims)
return y, Δy
end
function ChainRulesCore.rrule(::typeof(fft), x::AbstractArray, dims)
Copy link
Member

@stevengj stevengj Sep 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since fft(x) just calls plan_fft(x) * x, shouldn't the chain rules be defined for * and mul! operations with a Plan?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this but then I noticed the following problem: the interface does not allow to dispatch on something like Plan{fft} or Plan{rfft}, so we don't know which transform the plan computes. In the reverse-mode rules this would be desirable though (if we don't want to use heuristics based on e.g. sizes of x and P * x) since the rules for rfft/irfft/brfft require some additional scaling (caused by the conjugate symmetry). (The other problem with P.region that I mentioned in the comment above seems easier to solve, we could just add a region(::Plan) function that falls back to the field.)

@codecov
Copy link

codecov bot commented Sep 9, 2021

Codecov Report

Merging #58 (491086b) into master (d007201) will increase coverage by 30.12%.
The diff coverage is 100.00%.

Impacted file tree graph

@@             Coverage Diff             @@
##           master      #58       +/-   ##
===========================================
+ Coverage   52.63%   82.75%   +30.12%     
===========================================
  Files           2        2               
  Lines          95      203      +108     
===========================================
+ Hits           50      168      +118     
+ Misses         45       35       -10     
Impacted Files Coverage Δ
src/chainrules.jl 100.00% <100.00%> (ø)
src/definitions.jl 64.64% <100.00%> (+12.51%) ⬆️
src/AbstractFFTs.jl

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d007201...491086b. Read the comment docs.

@devmotion
Copy link
Member Author

The test failures of the reverse-mode rules for fftshift and ifftshift revealed that currently fftshift and ifftshift can't be inferred on Julia 1.0.

I thought I was able to fix the issues by changing ntuple(..., ndims(x)) to ntuple(..., Val(ndims(x))) but it turns out it only fixes the type inference problems for vectors and matrices but not for arrays with >= 3 dimensions. Apparently, the return type of circshift for such higher-dimensional arrays can't be inferred on Julia 1.0:

julia> f(x) = circshift(x, (1, 1, 1))
f (generic function with 1 method)

julia> @code_warntype f(rand(3, 4, 2))
Body::Any
1 1%1  = (Base.arraysize)(x, 1)::Int64                                                                                                                                                               │╻╷╷╷╷╷╷╷  circshift
  │   %2  = (Base.arraysize)(x, 2)::Int64                                                                                                                                                               ││╻         similar
  │   %3  = (Base.arraysize)(x, 3)::Int64                                                                                                                                                               │││╻         similar
  │   %4  = (Base.slt_int)(%1, 0)::Bool                                                                                                                                                                 ││││╻╷╷╷      axes
  │   %5  = (Base.ifelse)(%4, 0, %1)::Int64                                                                                                                                                             │││││┃│││      map
  │   %6  = (Base.slt_int)(%2, 0)::Bool                                                                                                                                                                 ││││││╻╷╷       Type
  │   %7  = (Base.ifelse)(%6, 0, %2)::Int64                                                                                                                                                             │││││││┃│        Type
  │   %8  = (Base.slt_int)(%3, 0)::Bool                                                                                                                                                                 │││││││╻╷╷       Type
  │   %9  = (Base.ifelse)(%8, 0, %3)::Int64                                                                                                                                                             ││││││││┃         max
  │   %10 = $(Expr(:foreigncall, :(:jl_alloc_array_3d), Array{Float64,3}, svec(Any, Int64, Int64, Int64), :(:ccall), 4, Array{Float64,3}, :(%5), :(%7), :(%9), :(%9), :(%7), :(%5)))::Array{Float64,3}  │││││╻╷        Type
  │   %11 = invoke Base.circshift!(%10::Array{Float64,3}, _2::Array{Float64,3}, (1, 1, 1)::Tuple{Int64,Int64,Int64})::Any                                                                               ││
  └──       return %11

Thus the type inference for vectors and matrices works now, both for the primal and the reverse-mode rule, but the tests fail since type inference of the reverse-mode rules with Array{Float64,3} does not work on Julia 1.0. Should we just disable type inference tests of the reverse-mode rule for this case on Julia 1.0?

@stevengj
Copy link
Member

I think it's fine to drop Julia 1.0 in the tests.

Regarding Plan{rfft}, we could presumably add a new subtype of Plan if needed?

@devmotion
Copy link
Member Author

I think it's fine to drop Julia 1.0 in the tests.

I disabled the failing type inference tests of rrule on Julia < 1.6. I don't think one should remove tests with Julia 1.0 completely (if this was your intention) since otherwise it would be quite easy to accidentally introduce changes that are not compatible with older Julia versions. I ran into such issues multiple times in other packages.

Regarding Plan{rfft}, we could presumably add a new subtype of Plan if needed?

We could add new subtypes of Plan but since the scaling is different for rfft, brfft and irfft one might need different types for all of them. I considered also if one should implement a default frule and rrule for Plan but the rrule would throw an error if the sizes of x and P * x are not equal. This would ensure that the rrule does not return incorrect values silently. Downstream packages that implement plans for real transforms would then have to define an rrule for their plan types.

@devmotion
Copy link
Member Author

What's missing from this PR? What should I do so it can be merged? 🙂

As discussed above, ideally we would define rules for Plans as well. However, as noted, we can't currently dispatch on the type of the Plan:

Since a special scaling is needed for rfft etc., as in the rules for rfft etc., it would be helpful to be able to infer the corresponding "type" of the plan, i.e., e.g. if it corresponds to fft or rfft. However, it seems this is not possible with the current interface. It seems currently one would have to use the sizes of the input and output as a heuristic for whether the plan is for a real transform or not.

So I think a somewhat nice implementation requires some changes of the types and interface in AbstractFFTs which seems a bit much for this PR (I assume).

I think the rules of fft etc. here are already an improvement. Additionally, it seems these rules are useful in any case, even if rules for Plans are defined as well: The AbstractFFTs interface does not require (at least in the official documentation) that calls such as fft construct and work with a Plan internally (the fallback definition of e.g. fft uses plan_fft but downstream packages could in principle redefine it, e.g., for certain array types).

@stevengj
Copy link
Member

LGTM.

@devmotion
Copy link
Member Author

Do you have any additional comments or suggestions, @oxinabox or @sethaxen (pinging you since you are familiar with ChainRules 😉)? If not, then I'll merge this PR at the end of the week. We can also iterate and improve the rules (in addition to the suggestions above) in subsequent PR if I missed anything.

@devmotion devmotion merged commit 2bae074 into JuliaMath:master Jan 10, 2022
@devmotion devmotion deleted the dw/chainrules branch January 10, 2022 23:49
trahflow added a commit to trahflow/Zygote.jl that referenced this pull request Mar 6, 2023
Partially addresses FluxML#1377

ChainRules for these
have been added in JuliaMath/AbstractFFTs.jl#58
CarloLucibello pushed a commit to FluxML/Zygote.jl that referenced this pull request Mar 8, 2023
* drop adjoints for [i,r,b]fft()

Partially addresses #1377

ChainRules for these
have been added in JuliaMath/AbstractFFTs.jl#58

* add back gradient test for *fft without dims argument

* increase compat constraint for AbstractFFTs
 to 1.3.1

* fix typo

Co-authored-by: Brian Chen <[email protected]>

---------

Co-authored-by: Brian Chen <[email protected]>
@vpuri3
Copy link
Contributor

vpuri3 commented Jun 26, 2023

@devmotion , any updates on this?

Is it possible now to write rrules for *(::Plan, ::AbstractArray) \(::Plan, ::AbstractArray)? Can we write separate dispatches for ScaledPlan, FFTW.DCTPlan, r2rFFTWPlan, etc in downstream packages? And add frules here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants