Speeding up my logsumexp function

The awkward thing about broadcasting is that type information and syntax combined still do not fully specify the behavior: if isone(size(A, n)), then axis n is broadcasted.
Currently LoopVectorization handles broadcasting by setting strides corresponding to dimensions of size 1 to 0. The linear index A[n...] equals dot(n, strides(A)), so by setting the stride to 0, it’ll ignore indexes on that axis and “broadcast” along it.

However, this doesn’t work well when loading data along a contiguous axis.
For this to be efficient, we need to use vmovup* instructions. These load contiguous elements, so the stride = 0 trick won’t work.
It would work with gather instructions, available in AVX2 and AVX512, but they’re many times slower. While better than nothing, they cripple performance on a comparative basis.
This means to be efficient, we have to use the contiguous loads (and stores, but stores aren’t a problem when broadcasting).

Obviously Julia+LLVM don’t have a problem with this. I’m guessing it uses a few runtime checks to switch between different versions of the loop. I should probably follow that approach.

But for now, a workaround (incompatible with the dims argument) is to make it known at compile time that isone(size(max_,1)):

function lsexp_mat4(mat; dims=1) # @avx broadcasting, is having a bad day
    @assert dims == 1
    max_ = vec(maximum(mat, dims=1))' # requires dims=1
    # zero1_mat = (mat .== max_)
    exp_mat = @avx exp.(mat .- max_) .- (mat .== max_) # should now work
    sum_exp_ = sum(exp_mat, dims=dims)
    @avx sum_exp_ .= log1p.(sum_exp_) .+ max_ # mostly NaN
end

So now I get:

julia> n = 1_000; A = rand(n,n);

julia> lsexp_mat(A) ≈ lsexp_mat1(A) ≈ lsexp_mat2(A)
true

julia> lsexp_mat(A) ≈ lsexp_mat4(A)
true

julia> lsexp_mat(A) ≈ lsexp_mat3(A) ≈ lsexp_mat5(A)
true

julia> @btime lsexp_mat($A);
  10.844 ms (13 allocations: 15.41 MiB)

julia> @btime lsexp_mat2($A);
  10.726 ms (6 allocations: 7.65 MiB)

julia> @btime lsexp_mat3($A);
  8.750 ms (21 allocations: 7.65 MiB)

julia> @btime lsexp_mat4($A);
  2.402 ms (17 allocations: 7.65 MiB)

julia> @btime lsexp_mat5($A);
  2.557 ms (19 allocations: 7.65 MiB)

julia> using Tracker, Zygote #, ForwardDiff

julia> Zygote.@nograd safeeq

julia> gA = Tracker.gradient(sum∘lsexp_mat, A)[1];

julia> Zygote.gradient(sum∘lsexp_mat1, A)[1] ≈ gA
true

julia> Zygote.gradient(sum∘lsexp_mat3, A)[1] ≈ gA
true

julia> Zygote.gradient(sum∘lsexp_mat5, A)[1] ≈ gA
true

julia> @btime Zygote.gradient(sum∘lsexp_mat1, $A);
  67.391 ms (3003130 allocations: 137.61 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat3, $A);
  22.924 ms (133 allocations: 38.22 MiB)

julia> @btime Zygote.gradient(sum∘lsexp_mat5, $A);
  10.519 ms (128 allocations: 38.22 MiB)

Overall, our performance numbers are really similar except for lsexp_mat3, where my computer is much slower.

1 Like