Skip to content

Using @turbo with FowardDiff.Dual for logsumexp #437

Open
@magerton

Description

@magerton

Using @turbo loops gives incredible performance gains (10x) over the LogExpFunctions library for arrays of Float64s. However, the @turbo doesn't seem to play well with FowardDiff.Dual arrays and prints the warning below. Is there a way to leverage LoopVectorization to accelerate operations on Dual numbers?

`LoopVectorization.check_args` on your inputs failed; running fallback `@inbounds @fastmath` loop instead.
Use `warn_check_args=false`, e.g. `@turbo warn_check_args=false ...`, to disable this warning.

I'm uploading a Pluto notebook with some benchmarks, which I reproduce below

Not sure if this is related to #93. @chriselrod , I think that this is related to your posts at https://discourse.julialang.org/t/speeding-up-my-logsumexp-function/42380/9?page=2 and https://discourse.julialang.org/t/fast-logsumexp-over-4th-dimension/64182/26

Thanks!

2-element BenchmarkTools.BenchmarkGroup:
  tags: []
  "Float64" => 6-element BenchmarkTools.BenchmarkGroup:
	  tags: ["Float64"]
	  "Vanilla Loop" => Trial(29.500 μs)
	  "Tullio" => Trial(5.200 μs)
	  "LogExpFunctions" => Trial(35.700 μs)
	  "Turbo" => Trial(3.000 μs)
	  "SIMD Loop" => Trial(25.500 μs)
	  "Vmap" => Trial(3.800 μs)
  "Dual" => 6-element BenchmarkTools.BenchmarkGroup:
	  tags: ["Dual"]
	  "Vanilla Loop" => Trial(45.300 μs)
	  "Tullio" => Trial(53.100 μs)
	  "LogExpFunctions" => Trial(62.800 μs)
	  "Turbo" => Trial(311.900 μs)
	  "SIMD Loop" => Trial(37.600 μs)
	  "Vmap" => Trial(44.300 μs)

LoopVectorization functions are

"""
using `LoopVectorization.@turbo` loops

**NOTE** - not compatible with `ForwardDiff.Dual` numbers!
"""
function logsumexp_turbo!(Vbar, tmp_max, X)
	n,k = size(X)
	maximum!(tmp_max, X)
	fill!(Vbar, 0)
	@turbo for i in 1:n, j in 1:k
		Vbar[i] += exp(X[i,j] - tmp_max[i])
	end
	@turbo for i in 1:n
		Vbar[i] = log(Vbar[i]) + tmp_max[i]
	end
	return Vbar
end

"""
using `LoopVectorization` `vmap` convenience fcts

**NOTE** - this DOES work with `ForwardDiff.Dual` numbers!
"""
function logsumexp_vmap!(Vbar, tmp_max, X, Xtmp)
	maximum!(tmp_max, X)
	n = size(X,2)
	for j in 1:n
		Xtmpj = view(Xtmp, :, j)
		Xj    = view(X, :, j)
		vmap!((xij, mi) -> exp(xij-mi), Xtmpj, Xj, tmp_max)
	end
	Vbartmp = vreduce(+, Xtmp; dims=2)
	vmap!((vi,mi) -> log(vi) + mi, Vbar, Vbartmp, tmp_max)
	return Vbar
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions