-
-
Notifications
You must be signed in to change notification settings - Fork 611
Description
I was still confused after my last issue, so I dove a little deeper. Specifically, I wanted to understand how RNN training works with the loss function used here:
loss(x, y) = sum((Flux.stack(m.(x),1) .- y) .^ 2)
I tested a much simplified version of this, using a 1 -> 1 RNN cell without an activation function, and the same loss function without the square:
m = Flux.RNN(1, 1, x -> x)
L(x, y) = sum((Flux.stack(m.(x), 1) .- y))
Here's where I had a problem: when there's more than one input/output sample, what's the derivative of L
with respect to Wi
(m.cell.Wi
)? Using some semi-random values:
x = [[0.3], [2.5]]
y = [0.5, 1.0]
m.cell.Wi .= [0.7] #Wi
m.cell.Wh .= [0.001] #Wh
m.cell.b .= [0.85] #b
m.state = [0.0] #h0, h1, h2, etc.
If you evaluate m.(x)
or L(x, y)
, you get the result you expect ([[1.06], [2.60106]]
and 2.16106
, respectively). For dL/dWi, it's easy to derive by inspection that it should be x1 + x2 + Wh*x1 = 2.8003. You could also get it by finite difference:
q = L(x, y)
m.cell.Wi .+= 0.01
m.state = [0.0]
r = L(x, y)
abs(q - r)/0.01 # = 2.8003
But when you use gradient
:
m.state = [0.0]
m.cell.Wi .= [0.7]
g = gradient(() -> L(x, y), params(m.cell.Wi))
g[m.cell.Wi] # = 2.8025
This result is equal to x1 + x2 + Wh*x2 instead of x1 + x2 + Wh*x1. Am I overlooking something, or is something weird happening here?