@@ -69,15 +69,25 @@ function add_var!(q, argtup, gcpres, ::Type{T}, argtupname, gcpresname, k) where
69
69
end
70
70
end
71
71
@generated function _batch_no_reserve (
72
- f!:: F , threadmask_tuple:: NTuple{N} , nthread_tuple, torelease_tuple, Nr, Nd, ulen, args:: Vararg{Any,K} ; threadlocal:: Bool = false
73
- ) where {F,K,N}
72
+ f!:: F , threadmask_tuple:: NTuple{N} , nthread_tuple, torelease_tuple, Nr, Nd, ulen, args:: Vararg{Any,K} ; threadlocal:: Val{thread_local} = Val ( false )
73
+ ) where {F,K,N,thread_local }
74
74
q = quote
75
75
$ (Expr (:meta ,:inline ))
76
76
# threads = UnsignedIteratorEarlyStop(threadmask, nthread)
77
77
# threads_tuple = map(UnsignedIteratorEarlyStop, threadmask_tuple, nthread_tuple)
78
78
# nthread_total = sum(nthread_tuple)
79
79
Ndp = Nd + one (Nd)
80
80
end
81
+ launch_quote = if thread_local
82
+ :(launch_batched_thread! (cfunc, tid, argtup, start, stop, i% UInt))
83
+ else
84
+ :(launch_batched_thread! (cfunc, tid, argtup, start, stop))
85
+ end
86
+ rem_quote = if thread_local
87
+ :(f! (arguments, (start+ one (UInt)) % Int, ulen % Int, (sum (nthread_tuple)+ 1 )% Int))
88
+ else
89
+ :(f! (arguments, (start+ one (UInt)) % Int, ulen % Int))
90
+ end
81
91
block = quote
82
92
start = zero (UInt)
83
93
tid = 0x00000000
92
102
tz += 0x00000001
93
103
tid += tz
94
104
tm >>>= tz
95
- if threadlocal
96
- launch_batched_thread! (cfunc, tid, argtup, start, stop, i% UInt)
97
- else
98
- launch_batched_thread! (cfunc, tid, argtup, start, stop)
99
- end
105
+ $ launch_quote
100
106
start = stop
101
107
end
102
108
Nr -= nthread
103
109
end
104
- if threadlocal
105
- f! (arguments, (start+ one (UInt)) % Int, ulen % Int, (sum (nthread_tuple)+ 1 )% Int)
106
- else
107
- f! (arguments, (start+ one (UInt)) % Int, ulen % Int)
108
- end
110
+ $ rem_quote
109
111
for (threadmask, nthread, torelease) ∈ zip (threadmask_tuple, nthread_tuple, torelease_tuple)
110
112
tm = mask (UnsignedIteratorEarlyStop (threadmask, nthread))
111
113
tid = 0x00000000
127
129
for k ∈ 1 : K
128
130
add_var! (q, argt, gcpr, args[k], :args , :gcp , k)
129
131
end
130
- push! (q. args, :(arguments = $ argt), :(argtup = Reference (arguments)), :(cfunc = batch_closure (f!, argtup, Val {false} (), Val {threadlocal } ())), gcpr)
132
+ push! (q. args, :(arguments = $ argt), :(argtup = Reference (arguments)), :(cfunc = batch_closure (f!, argtup, Val {false} (), Val {$thread_local } ())), gcpr)
131
133
push! (q. args, nothing )
132
134
q
133
135
end
@@ -227,15 +229,15 @@ end
227
229
228
230
229
231
@inline function batch (
230
- f!:: F , (len, nbatches):: Tuple{Vararg{Integer,2}} , args:: Vararg{Any,K} ; threadlocal:: Bool = false
231
- ) where {F,K}
232
+ f!:: F , (len, nbatches):: Tuple{Vararg{Integer,2}} , args:: Vararg{Any,K} ; threadlocal:: Val{thread_local} = Val { false} ()
233
+ ) where {F,K,thread_local }
232
234
# threads, torelease = request_threads(Base.Threads.threadid(), nbatches - one(nbatches))
233
235
threads, torelease = request_threads (nbatches - one (nbatches))
234
236
nthreads = map (length,threads)
235
237
nthread = sum (nthreads)
236
238
ulen = len % UInt
237
239
if nthread % Int32 ≤ zero (Int32)
238
- if threadlocal
240
+ if thread_local
239
241
f! (args, one (Int), ulen % Int, 1 )
240
242
else
241
243
f! (args, one (Int), ulen % Int)
@@ -246,12 +248,12 @@ end
246
248
Nd = Base. udiv_int (ulen, nbatch % UInt) # reasonable for `ulen` to be ≥ 2^32
247
249
Nr = ulen - Nd * nbatch
248
250
249
- _batch_no_reserve (f!, map (mask,threads), nthreads, torelease, Nr, Nd, ulen, args... ; threadlocal= threadlocal )
251
+ _batch_no_reserve (f!, map (mask,threads), nthreads, torelease, Nr, Nd, ulen, args... ; threadlocal)
250
252
end
251
253
function batch (
252
- f!:: F , (len, nbatches, reserve_per_worker):: Tuple{Vararg{Integer,3}} , args:: Vararg{Any,K} ; threadlocal:: Bool = false
253
- ) where {F,K}
254
- batch (f!, (len, nbatches), args... ; threadlocal= false )
254
+ f!:: F , (len, nbatches, reserve_per_worker):: Tuple{Vararg{Integer,3}} , args:: Vararg{Any,K} ; threadlocal:: Val{thread_local} = Val ( false )
255
+ ) where {F,K,thread_local }
256
+ batch (f!, (len, nbatches), args... ; threadlocal)
255
257
# ulen = len % UInt
256
258
# if nbatches > 1
257
259
# requested_threads = reserve_per_worker*nbatches
0 commit comments