Open
Description
Using @turbo
loops gives incredible performance gains (10x) over the LogExpFunctions
library for arrays of Float64
s. However, the @turbo
doesn't seem to play well with FowardDiff.Dual
arrays and prints the warning below. Is there a way to leverage LoopVectorization
to accelerate operations on Dual
numbers?
`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.
Use `warn_check_args=false`, e.g. `@turbo warn_check_args=false ...`, to disable this warning.
I'm uploading a Pluto notebook with some benchmarks, which I reproduce below
Not sure if this is related to #93. @chriselrod , I think that this is related to your posts at https://discourse.julialang.org/t/speeding-up-my-logsumexp-function/42380/9?page=2 and https://discourse.julialang.org/t/fast-logsumexp-over-4th-dimension/64182/26
Thanks!
2-element BenchmarkTools.BenchmarkGroup:
tags: []
"Float64" => 6-element BenchmarkTools.BenchmarkGroup:
tags: ["Float64"]
"Vanilla Loop" => Trial(29.500 μs)
"Tullio" => Trial(5.200 μs)
"LogExpFunctions" => Trial(35.700 μs)
"Turbo" => Trial(3.000 μs)
"SIMD Loop" => Trial(25.500 μs)
"Vmap" => Trial(3.800 μs)
"Dual" => 6-element BenchmarkTools.BenchmarkGroup:
tags: ["Dual"]
"Vanilla Loop" => Trial(45.300 μs)
"Tullio" => Trial(53.100 μs)
"LogExpFunctions" => Trial(62.800 μs)
"Turbo" => Trial(311.900 μs)
"SIMD Loop" => Trial(37.600 μs)
"Vmap" => Trial(44.300 μs)
LoopVectorization
functions are
"""
using `LoopVectorization.@turbo` loops
**NOTE** - not compatible with `ForwardDiff.Dual` numbers!
"""
function logsumexp_turbo!(Vbar, tmp_max, X)
n,k = size(X)
maximum!(tmp_max, X)
fill!(Vbar, 0)
@turbo for i in 1:n, j in 1:k
Vbar[i] += exp(X[i,j] - tmp_max[i])
end
@turbo for i in 1:n
Vbar[i] = log(Vbar[i]) + tmp_max[i]
end
return Vbar
end
"""
using `LoopVectorization` `vmap` convenience fcts
**NOTE** - this DOES work with `ForwardDiff.Dual` numbers!
"""
function logsumexp_vmap!(Vbar, tmp_max, X, Xtmp)
maximum!(tmp_max, X)
n = size(X,2)
for j in 1:n
Xtmpj = view(Xtmp, :, j)
Xj = view(X, :, j)
vmap!((xij, mi) -> exp(xij-mi), Xtmpj, Xj, tmp_max)
end
Vbartmp = vreduce(+, Xtmp; dims=2)
vmap!((vi,mi) -> log(vi) + mi, Vbar, Vbartmp, tmp_max)
return Vbar
end
Metadata
Metadata
Assignees
Labels
No labels