Skip to content

Add derivatives for the splines #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6932efe
Add derivatives for the splines
Aug 5, 2020
40c620b
Attempt to add frule and rrule
Aug 5, 2020
8e2fd11
Use _interpolate function
Aug 5, 2020
f8fe5d9
Apply frule and rrule directly on the functors
Aug 6, 2020
1b85892
Remove _interpolate
Aug 6, 2020
854c48f
Fix derivative for BSpline
Aug 6, 2020
35a9e6b
Tests passing for methods other than BSpline
Aug 6, 2020
6c64d89
Tests passing for methods other than BSpline
Aug 6, 2020
8265a22
Apply suggestions from code review
Dec 29, 2020
05b8eef
Introduce Caching in Lagrange Interpolation
Dec 29, 2020
d15d17e
Remove BSpline derivatives for the time being
Dec 29, 2020
81966e1
Add derivatives for the splines
Aug 5, 2020
304a858
Attempt to add frule and rrule
Aug 5, 2020
a144fd9
Use _interpolate function
Aug 5, 2020
48ce251
Apply frule and rrule directly on the functors
Aug 6, 2020
e4e654d
Remove _interpolate
Aug 6, 2020
a63321a
Fix derivative for BSpline
Aug 6, 2020
0110b85
Tests passing for methods other than BSpline
Aug 6, 2020
27eaa7e
Tests passing for methods other than BSpline
Aug 6, 2020
31cdccc
Apply suggestions from code review
Dec 29, 2020
611acfc
Introduce Caching in Lagrange Interpolation
Dec 29, 2020
484749d
Remove BSpline derivatives for the time being
Dec 29, 2020
908cfdc
Merge branch 'ap/derivatives' of github.com:avik-pal/DataInterpolatio…
Jan 8, 2021
da73445
Update code to be consistent with master
Jan 8, 2021
526e150
Fix the scale parameter for Bsplines
Jan 8, 2021
38455ef
change CI tested versions
ChrisRackauckas Jan 8, 2021
bc3738f
change CI tested versions
ChrisRackauckas Jan 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ os:
- linux
- osx
julia:
- 1.3
- 1.4
- nightly
- 1
notifications:
email: false
git:
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
version = "3.2.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand All @@ -17,8 +18,9 @@ Reexport = "0.2, 1.0"
julia = "1.3"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random"]
test = ["Test", "Random", "FiniteDifferences"]
12 changes: 11 additions & 1 deletion src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,23 @@ Base.setindex!(A::AbstractInterpolation,x,i) = A.u[i] = x
Base.setindex!(A::AbstractInterpolation{true},x,i) =
i <= length(A.u) ? (A.u[i] = x) : (A.t[i-length(A.u)] = x)

using LinearAlgebra, RecursiveArrayTools, RecipesBase, Reexport
using ChainRulesCore, LinearAlgebra, RecursiveArrayTools, RecipesBase, Reexport
@reexport using Optim

include("interpolation_caches.jl")
include("interpolation_utils.jl")
include("interpolation_methods.jl")
include("plot_rec.jl")
include("derivatives.jl")

function ChainRulesCore.rrule(::typeof(_interpolate), A::AbstractInterpolation, t::Number)
interpolate_pullback(Δ) = (NO_FIELDS, DoesNotExist(), derivative(A, t) * Δ)
return _interpolate(A, t), interpolate_pullback
end

ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation, t::Number) = _interpolate(A, t), derivative(A, t) * Δt

(interp::AbstractInterpolation)(t::Number) = _interpolate(interp, t)

export LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, AkimaInterpolation,
ConstantInterpolation, QuadraticSpline, CubicSpline, BSplineInterpolation, BSplineApprox, Curvefit
Expand Down
184 changes: 184 additions & 0 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
function derivative(A::LinearInterpolation{<:AbstractVector}, t::Number)
idx = searchsortedfirst(A.t, t)
if A.t[idx] >= t
idx -= 1
end
idx == 0 ? idx += 1 : nothing
θ = 1 / (A.t[idx+1] - A.t[idx])
(A.u[idx+1] - A.u[idx]) / (A.t[idx+1] - A.t[idx])
end

