Skip to content

Commit 4ddf19d

Browse files
author
jeremiedb
committed
drop CUDNN rnn
1 parent f93a114 commit 4ddf19d

File tree

7 files changed

+201
-400
lines changed

7 files changed

+201
-400
lines changed

src/Flux.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ include("functor.jl")
3838
include("layers/stateless.jl")
3939
include("layers/basic.jl")
4040
include("layers/conv.jl")
41-
# include("layers/recurrent.jl")
42-
include("layers/recurrent_jdb.jl")
41+
include("layers/recurrent.jl")
4342
include("layers/normalise.jl")
4443

4544
include("data/Data.jl")

src/cuda/cuda.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ module CUDAint
33
using ..CUDA
44

55
using CUDA: CUDNN
6+
7+
import ..Flux: Flux
8+
import Zygote
9+
using Zygote: @adjoint
10+
611
# include("curnn.jl")
7-
include("curnn_jdb_v1.jl")
812
include("cudnn.jl")
913

1014
end

src/cuda/curnn.jl

Lines changed: 95 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,89 +1,95 @@
1-
import ..Flux: Flux, relu
2-
3-
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
4-
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
5-
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
6-
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
7-
8-
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
9-
h, i = length(m.h), size(m.Wi, 2)
10-
mode = m isa CuRNN ?
11-
(m.σ == tanh ? CUDNN.CUDNN_RNN_TANH : CUDNN.CUDNN_RNN_RELU) :
12-
m isa CuGRU ? CUDNN.CUDNN_GRU : CUDNN.CUDNN_LSTM
13-
r = CUDNN.RNNDesc{T}(mode, i, h)
14-
return r
15-
end
16-
17-
const descs = WeakKeyDict()
18-
19-
function desc(rnn)
20-
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
21-
CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
22-
return d
23-
end
24-
25-
import Zygote
26-
using Zygote: @adjoint
27-
28-
function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
29-
y, h′ = CUDNN.forward(desc(m), x, h)
30-
return h′, y
31-
end
32-
33-
function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
34-
y, h′ = CUDNN.forward(desc(m), x, h)
35-
return h′, y
36-
end
37-
38-
function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
39-
y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2])
40-
return (h′, c′), y
41-
end
42-
43-
(m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
44-
(m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
45-
(m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
46-
47-
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
48-
49-
unbroadcast(x::AbstractArray, Δ) =
50-
size(x) == size(Δ) ? Δ :
51-
length(x) == length(Δ) ? trim(x, Δ) :
52-
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
53-
54-
coerce_cuda(x::Union{CuArray,Nothing}) = x
55-
coerce_cuda(x::Tuple) = coerce_cuda.(x)
56-
57-
coerce_cuda(x::AbstractArray) = x .+ CUDA.fill(0)
58-
59-
function struct_grad!(cx::Zygote.Context, x, x̄)
60-
for f in fieldnames(typeof(x))
61-
Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f))
62-
end
63-
dx = Zygote.grad_mut(cx, x)
64-
dx[] = Zygote.accum(dx[], x̄)
65-
return dx
66-
end
67-
68-
for RNN in (CuRNN, CuGRU)
69-
@eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
70-
(y, ho), back = CUDNN.pullback(desc(m), x, h)
71-
(ho, y), function (Δ)
72-
dho, dy = coerce_cuda(Δ) # Support FillArrays etc.
73-
= back(dy, dho)
74-
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=.b,h=nothing))
75-
(dm, unbroadcast(h, m̄.h), m̄.x)
76-
end
77-
end
78-
end
79-
80-
@adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
81-
(y, ho, co), back = CUDNN.pullback(desc(m), x, h, c)
82-
((ho, co), y), function (Δ)
83-
dhc, dy = coerce_cuda(Δ) # Support FillArrays etc.
84-
dho, dco = dhc === nothing ? (nothing, nothing) : dhc
85-
= back(dy, dho, dco)
86-
dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=.b,h=nothing,c=nothing))
87-
(dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x)
88-
end
89-
end
1+
# import ..Flux: relu
2+
#
3+
# CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
4+
# CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
5+
# CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
6+
# CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
7+
#
8+
# function CUDNN.RNNDesc(m::CuRNNs{T}) where T
9+
# if isa(m, CuRNN)
10+
# m.σ == tanh ? mode = CUDNN.CUDNN_RNN_TANH : mode = CUDNN.CUDNN_RNN_RELU
11+
# h, i = length(m.b), size(m.Wi, 2)
12+
# elseif isa(m, CuGRU)
13+
# mode = CUDNN.CUDNN_GRU
14+
# h, i = length(m.b)÷3, size(m.Wi, 2)
15+
# elseif isa(m, CuLSTM)
16+
# mode = CUDNN.CUDNN_LSTM
17+
# h, i = length(m.b)÷4, size(m.Wi, 2)
18+
# println("h: ", h, ", i:", i)
19+
# else
20+
# error("typeof m ∉ {CuRNN, CuGRU, CuLSTM}")
21+
# end
22+
# r = CUDNN.RNNDesc{T}(mode, i, h)
23+
# return r
24+
# end
25+
#
26+
# const descs = WeakKeyDict()
27+
#
28+
# function desc(rnn)
29+
# d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
30+
# CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
31+
# return d
32+
# end
33+
#
34+
# function (m::CuRNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
35+
# y, h′ = CUDNN.forward(desc(m), x, h)
36+
# return h′, y
37+
# end
38+
#
39+
# function (m::CuGRU{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
40+
# y, h′ = CUDNN.forward(desc(m), x, h)
41+
# return h′, y
42+
# end
43+
#
44+
# function (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
45+
# y, h′, c′ = CUDNN.forward(desc(m), x, h[1], h[2])
46+
# return (h′, c′), y
47+
# end
48+
#
49+
# (m::CuRNN{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
50+
# (m::CuGRU{T})(h::CuArray{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
51+
# (m::CuLSTM{T})(h::NTuple{2,CuArray{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
52+
#
53+
# trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
54+
#
55+
# unbroadcast(x::AbstractArray, Δ) =
56+
# size(x) == size(Δ) ? Δ :
57+
# length(x) == length(Δ) ? trim(x, Δ) :
58+
# trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
59+
#
60+
# coerce_cuda(x::Union{CuArray,Nothing}) = x
61+
# coerce_cuda(x::Tuple) = coerce_cuda.(x)
62+
#
63+
# coerce_cuda(x::AbstractArray) = x .+ CUDA.fill(0)
64+
#
65+
# function struct_grad!(cx::Zygote.Context, x, x̄)
66+
# for f in fieldnames(typeof(x))
67+
# Zygote.accum_param(cx, getfield(x, f), getfield(x̄, f))
68+
# end
69+
# dx = Zygote.grad_mut(cx, x)
70+
# dx[] = Zygote.accum(dx[], x̄)
71+
# return dx
72+
# end
73+
#
74+
# for RNN in (CuRNN, CuGRU)
75+
# @eval @adjoint function (m::$RNN{T})(h::CuArray{T}, x::CuArray{T}) where T <: Union{Float32,Float64}
76+
# (y, ho), back = CUDNN.pullback(desc(m), x, h)
77+
# (ho, y), function (Δ)
78+
# dho, dy = coerce_cuda(Δ) # Support FillArrays etc.
79+
# m̄ = back(dy, dho)
80+
# dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing))
81+
# (dm, unbroadcast(h, m̄.h), m̄.x)
82+
# end
83+
# end
84+
# end
85+
#
86+
# @adjoint function (m::CuLSTM)((h, c)::Tuple{CuArray{T},CuArray{T}}, x::CuArray{T}) where T <: Union{Float32,Float64}
87+
# (y, ho, co), back = CUDNN.pullback(desc(m), x, h, c)
88+
# ((ho, co), y), function (Δ)
89+
# dhc, dy = coerce_cuda(Δ) # Support FillArrays etc.
90+
# dho, dco = dhc === nothing ? (nothing, nothing) : dhc
91+
# m̄ = back(dy, dho, dco)
92+
# dm = struct_grad!(__context__, m, (σ=nothing,Wi=transpose(m̄.Wi),Wh=transpose(m̄.Wh),b=m̄.b,h=nothing,c=nothing))
93+
# (dm, (unbroadcast(h, m̄.h), unbroadcast(c, m̄.c)), m̄.x)
94+
# end
95+
# end

src/cuda/curnn_jdb_v1.jl

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)