Skip to content

Commit bb157ae

Browse files
committed
Wrapping extension in general function
Create a general function that enable the enqueue_functions extension if it is enable in the compiler, otherwise call the general sycl function to launch kernels. Signed-off-by: nscipione <[email protected]>
1 parent 72fabb4 commit bb157ae

19 files changed

+265
-237
lines changed

ggml/src/ggml-sycl/binbcast.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ struct bin_bcast_sycl {
225225
dpct::has_capability_or_fail(stream->get_device(),
226226
{sycl::aspect::fp16});
227227

228-
syclex::nd_launch(*stream,
228+
sycl_parallel_for(stream,
229229
sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
230230
sycl::range<3>(1, 1, block_size),
231231
sycl::range<3>(1, 1, block_size)),
@@ -246,7 +246,7 @@ struct bin_bcast_sycl {
246246
dpct::has_capability_or_fail(stream->get_device(),
247247
{sycl::aspect::fp16});
248248

249-
syclex::nd_launch(*stream,
249+
sycl_parallel_for(stream,
250250
sycl::nd_range<3>(block_nums * block_dims, block_dims),
251251
[=](sycl::nd_item<3> item_ct1) {
252252
k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1,

ggml/src/ggml-sycl/concat.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
8989
sycl::range<3> gridDim(ne2, ne1, num_blocks);
9090
switch (dim) {
9191
case 0:
92-
syclex::nd_launch(*stream,
92+
sycl_parallel_for(stream,
9393
sycl::nd_range<3>(gridDim *
9494
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
9595
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
@@ -98,7 +98,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
9898
});
9999
break;
100100
case 1:
101-
syclex::nd_launch(*stream,
101+
sycl_parallel_for(stream,
102102
sycl::nd_range<3>(gridDim *
103103
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
104104
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
@@ -108,7 +108,7 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
108108
break;
109109
// dim >=2 will be dispatched to the default path
110110
default:
111-
syclex::nd_launch(*stream,
111+
sycl_parallel_for(stream,
112112
sycl::nd_range<3>(gridDim *
113113
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE),
114114
sycl::range<3>(1, 1, SYCL_CONCAT_BLOCK_SIZE)),
@@ -129,7 +129,7 @@ static void concat_f32_sycl_non_cont(
129129
int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
130130
uint64_t nb3, int32_t dim) {
131131
sycl::range<3> gridDim(ne3, ne2, ne1);
132-
syclex::nd_launch(*stream,
132+
sycl_parallel_for(stream,
133133
sycl::nd_range<3>(gridDim, sycl::range<3>(1, 1, 1)),
134134
[=](sycl::nd_item<3> item_ct1) {
135135
int64_t i3 = item_ct1.get_group(0);

ggml/src/ggml-sycl/conv.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ static void conv_transpose_1d_f32_f32_sycl(
5959
const int num_blocks = (output_size + SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE - 1) / SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE;
6060
const sycl::range<3> block_dims(1, 1, SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE);
6161
const sycl::range<3> block_nums(1, 1, num_blocks);
62-
syclex::nd_launch(*stream,
62+
sycl_parallel_for(stream,
6363
sycl::nd_range<3>(
6464
block_nums * block_dims, block_dims),
6565
[=](sycl::nd_item<3> item_ct1) {

ggml/src/ggml-sycl/convert.cpp

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ static void dequantize_block_sycl(const void *__restrict__ vx,
3333
{
3434
dpct::has_capability_or_fail(stream->get_device(),
3535
{sycl::aspect::fp16});
36-
syclex::nd_launch(*stream,
36+
sycl_parallel_for(stream,
3737
sycl::nd_range<3>(
3838
sycl::range<3>(1, 1, num_blocks) *
3939
sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
@@ -53,7 +53,7 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
5353
dpct::has_capability_or_fail(stream->get_device(),
5454
{sycl::aspect::fp16});
5555

56-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
56+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
5757
sycl::range<3>(1, 1, 64),
5858
sycl::range<3>(1, 1, 64)),
5959
[=](sycl::nd_item<3> item_ct1) {
@@ -65,7 +65,7 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int64_t k,
6565
dpct::has_capability_or_fail(stream->get_device(),
6666
{sycl::aspect::fp16});
6767

68-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
68+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
6969
sycl::range<3>(1, 1, 32),
7070
sycl::range<3>(1, 1, 32)),
7171
[=](sycl::nd_item<3> item_ct1) {
@@ -85,7 +85,7 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
8585
dpct::has_capability_or_fail(stream->get_device(),
8686
{sycl::aspect::fp16});
8787

88-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
88+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
8989
sycl::range<3>(1, 1, 64),
9090
sycl::range<3>(1, 1, 64)),
9191
[=](sycl::nd_item<3> item_ct1) {
@@ -97,7 +97,7 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k,
9797
dpct::has_capability_or_fail(stream->get_device(),
9898
{sycl::aspect::fp16});
9999

100-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
100+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
101101
sycl::range<3>(1, 1, 32),
102102
sycl::range<3>(1, 1, 32)),
103103
[=](sycl::nd_item<3> item_ct1) {
@@ -116,7 +116,7 @@ static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k,
116116
dpct::has_capability_or_fail(stream->get_device(),
117117
{sycl::aspect::fp16});
118118

119-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
119+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
120120
sycl::range<3>(1, 1, 32),
121121
sycl::range<3>(1, 1, 32)),
122122
[=](sycl::nd_item<3> item_ct1) {
@@ -135,7 +135,7 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int
135135
int constexpr WARP_K = WARP_SIZE * QK4_0;
136136
const int n_warp = (k + WARP_K - 1) / WARP_K;
137137
GGML_ASSERT(k % 2 == 0);
138-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
138+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) *
139139
sycl::range<3>(1, 1, WARP_SIZE),
140140
sycl::range<3>(1, 1, WARP_SIZE)),
141141
[=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{
@@ -153,7 +153,7 @@ static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k,
153153
dpct::has_capability_or_fail(stream->get_device(),
154154
{sycl::aspect::fp16});
155155

156-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
156+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
157157
sycl::range<3>(1, 1, 32),
158158
sycl::range<3>(1, 1, 32)),
159159
[=](sycl::nd_item<3> item_ct1) {
@@ -171,9 +171,9 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int64_t k,
171171
dpct::has_capability_or_fail(stream->get_device(),
172172
{sycl::aspect::fp16});
173173

174-
syclex::submit(*stream,[&](sycl::handler &cgh) {
174+
sycl_launch(stream,[&](sycl::handler &cgh) {
175175
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
176-
syclex::nd_launch(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
176+
sycl_parallel_for(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
177177
sycl::range<3>(1, 1, 32),
178178
sycl::range<3>(1, 1, 32)),
179179
[=](sycl::nd_item<3> item_ct1) {
@@ -191,10 +191,10 @@ static void dequantize_row_q4_K_sycl_reorder(const void * vx, dst_t * y, const i
191191

192192
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
193193

194-
syclex::submit(*stream,[&](sycl::handler & cgh) {
194+
sycl_launch(stream,[&](sycl::handler & cgh) {
195195
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
196196

197-
syclex::nd_launch(cgh,sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
197+
sycl_parallel_for<1>(cgh,sycl::nd_range<1>(sycl::range<1>(global_size), sycl::range<1>(local_size)),
198198
[=](sycl::nd_item<1> item_ct1) {
199199
dequantize_block_q4_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb);
200200
});
@@ -210,7 +210,7 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
210210
dpct::has_capability_or_fail(stream->get_device(),
211211
{sycl::aspect::fp16});
212212

213-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
213+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
214214
sycl::range<3>(1, 1, 64),
215215
sycl::range<3>(1, 1, 64)),
216216
[=](sycl::nd_item<3> item_ct1) {
@@ -222,7 +222,7 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k,
222222
dpct::has_capability_or_fail(stream->get_device(),
223223
{sycl::aspect::fp16});
224224

225-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
225+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
226226
sycl::range<3>(1, 1, 32),
227227
sycl::range<3>(1, 1, 32)),
228228
[=](sycl::nd_item<3> item_ct1) {
@@ -242,7 +242,7 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
242242
dpct::has_capability_or_fail(stream->get_device(),
243243
{sycl::aspect::fp16});
244244

245-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
245+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
246246
sycl::range<3>(1, 1, 64),
247247
sycl::range<3>(1, 1, 64)),
248248
[=](sycl::nd_item<3> item_ct1) {
@@ -254,7 +254,7 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k,
254254
dpct::has_capability_or_fail(stream->get_device(),
255255
{sycl::aspect::fp16});
256256

257-
syclex::nd_launch(*stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
257+
sycl_parallel_for(stream,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
258258
sycl::range<3>(1, 1, 32),
259259
sycl::range<3>(1, 1, 32)),
260260
[=](sycl::nd_item<3> item_ct1) {
@@ -271,7 +271,7 @@ static void dequantize_row_q6_K_sycl_reorder(const void * vx, dst_t * y, const i
271271

272272
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
273273

274-
syclex::nd_launch(*stream,
274+
sycl_parallel_for(stream,
275275
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)),
276276
[=](sycl::nd_item<3> item_ct1) { dequantize_block_q6_K_reorder(vx, y, item_ct1, nb); });
277277
}
@@ -284,8 +284,8 @@ static void dequantize_row_iq1_s_sycl(const void *vx, dst_t *y, const int64_t k,
284284
dpct::has_capability_or_fail(stream->get_device(),
285285
{sycl::aspect::fp16});
286286

287-
syclex::submit(*stream,[&](sycl::handler &cgh) {
288-
syclex::nd_launch(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
287+
sycl_launch(stream,[&](sycl::handler &cgh) {
288+
sycl_parallel_for(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
289289
sycl::range<3>(1, 1, 32),
290290
sycl::range<3>(1, 1, 32)),
291291
[=](sycl::nd_item<3> item_ct1) {
@@ -305,8 +305,8 @@ static void dequantize_row_iq1_m_sycl(const void *vx, dst_t *y, const int64_t k,
305305
dpct::has_capability_or_fail(stream->get_device(),
306306
{sycl::aspect::fp16});
307307

308-
syclex::submit(*stream,[&](sycl::handler &cgh) {
309-
syclex::nd_launch(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
308+
sycl_launch(stream,[&](sycl::handler &cgh) {
309+
sycl_parallel_for(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
310310
sycl::range<3>(1, 1, 32),
311311
sycl::range<3>(1, 1, 32)),
312312
[=](sycl::nd_item<3> item_ct1) {
@@ -326,8 +326,8 @@ static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int64_t
326326
dpct::has_capability_or_fail(stream->get_device(),
327327
{sycl::aspect::fp16});
328328

329-
syclex::submit(*stream,[&](sycl::handler &cgh) {
330-
syclex::nd_launch(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
329+
sycl_launch(stream,[&](sycl::handler &cgh) {
330+
sycl_parallel_for(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
331331
sycl::range<3>(1, 1, 32),
332332
sycl::range<3>(1, 1, 32)),
333333
[=](sycl::nd_item<3> item_ct1) {
@@ -347,8 +347,8 @@ static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int64_t k
347347
dpct::has_capability_or_fail(stream->get_device(),
348348
{sycl::aspect::fp16});
349349

350-
syclex::submit(*stream,[&](sycl::handler &cgh) {
351-
syclex::nd_launch(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
350+
sycl_launch(stream,[&](sycl::handler &cgh) {
351+
sycl_parallel_for(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
352352
sycl::range<3>(1, 1, 32),
353353
sycl::range<3>(1, 1, 32)),
354354
[=](sycl::nd_item<3> item_ct1) {
@@ -368,8 +368,8 @@ static void dequantize_row_iq2_s_sycl(const void *vx, dst_t *y, const int64_t k,
368368
dpct::has_capability_or_fail(stream->get_device(),
369369
{sycl::aspect::fp16});
370370

371-
syclex::submit(*stream,[&](sycl::handler &cgh) {
372-
syclex::nd_launch(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
371+
sycl_launch(stream,[&](sycl::handler &cgh) {
372+
sycl_parallel_for(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
373373
sycl::range<3>(1, 1, 32),
374374
sycl::range<3>(1, 1, 32)),
375375
[=](sycl::nd_item<3> item_ct1) {
@@ -388,8 +388,8 @@ static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int64_t
388388
dpct::has_capability_or_fail(stream->get_device(),
389389
{sycl::aspect::fp16});
390390

391-
syclex::submit(*stream,[&](sycl::handler &cgh) {
392-
syclex::nd_launch(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
391+
sycl_launch(stream,[&](sycl::handler &cgh) {
392+
sycl_parallel_for(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
393393
sycl::range<3>(1, 1, 32),
394394
sycl::range<3>(1, 1, 32)),
395395
[=](sycl::nd_item<3> item_ct1) {
@@ -409,8 +409,8 @@ static void dequantize_row_iq3_s_sycl(const void *vx, dst_t *y, const int64_t k,
409409
dpct::has_capability_or_fail(stream->get_device(),
410410
{sycl::aspect::fp16});
411411

412-
syclex::submit(*stream,[&](sycl::handler &cgh) {
413-
syclex::nd_launch(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
412+
sycl_launch(stream,[&](sycl::handler &cgh) {
413+
sycl_parallel_for(cgh,sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
414414
sycl::range<3>(1, 1, 32),
415415
sycl::range<3>(1, 1, 32)),
416416
[=](sycl::nd_item<3> item_ct1) {
@@ -432,8 +432,8 @@ static void dequantize_row_iq4_xs_sycl(const void *vx, dst_t *y, const int64_t k
432432
dpct::has_capability_or_fail(stream->get_device(),
433433
{sycl::aspect::fp16});
434434

435-
syclex::submit(*stream,[&](sycl::handler &cgh) {
436-
syclex::nd_launch(cgh,
435+
sycl_launch(stream,[&](sycl::handler &cgh) {
436+
sycl_parallel_for(cgh,
437437
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
438438
sycl::range<3>(1, 1, 32),
439439
sycl::range<3>(1, 1, 32)),
@@ -453,8 +453,8 @@ static void dequantize_row_iq4_nl_sycl(const void *vx, dst_t *y, const int64_t k
453453
dpct::has_capability_or_fail(stream->get_device(),
454454
{sycl::aspect::fp16});
455455

456-
syclex::submit(*stream,[&](sycl::handler &cgh) {
457-
syclex::nd_launch(cgh,
456+
sycl_launch(stream,[&](sycl::handler &cgh) {
457+
sycl_parallel_for(cgh,
458458
sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
459459
sycl::range<3>(1, 1, 32),
460460
sycl::range<3>(1, 1, 32)),

0 commit comments

Comments
 (0)