function derivative(A::LinearInterpolation{<:AbstractMatrix}, t::Number)
idx = searchsortedfirst(A.t, t)
if A.t[idx] >= t
idx -= 1
end
idx == 0 ? idx += 1 : nothing
θ = 1 / (A.t[idx+1] - A.t[idx])
@views @. (A.u[:, idx+1] - A.u[:, idx]) / (A.t[idx+1] - A.t[idx])
end

function derivative(A::QuadraticInterpolation{<:AbstractVector}, t::Number)
idx = searchsortedfirst(A.t, t)
if A.t[idx] >= t
idx -= 1
end
idx == 0 ? idx += 1 : nothing
if idx == length(A.t) - 1
i₀ = idx - 1; i₁ = idx; i₂ = i₁ + 1;
else
i₀ = idx; i₁ = i₀ + 1; i₂ = i₁ + 1;
end
dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
A.u[i₀] * dl₀ + A.u[i₁] * dl₁ + A.u[i₂] * dl₂
end

function derivative(A::QuadraticInterpolation{<:AbstractMatrix}, t::Number)
idx = searchsortedfirst(A.t, t)
if A.t[idx] >= t
idx -= 1
end
idx == 0 ? idx += 1 : nothing
if idx == length(A.t) - 1
i₀ = idx - 1; i₁ = idx; i₂ = i₁ + 1;
else
i₀ = idx; i₁ = i₀ + 1; i₂ = i₁ + 1;
end
dl₀ = (2t - A.t[i₁] - A.t[i₂]) / ((A.t[i₀] - A.t[i₁]) * (A.t[i₀] - A.t[i₂]))
dl₁ = (2t - A.t[i₀] - A.t[i₂]) / ((A.t[i₁] - A.t[i₀]) * (A.t[i₁] - A.t[i₂]))
dl₂ = (2t - A.t[i₀] - A.t[i₁]) / ((A.t[i₂] - A.t[i₀]) * (A.t[i₂] - A.t[i₁]))
@views @. A.u[:, i₀] * dl₀ + A.u[:, i₁] * dl₁ + A.u[:, i₂] * dl₂
end

function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return zero(A.u[idxs[1]])
end
G = zero(A.u[1]); F = zero(A.t[1])
DG = zero(A.u[1]); DF = zero(A.t[1])
tmp = G
for i = 1:length(idxs)
if isnan(A.bcache[idxs[i]])
mult = one(A.t[1])
for j = 1:(i - 1)
mult *= (A.t[idxs[i]] - A.t[idxs[j]])
end
for j = (i+1):length(idxs)
mult *= (A.t[idxs[i]] - A.t[idxs[j]])
end
A.bcache[idxs[i]] = mult
else
mult = A.bcache[idxs[i]]
end
wi = inv(mult)
tti = t - A.t[idxs[i]]
tmp = wi / (t - A.t[idxs[i]])
g = tmp * A.u[idxs[i]]
G += g
DG -= g / (t - A.t[idxs[i]])
F += tmp
DF -= tmp / (t - A.t[idxs[i]])
end
(DG * F - G * DF) / (F ^ 2)
end

function derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return zero(A.u[:, idxs[1]])
end
G = zero(A.u[:, 1]); F = zero(A.t[1])
DG = zero(A.u[:, 1]); DF = zero(A.t[1])
tmp = G
for i = 1:length(idxs)
if isnan(A.bcache[idxs[i]])
mult = one(A.t[1])
for j = 1:(i - 1)
mult *= (A.t[idxs[i]] - A.t[idxs[j]])
end
for j = (i+1):length(idxs)
mult *= (A.t[idxs[i]] - A.t[idxs[j]])
end
A.bcache[idxs[i]] = mult
else
mult = A.bcache[idxs[i]]
end
wi = inv(mult)
tti = t - A.t[idxs[i]]
tmp = wi / (t - A.t[idxs[i]])
g = tmp * A.u[:, idxs[i]]
@. G += g
@. DG -= g / (t - A.t[idxs[i]])
F += tmp
DF -= tmp / (t - A.t[idxs[i]])
end
@. (DG * F - G * DF) / (F ^ 2)
end

function derivative(A::AkimaInterpolation{<:AbstractVector}, t::Number)
i = searchsortedlast(A.t, t)
i == 0 && return zero(A.u[1])
i == length(A.t) && return zero(A.u[end])
wj = t - A.t[i]
@evalpoly wj A.b[i] 2A.c[i] 3A.d[i]
end

