Skip to content

Commit 48480c8

Browse files
committed
Refactored opt_for_reorder logic to simplify code path
1 parent d61dda3 commit 48480c8

File tree

1 file changed

+39
-21
lines changed

1 file changed

+39
-21
lines changed

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2831,13 +2831,19 @@ catch (sycl::exception const &exc) {
28312831
std::exit(1);
28322832
}
28332833

2834+
enum class Mul_Mat_Algo {
2835+
DMMV = 0,
2836+
MMVQ = 1,
2837+
MUL_MAT_SYCL = 2,
2838+
};
2839+
28342840
inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
28352841
// TODO: accuracy issues in MMQ
28362842
GGML_UNUSED(type);
28372843
return false;
28382844
}
28392845

2840-
inline bool ggml_sycl_supports_reorder_dequantize(enum ggml_type type) {
2846+
inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) {
28412847
switch (type) {
28422848
case GGML_TYPE_Q4_0:
28432849
return true;
@@ -2927,20 +2933,37 @@ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_ten
29272933
dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1;
29282934
}
29292935

2930-
/*
2931-
* This function could be called when the OP (mul_mat) function support reorder optimizition.
2932-
*/
2933-
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2934-
ggml_tensor * dst) {
2935-
if (should_reorder_tensor(*ctx, dst)) {
2936-
ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra;
2937-
if (!extra) return; //only happen in CI/UT permute case.
2936+
static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */,
2937+
ggml_tensor * dst, Mul_Mat_Algo mul_mat_algo) {
2938+
if (!should_reorder_tensor(*ctx, dst)) {
2939+
return;
2940+
}
29382941

2939-
if (extra->optimized_feature.reorder) return; //skip the tensor which is handled for reorder.
2942+
ggml_tensor_extra_gpu * extra = static_cast<ggml_tensor_extra_gpu *>(src0->extra);
2943+
if (!extra || extra->optimized_feature.reorder) {
2944+
return; // Skip permutations and already reordered tensors
2945+
}
29402946

2941-
reorder_qw(src0, ctx->stream());
2942-
extra->optimized_feature.reorder = true; //used to decode/dequan in next steps.
2947+
switch (mul_mat_algo) {
2948+
case Mul_Mat_Algo::DMMV:
2949+
if (!ggml_sycl_supports_reorder_dmmv(src0->type)) {
2950+
return;
2951+
}
2952+
break;
2953+
case Mul_Mat_Algo::MMVQ:
2954+
if (!ggml_sycl_supports_reorder_mmvq(src0->type)) {
2955+
return;
2956+
}
2957+
break;
2958+
case Mul_Mat_Algo::MUL_MAT_SYCL:
2959+
if (!ggml_sycl_supports_reorder_mul_mat_sycl(src0->type)) {
2960+
return;
2961+
}
2962+
break;
29432963
}
2964+
2965+
reorder_qw(src0, ctx->stream());
2966+
extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering
29442967
}
29452968

29462969
static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3013,24 +3036,19 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
30133036
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
30143037
} else if (use_dequantize_mul_mat_vec) {
30153038
constexpr bool convert_src1_to_q8_1 = false;
3016-
if (ggml_sycl_supports_reorder_dmmv(src0->type)) {
3017-
opt_for_reorder(&ctx, src0, src1, dst);
3018-
}
3039+
opt_for_reorder(&ctx, src0, src1, dst, Mul_Mat_Algo::DMMV);
30193040
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
30203041
} else if (use_mul_mat_vec_q) {
30213042
constexpr bool convert_src1_to_q8_1 = true;
3022-
if (ggml_sycl_supports_reorder_mmvq(src0->type)) {
3023-
opt_for_reorder(&ctx, src0, src1, dst);
3024-
}
3043+
opt_for_reorder(&ctx, src0, src1, dst, Mul_Mat_Algo::MMVQ);
30253044
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
30263045
} else if (use_mul_mat_q) {
30273046
constexpr bool convert_src1_to_q8_1 = true;
30283047
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
30293048
} else {
30303049
constexpr bool convert_src1_to_q8_1 = false;
3031-
if (ggml_sycl_supports_reorder_dequantize(src0->type)) {
3032-
opt_for_reorder(&ctx, src0, src1, dst); //the OP function in this branch support reorder.
3033-
}
3050+
// MUL_MAT_SYCL supports reorder
3051+
opt_for_reorder(&ctx, src0, src1, dst, Mul_Mat_Algo::MUL_MAT_SYCL);
30343052
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
30353053
}
30363054
GGML_SYCL_DEBUG("call %s done\n", __func__);

0 commit comments

Comments
 (0)