@@ -2831,13 +2831,19 @@ catch (sycl::exception const &exc) {
2831
2831
std::exit (1 );
2832
2832
}
2833
2833
2834
+ enum class Mul_Mat_Algo {
2835
+ DMMV = 0 ,
2836
+ MMVQ = 1 ,
2837
+ MUL_MAT_SYCL = 2 ,
2838
+ };
2839
+
2834
2840
inline bool ggml_sycl_supports_mmq (enum ggml_type type) {
2835
2841
// TODO: accuracy issues in MMQ
2836
2842
GGML_UNUSED (type);
2837
2843
return false ;
2838
2844
}
2839
2845
2840
- inline bool ggml_sycl_supports_reorder_dequantize (enum ggml_type type) {
2846
+ inline bool ggml_sycl_supports_reorder_mul_mat_sycl (enum ggml_type type) {
2841
2847
switch (type) {
2842
2848
case GGML_TYPE_Q4_0:
2843
2849
return true ;
@@ -2927,20 +2933,37 @@ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_ten
2927
2933
dst->src [1 ]->ne [2 ]==1 && dst->src [1 ]->ne [3 ]==1 ;
2928
2934
}
2929
2935
2930
- /*
2931
- * This function could be called when the OP (mul_mat) function support reorder optimizition.
2932
- */
2933
- static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1,
2934
- ggml_tensor * dst) {
2935
- if (should_reorder_tensor (*ctx, dst)) {
2936
- ggml_tensor_extra_gpu* extra = (ggml_tensor_extra_gpu*)src0->extra ;
2937
- if (!extra) return ; // only happen in CI/UT permute case.
2936
+ static void opt_for_reorder (ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */ ,
2937
+ ggml_tensor * dst, Mul_Mat_Algo mul_mat_algo) {
2938
+ if (!should_reorder_tensor (*ctx, dst)) {
2939
+ return ;
2940
+ }
2938
2941
2939
- if (extra->optimized_feature .reorder ) return ; // skip the tensor which is handled for reorder.
2942
+ ggml_tensor_extra_gpu * extra = static_cast <ggml_tensor_extra_gpu *>(src0->extra );
2943
+ if (!extra || extra->optimized_feature .reorder ) {
2944
+ return ; // Skip permutations and already reordered tensors
2945
+ }
2940
2946
2941
- reorder_qw (src0, ctx->stream ());
2942
- extra->optimized_feature .reorder = true ; // used to decode/dequan in next steps.
2947
+ switch (mul_mat_algo) {
2948
+ case Mul_Mat_Algo::DMMV:
2949
+ if (!ggml_sycl_supports_reorder_dmmv (src0->type )) {
2950
+ return ;
2951
+ }
2952
+ break ;
2953
+ case Mul_Mat_Algo::MMVQ:
2954
+ if (!ggml_sycl_supports_reorder_mmvq (src0->type )) {
2955
+ return ;
2956
+ }
2957
+ break ;
2958
+ case Mul_Mat_Algo::MUL_MAT_SYCL:
2959
+ if (!ggml_sycl_supports_reorder_mul_mat_sycl (src0->type )) {
2960
+ return ;
2961
+ }
2962
+ break ;
2943
2963
}
2964
+
2965
+ reorder_qw (src0, ctx->stream ());
2966
+ extra->optimized_feature .reorder = true ; // Used to decode/dequan in next steps and avoid re-reordering
2944
2967
}
2945
2968
2946
2969
static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -3013,24 +3036,19 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
3013
3036
ggml_sycl_mul_mat_batched_sycl (ctx, src0, src1, dst);
3014
3037
} else if (use_dequantize_mul_mat_vec) {
3015
3038
constexpr bool convert_src1_to_q8_1 = false ;
3016
- if (ggml_sycl_supports_reorder_dmmv (src0->type )) {
3017
- opt_for_reorder (&ctx, src0, src1, dst);
3018
- }
3039
+ opt_for_reorder (&ctx, src0, src1, dst, Mul_Mat_Algo::DMMV);
3019
3040
ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, convert_src1_to_q8_1);
3020
3041
} else if (use_mul_mat_vec_q) {
3021
3042
constexpr bool convert_src1_to_q8_1 = true ;
3022
- if (ggml_sycl_supports_reorder_mmvq (src0->type )) {
3023
- opt_for_reorder (&ctx, src0, src1, dst);
3024
- }
3043
+ opt_for_reorder (&ctx, src0, src1, dst, Mul_Mat_Algo::MMVQ);
3025
3044
ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, convert_src1_to_q8_1);
3026
3045
} else if (use_mul_mat_q) {
3027
3046
constexpr bool convert_src1_to_q8_1 = true ;
3028
3047
ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, convert_src1_to_q8_1);
3029
3048
} else {
3030
3049
constexpr bool convert_src1_to_q8_1 = false ;
3031
- if (ggml_sycl_supports_reorder_dequantize (src0->type )) {
3032
- opt_for_reorder (&ctx, src0, src1, dst); // the OP function in this branch support reorder.
3033
- }
3050
+ // MUL_MAT_SYCL supports reorder
3051
+ opt_for_reorder (&ctx, src0, src1, dst, Mul_Mat_Algo::MUL_MAT_SYCL);
3034
3052
ggml_sycl_op_mul_mat (ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, convert_src1_to_q8_1);
3035
3053
}
3036
3054
GGML_SYCL_DEBUG (" call %s done\n " , __func__);
0 commit comments