function derivative(A::ConstantInterpolation{<:AbstractVector}, t::Number)
return isempty(searchsorted(A.t, t)) ? zero(A.u[1]) : eltype(A.u)(NaN)
end

function derivative(A::ConstantInterpolation{<:AbstractMatrix}, t::Number)
return isempty(searchsorted(A.t, t)) ? zero(A.u[:, 1]) : eltype(A.u)(NaN) .* A.u[:, 1]
end

# QuadraticSpline Interpolation
function derivative(A::QuadraticSpline{<:AbstractVector{<:Number}}, t::Number)
i = searchsortedfirst(A.t, t)
i == 1 ? i += 1 : nothing
σ = 1//2 * (A.z[i] - A.z[i - 1]) / (A.t[i] - A.t[i - 1])
A.z[i-1] + 2σ * (t - A.t[i-1])
end

# CubicSpline Interpolation
function derivative(A::CubicSpline{<:AbstractVector{<:Number}}, t::Number)
i = searchsortedfirst(A.t, t)
isnothing(i) ? i = length(A.t) - 1 : i -= 1
i == 0 ? i += 1 : nothing
dI = -3A.z[i] * (A.t[i + 1] - t)^2 / (6A.h[i + 1]) + 3A.z[i + 1] * (t - A.t[i])^2 / (6A.h[i + 1])
dC = A.u[i + 1] / A.h[i + 1] - A.z[i + 1] * A.h[i + 1] / 6
dD = -(A.u[i] / A.h[i + 1] - A.z[i] * A.h[i + 1] / 6)
dI + dC + dD
end

function derivative(A::BSplineInterpolation{<:AbstractVector{<:Number}}, t::Number)
# change t into param [0 1]
idx = searchsortedlast(A.t,t)
idx == length(A.t) ? idx -= 1 : nothing
n = length(A.t)
scale = (A.p[idx+1] - A.p[idx]) / (A.t[idx+1] - A.t[idx])
t_ = A.p[idx] + (t - A.t[idx]) * scale
N = DataInterpolations.spline_coefficients(n, A.d-1, A.k, t_)
ducum = zero(eltype(A.u))
for i = 1:(n - 1)
ducum += N[i + 1] * (A.c[i + 1] - A.c[i]) / (A.k[i + A.d + 1] - A.k[i + 1])
end
ducum * A.d * scale
end

# BSpline Curve Approx
function derivative(A::BSplineApprox{<:AbstractVector{<:Number}}, t::Number)
# change t into param [0 1]
idx = searchsortedlast(A.t,t)
idx == 0 ? idx += 1 : nothing
scale = (A.p[idx+1] - A.p[idx]) / (A.t[idx+1] - A.t[idx])
t_ = A.p[idx] + (t - A.t[idx]) * scale
N = spline_coefficients(A.h, A.d-1, A.k, t_)
ducum = zero(eltype(A.u))
for i = 1:(A.h - 1)
ducum += N[i + 1] * (A.c[i + 1] - A.c[i]) / (A.k[i + A.d + 1] - A.k[i + 1])
end
ducum * A.d * scale
end
11 changes: 8 additions & 3 deletions src/interpolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@ function QuadraticInterpolation(u,t)
end

### Lagrange Interpolation
struct LagrangeInterpolation{uType,tType,FT,T} <: AbstractInterpolation{FT,T}
struct LagrangeInterpolation{uType,tType,FT,T,bcacheType} <: AbstractInterpolation{FT,T}
u::uType
t::tType
n::Int
LagrangeInterpolation{FT}(u,t,n) where FT = new{typeof(u),typeof(t),FT,eltype(u)}(u,t,n)
bcache::bcacheType
function LagrangeInterpolation{FT}(u,t,n) where FT
bcache = zeros(eltype(u),n+1)
fill!(bcache, NaN)
new{typeof(u),typeof(t),FT,eltype(u),typeof(bcache)}(u,t,n,bcache)
end
end

function LagrangeInterpolation(u,t,n=nothing)
u, t = munge_data(u, t)
if n == nothing
if isnothing(n)
n = length(t) - 1 # degree
end
if n != length(t) - 1
Expand Down
Loading