Skip to content

Commit 5e6ae4c

Browse files
committed
Type stability fixes surrounding passing of threadlocal variable
1 parent 6603529 commit 5e6ae4c

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

src/batch.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,25 @@ function add_var!(q, argtup, gcpres, ::Type{T}, argtupname, gcpresname, k) where
6969
end
7070
end
7171
@generated function _batch_no_reserve(
72-
f!::F, threadmask_tuple::NTuple{N}, nthread_tuple, torelease_tuple, Nr, Nd, ulen, args::Vararg{Any,K}; threadlocal::Bool=false
73-
) where {F,K,N}
72+
f!::F, threadmask_tuple::NTuple{N}, nthread_tuple, torelease_tuple, Nr, Nd, ulen, args::Vararg{Any,K}; threadlocal::Val{thread_local}=Val(false)
73+
) where {F,K,N,thread_local}
7474
q = quote
7575
$(Expr(:meta,:inline))
7676
# threads = UnsignedIteratorEarlyStop(threadmask, nthread)
7777
# threads_tuple = map(UnsignedIteratorEarlyStop, threadmask_tuple, nthread_tuple)
7878
# nthread_total = sum(nthread_tuple)
7979
Ndp = Nd + one(Nd)
8080
end
81+
launch_quote = if thread_local
82+
:(launch_batched_thread!(cfunc, tid, argtup, start, stop, i%UInt))
83+
else
84+
:(launch_batched_thread!(cfunc, tid, argtup, start, stop))
85+
end
86+
rem_quote = if thread_local
87+
:(f!(arguments, (start+one(UInt)) % Int, ulen % Int, (sum(nthread_tuple)+1)%Int))
88+
else
89+
:(f!(arguments, (start+one(UInt)) % Int, ulen % Int))
90+
end
8191
block = quote
8292
start = zero(UInt)
8393
tid = 0x00000000
@@ -92,20 +102,12 @@ end
92102
tz += 0x00000001
93103
tid += tz
94104
tm >>>= tz
95-
if threadlocal
96-
launch_batched_thread!(cfunc, tid, argtup, start, stop, i%UInt)
97-
else
98-
launch_batched_thread!(cfunc, tid, argtup, start, stop)
99-
end
105+
$launch_quote
100106
start = stop
101107
end
102108
Nr -= nthread
103109
end
104-
if threadlocal
105-
f!(arguments, (start+one(UInt)) % Int, ulen % Int, (sum(nthread_tuple)+1)%Int)
106-
else
107-
f!(arguments, (start+one(UInt)) % Int, ulen % Int)
108-
end
110+
$rem_quote
109111
for (threadmask, nthread, torelease) zip(threadmask_tuple, nthread_tuple, torelease_tuple)
110112
tm = mask(UnsignedIteratorEarlyStop(threadmask, nthread))
111113
tid = 0x00000000
@@ -127,7 +129,7 @@ end
127129
for k 1:K
128130
add_var!(q, argt, gcpr, args[k], :args, :gcp, k)
129131
end
130-
push!(q.args, :(arguments = $argt), :(argtup = Reference(arguments)), :(cfunc = batch_closure(f!, argtup, Val{false}(), Val{threadlocal}())), gcpr)
132+
push!(q.args, :(arguments = $argt), :(argtup = Reference(arguments)), :(cfunc = batch_closure(f!, argtup, Val{false}(), Val{$thread_local}())), gcpr)
131133
push!(q.args, nothing)
132134
q
133135
end
@@ -227,15 +229,15 @@ end
227229

228230

229231
@inline function batch(
230-
f!::F, (len, nbatches)::Tuple{Vararg{Integer,2}}, args::Vararg{Any,K}; threadlocal::Bool=false
231-
) where {F,K}
232+
f!::F, (len, nbatches)::Tuple{Vararg{Integer,2}}, args::Vararg{Any,K}; threadlocal::Val{thread_local}=Val{false}()
233+
) where {F,K,thread_local}
232234
# threads, torelease = request_threads(Base.Threads.threadid(), nbatches - one(nbatches))
233235
threads, torelease = request_threads(nbatches - one(nbatches))
234236
nthreads = map(length,threads)
235237
nthread = sum(nthreads)
236238
ulen = len % UInt
237239
if nthread % Int32 zero(Int32)
238-
if threadlocal
240+
if thread_local
239241
f!(args, one(Int), ulen % Int, 1)
240242
else
241243
f!(args, one(Int), ulen % Int)
@@ -246,12 +248,12 @@ end
246248
Nd = Base.udiv_int(ulen, nbatch % UInt) # reasonable for `ulen` to be ≥ 2^32
247249
Nr = ulen - Nd * nbatch
248250

249-
_batch_no_reserve(f!, map(mask,threads), nthreads, torelease, Nr, Nd, ulen, args...; threadlocal=threadlocal)
251+
_batch_no_reserve(f!, map(mask,threads), nthreads, torelease, Nr, Nd, ulen, args...; threadlocal)
250252
end
251253
function batch(
252-
f!::F, (len, nbatches, reserve_per_worker)::Tuple{Vararg{Integer,3}}, args::Vararg{Any,K}; threadlocal::Bool=false
253-
) where {F,K}
254-
batch(f!, (len, nbatches), args...; threadlocal=false)
254+
f!::F, (len, nbatches, reserve_per_worker)::Tuple{Vararg{Integer,3}}, args::Vararg{Any,K}; threadlocal::Val{thread_local}=Val(false)
255+
) where {F,K,thread_local}
256+
batch(f!, (len, nbatches), args...; threadlocal)
255257
# ulen = len % UInt
256258
# if nbatches > 1
257259
# requested_threads = reserve_per_worker*nbatches

src/closure.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ function enclose(exorig::Expr, reserve_per, minbatchsize, per::Symbol, threadloc
340340
push!(batchcall.args, esc(a))
341341
end
342342
if threadlocal !== Symbol("")
343-
push!(batchcall.args, Expr(:kw, :threadlocal, true))
343+
push!(batchcall.args, Expr(:kw, :threadlocal, Val(true)))
344344
end
345345
push!(q.args, batchcall)
346346
quote

0 commit comments

Comments
 (0)