Similar to homework 1, let's define a relative-error function for quantifying the error from finite differences:
using LinearAlgebra # for norm, kron, I(n) etc.
relerr(approx, exact) = norm(approx - exact) / norm(exact)
relerr (generic function with 1 method)
Let's also define the $\otimes$ operator (type it by \otimes
followed by TAB) to be the Kronecker product (the kron
function in the LinearAlgebra
library) so we can use math notation A ⊗ B
instead of kron(A, B)
:
const ⊗ = kron
kron (generic function with 21 methods)
We'll also load the Zygote AD library for analytical comparisons:
using Zygote
Here, $f(A) = \sqrt{A}$, and the homework says the Jacobian (acting on $\operatorname{vec}(A)$) should be $(I\otimes \sqrt{A} + \sqrt{A}^T \otimes I)^{-1}$. Let's try it out for a random positive-definite $A$:
n = 1 # should be big enough to be interesting
B = randn(n,n)
A = B'B # random positive definite
1×1 Matrix{Float64}: 1.3731197861011764
The matrix square root is just sqrt(A)
in Julia:
# check that sqrt(A)² ≈ A, up to roundoff errors:
relerr(sqrt(A)^2, A)
0.0
# Jacobian:
J = (I(n) ⊗ sqrt(A) + sqrt(A)' ⊗ I(n))^-1
1×1 Matrix{Float64}: 0.42669326873498054
Now, let's check this against finite difference for a small random $dA$:
dA = randn(n,n) * 1e-8
relerr( vec(sqrt(A+dA) - sqrt(A)), # finite-difference directional derivative
J * vec(dA) ) # vs. exact expression
3.726301788530775e-8
Hooray, it worked!
Let's define our function $f(p)$ as in the problem. We'll pass the extra parameters $A_0$ etcetera explicitly (or we could use global variables):
function f(p, A₀, a, B₀, b, F)
A = A₀ + diagm(p)
x = A \ a # A⁻¹a
B = B₀ + diagm(x.*x)
y = B \ b # B⁻¹b
return y'*F*y
end
f (generic function with 1 method)
Now, we'll pick some random parameters and try it out:
n = 5
A₀ = randn(n,n)
a = randn(5)
B₀ = randn(n,n)
b = randn(5)
F = randn(n,n); F = F + F'
p = rand(5)
f(p, A₀, a, B₀, b, F)
49.975441720918774
Now, let's implement the manual "adjoint" gradient from the solutions, using the same notation:
function ∇f(p, A₀, a, B₀, b, F)
# the forward solution, copied from above
# (note that in serious computation we would want to re-use this from f)
A = A₀ + diagm(p)
x = A \ a # A⁻¹a
B = B₀ + diagm(x.*x)
y = B \ b # B⁻¹b
# the reverse-mode gradient
g′ᵀ = 2F*y # step (i)
u = B' \ -g′ᵀ # step (ii): adjoint problem 1
w = 2u .* x .* y # step (iii)
z = A' \ -w # step (iv): adjoint problem 2
return z .* x # step (v): ∇f
end
∇f (generic function with 1 method)
∇f(p, A₀, a, B₀, b, F)
5-element Vector{Float64}: 262.55155103380923 -273.8595112733488 96.32881037878508 217.43756038923235 -171.79342086396048
First, we'll compare it against finite differences for a random small dp
:
dp = randn(n) * 1e-8
relerr( f(p+dp, A₀, a, B₀, b, F) - f(p, A₀, a, B₀, b, F), # finite difference
∇f(p, A₀, a, B₀, b, F)'dp) # exact directional derivative
1.428248665134681e-7
Hooray, it matches!
To compute $\nabla f$ with Zygote
, we have to give Zygote a function of a single parameter vector p
that we want to differentiate with respect to, along with the point p
at which we want the derivative. To do that, we will define an anonymous function with p -> ...
that captures the other parameters (also called a closure in computer science):
Zygote.gradient(p -> f(p, A₀, a, B₀, b, F), p)
([262.5515510338092, -273.8595112733488, 96.32881037878506, 217.43756038923235, -171.7934208639604],)
(Zygote returns a 1-component tuple of outputs, because it can potentially differentiate with respect to multiple arguments, though here we are just asking for 1.)
The above looks pretty good if we "eyeball" it compared to ∇f
above. Let's compare it quantitatively:
relerr(∇f(p, A₀, a, B₀, b, F), # manual ∇f
Zygote.gradient(p -> f(p, A₀, a, B₀, b, F), p)[1]) # vs AD
2.1572017312010847e-16
Hooray, it matches up to the limits of roundoff error (to essentially machine precision)!
Demonstrate numerically that $d(e^A) = \sum_{k=0}^\infty \! \! \frac{1}{k!} (\sum_{\ell=0}^{k-1} (A^T)^{k-\ell-1} \otimes A^\ell )dA$
A = rand(3,3)
3×3 Matrix{Float64}: 0.989401 0.372355 0.475427 0.94298 0.68495 0.646221 0.794569 0.13317 0.353663
using ForwardDiff
e(A) = sum(A^k/factorial(k) for k=0:20) # hmm exp doesn't work, i'll go to k=20
e (generic function with 1 method)
relerr(e(A), exp(A)) # check that our sum matches the built-in exp(A) function
1.2971024122201743e-15
J_AD = ForwardDiff.jacobian(e,A)
9×9 Matrix{Float64}: 3.35707 0.5389 0.688955 1.50121 … 1.04456 0.123629 0.158014 1.50121 2.82372 0.907286 0.466706 0.338304 0.92634 0.209389 1.04456 0.265598 2.45771 0.338304 0.244335 0.0576795 0.83731 0.5389 0.0624931 0.0798733 2.82372 0.265598 0.0290732 0.0371575 0.170853 0.479248 0.105876 1.3384 0.079272 0.237991 0.0492994 0.123629 0.0290732 0.434115 0.92634 … 0.0576795 0.0134119 0.216817 0.688955 0.0798733 0.102087 0.907286 2.45771 0.434115 0.555028 0.218367 0.612715 0.135322 0.289541 1.21467 2.02443 0.729772 0.158014 0.0371575 0.555028 0.209389 0.83731 0.216817 1.73404
ϵ = 1e-8
J = zeros(9,0) # initialize 9x9 Jacobian with 9 rows and no columns
for j=1:3, i=1:3
dA = zeros(3,3)
dA[i,j] = ϵ # perturb the (i,j) entry only
df = exp(A+dA)-exp(A) # see the perturbed exp
J = [J vec(df)] # append this to J
end
J_FD = J/ϵ
9×9 Matrix{Float64}: 3.35707 0.538899 0.688955 1.50121 … 1.04456 0.123629 0.158014 1.50121 2.82372 0.907286 0.466706 0.338304 0.92634 0.209389 1.04456 0.265598 2.45771 0.338304 0.244335 0.0576795 0.83731 0.5389 0.062493 0.0798732 2.82372 0.265598 0.0290732 0.0371576 0.170853 0.479248 0.105875 1.3384 0.0792719 0.237991 0.0492994 0.123629 0.0290732 0.434115 0.92634 … 0.0576795 0.0134118 0.216817 0.688955 0.0798732 0.102087 0.907286 2.45771 0.434115 0.555028 0.218367 0.612715 0.135321 0.289541 1.21467 2.02443 0.729772 0.158014 0.0371575 0.555028 0.209389 0.83731 0.216817 1.73404
# written as nested sum calls:
sum(sum( (A')^(k-ℓ-1) ⊗ A^ℓ for ℓ=0:(k-1))/factorial(k) for k=1:20)
9×9 Matrix{Float64}: 3.35707 0.5389 0.688955 1.50121 … 1.04456 0.123629 0.158014 1.50121 2.82372 0.907286 0.466706 0.338304 0.92634 0.209389 1.04456 0.265598 2.45771 0.338304 0.244335 0.0576795 0.83731 0.5389 0.0624931 0.0798733 2.82372 0.265598 0.0290732 0.0371575 0.170853 0.479248 0.105876 1.3384 0.079272 0.237991 0.0492994 0.123629 0.0290732 0.434115 0.92634 … 0.0576795 0.0134119 0.216817 0.688955 0.0798733 0.102087 0.907286 2.45771 0.434115 0.555028 0.218367 0.612715 0.135322 0.289541 1.21467 2.02443 0.729772 0.158014 0.0371575 0.555028 0.209389 0.83731 0.216817 1.73404
# same thing written another way:
J_20 = sum( (A')^(k-ℓ-1) ⊗ A^ℓ / factorial(k) for ℓ=0:20, k=1:20 if k>ℓ)
9×9 Matrix{Float64}: 3.35707 0.5389 0.688955 1.50121 … 1.04456 0.123629 0.158014 1.50121 2.82372 0.907286 0.466706 0.338304 0.92634 0.209389 1.04456 0.265598 2.45771 0.338304 0.244335 0.0576795 0.83731 0.5389 0.0624931 0.0798733 2.82372 0.265598 0.0290732 0.0371575 0.170853 0.479248 0.105876 1.3384 0.079272 0.237991 0.0492994 0.123629 0.0290732 0.434115 0.92634 … 0.0576795 0.0134119 0.216817 0.688955 0.0798733 0.102087 0.907286 2.45771 0.434115 0.555028 0.218367 0.612715 0.135322 0.289541 1.21467 2.02443 0.729772 0.158014 0.0371575 0.555028 0.209389 0.83731 0.216817 1.73404
Looks good let's check the match quantitatively:
relerr(J_20, J_AD) # should match to nearly machine precision
1.3738128345184809e-15
relerr(J_20, J_FD) # should match to ≈ 7 digits
1.293987776257493e-7
Hooray, math works!