Skip to content

Commit f56013d

Browse files
committed
ggml-cpu: support IQ4_NL_4_4 by runtime repack
1 parent 4a57d36 commit f56013d

File tree

7 files changed

+332
-15
lines changed

7 files changed

+332
-15
lines changed

ggml/include/ggml.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,9 @@ extern "C" {
389389
GGML_TYPE_Q4_0_8_8 = 33,
390390
GGML_TYPE_TQ1_0 = 34,
391391
GGML_TYPE_TQ2_0 = 35,
392+
GGML_TYPE_IQ4_NL_4_4 = 36,
393+
// GGML_TYPE_IQ4_NL_4_8 = 37,
394+
// GGML_TYPE_IQ4_NL_8_8 = 38,
392395
GGML_TYPE_COUNT,
393396
};
394397

ggml/src/ggml-common.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,12 @@ typedef struct {
418418
} block_iq4_xs;
419419
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
420420

421+
typedef struct {
422+
ggml_half d[4]; // deltas for 4 iq4_nl blocks
423+
uint8_t qs[QK4_NL * 2];// nibbles / quants for 4 iq4_nl blocks
424+
} block_iq4_nlx4;
425+
static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding");
426+
421427
#endif // GGML_COMMON_DECL
422428
#endif // GGML_COMMON_DECL
423429

ggml/src/ggml-cpu/ggml-cpu-aarch64.c

Lines changed: 301 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y)
187187
}
188188
#endif
189189

190+
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
191+
190192
static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) {
191193
assert(QK8_0 == 32);
192194
assert(k % QK8_0 == 0);
@@ -996,6 +998,102 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
996998
}
997999
}
9981000

1001+
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
1002+
const int qk = QK8_0;
1003+
const int nb = n / qk;
1004+
const int ncols_interleaved = 4;
1005+
const int blocklen = 4;
1006+
1007+
assert (n % qk == 0);
1008+
assert (nc % ncols_interleaved == 0);
1009+
1010+
UNUSED(s);
1011+
UNUSED(bs);
1012+
UNUSED(vx);
1013+
UNUSED(vy);
1014+
UNUSED(nr);
1015+
UNUSED(nc);
1016+
UNUSED(nb);
1017+
UNUSED(ncols_interleaved);
1018+
UNUSED(blocklen);
1019+
1020+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1021+
if (ggml_cpu_has_neon()) {
1022+
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
1023+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1024+
float * res_ptr = s;
1025+
1026+
for (int x = 0; x < nc / ncols_interleaved; x++) {
1027+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
1028+
1029+
float32x4_t sumf = vdupq_n_f32(0);
1030+
for (int l = 0; l < nb; l++) {
1031+
uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
1032+
uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
1033+
uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
1034+
uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
1035+
1036+
int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
1037+
int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
1038+
int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
1039+
int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
1040+
int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
1041+
int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
1042+
int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
1043+
int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
1044+
1045+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
1046+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
1047+
1048+
int32x4_t sumi = vdupq_n_s32(0);
1049+
sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
1050+
sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
1051+
sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
1052+
sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
1053+
sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
1054+
sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
1055+
sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
1056+
sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
1057+
1058+
float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
1059+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
1060+
float32x4_t d = a_d * b_d;
1061+
1062+
sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
1063+
}
1064+
1065+
vst1q_f32(res_ptr + x * 4, sumf);
1066+
}
1067+
return;
1068+
}
1069+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
1070+
{
1071+
float sumf[4];
1072+
int sumi;
1073+
1074+
const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
1075+
for (int x = 0; x < nc / ncols_interleaved; x++) {
1076+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
1077+
1078+
for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
1079+
for (int l = 0; l < nb; l++) {
1080+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
1081+
for (int j = 0; j < ncols_interleaved; j++) {
1082+
sumi = 0;
1083+
for (int i = 0; i < blocklen; ++i) {
1084+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
1085+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
1086+
sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
1087+
}
1088+
sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d);
1089+
}
1090+
}
1091+
}
1092+
for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
1093+
}
1094+
}
1095+
}
1096+
9991097
void ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
10001098
const int qk = QK8_0;
10011099
const int nb = n / qk;
@@ -3386,6 +3484,117 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
33863484
}
33873485
}
33883486

3487+
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) {
3488+
const int qk = QK8_0;
3489+
const int nb = n / qk;
3490+
const int ncols_interleaved = 4;
3491+
const int blocklen = 4;
3492+
3493+
assert (n % qk == 0);
3494+
assert (nr % 4 == 0);
3495+
assert (nc % ncols_interleaved == 0);
3496+
3497+
UNUSED(s);
3498+
UNUSED(bs);
3499+
UNUSED(vx);
3500+
UNUSED(vy);
3501+
UNUSED(nr);
3502+
UNUSED(nc);
3503+
UNUSED(nb);
3504+
UNUSED(ncols_interleaved);
3505+
UNUSED(blocklen);
3506+
3507+
#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3508+
if (ggml_cpu_has_neon()) {
3509+
const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl);
3510+
3511+
for (int y = 0; y < nr / 4; y++) {
3512+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3513+
for (int x = 0; x < nc / ncols_interleaved; x++) {
3514+
const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
3515+
3516+
float32x4_t sumf[4];
3517+
for (int m = 0; m < 4; m++) {
3518+
sumf[m] = vdupq_n_f32(0);
3519+
}
3520+
3521+
for (int l = 0; l < nb; l++) {
3522+
float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
3523+
float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d));
3524+
3525+
int32x4_t sumi_0 = vdupq_n_s32(0);
3526+
int32x4_t sumi_1 = vdupq_n_s32(0);
3527+
int32x4_t sumi_2 = vdupq_n_s32(0);
3528+
int32x4_t sumi_3 = vdupq_n_s32(0);
3529+
3530+
for (int k = 0; k < 4; k++) {
3531+
int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
3532+
int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
3533+
3534+
uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
3535+
int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
3536+
int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
3537+
3538+
sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
3539+
sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
3540+
sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
3541+
sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
3542+
sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
3543+
sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
3544+
sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
3545+
sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
3546+
}
3547+
3548+
sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
3549+
sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
3550+
sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
3551+
sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
3552+
}
3553+
3554+
for (int m = 0; m < 4; m++) {
3555+
vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
3556+
}
3557+
}
3558+
}
3559+
return;
3560+
}
3561+
#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
3562+
{
3563+
float sumf[4][4];
3564+
int sumi;
3565+
3566+
for (int y = 0; y < nr / 4; y++) {
3567+
const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
3568+
for (int x = 0; x < nc / ncols_interleaved; x++) {
3569+
const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
3570+
for (int m = 0; m < 4; m++) {
3571+
for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
3572+
}
3573+
for (int l = 0; l < nb; l++) {
3574+
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
3575+
for (int m = 0; m < 4; m++) {
3576+
for (int j = 0; j < ncols_interleaved; j++) {
3577+
sumi = 0;
3578+
for (int i = 0; i < blocklen; ++i) {
3579+
const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
3580+
const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
3581+
sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
3582+
(v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
3583+
}
3584+
sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_FP16_TO_FP32(a_ptr[l].d[m]);
3585+
}
3586+
}
3587+
}
3588+
}
3589+
for (int m = 0; m < 4; m++) {
3590+
for (int j = 0; j < ncols_interleaved; j++)
3591+
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
3592+
}
3593+
}
3594+
}
3595+
}
3596+
}
3597+
33893598
// FIXME: this code is duplicated from ggml-aarch64.c
33903599
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
33913600
block_q4_0x4 out;
@@ -3518,27 +3727,101 @@ static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block,
35183727
GGML_UNUSED(data_size);
35193728
}
35203729

3730+
static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
3731+
block_iq4_nlx4 out;
3732+
3733+
for (int i = 0; i < 4; i++) {
3734+
out.d[i] = in[i].d;
3735+
}
3736+
3737+
const int end = QK4_NL * 2 / blck_size_interleave;
3738+
3739+
if (blck_size_interleave == 8) {
3740+
for (int i = 0; i < end; ++i) {
3741+
int src_id = i % 4;
3742+
int src_offset = (i / 4) * blck_size_interleave;
3743+
int dst_offset = i * blck_size_interleave;
3744+
3745+
// Using memcpy to avoid unaligned memory accesses
3746+
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
3747+
}
3748+
} else if (blck_size_interleave == 4) {
3749+
for (int i = 0; i < end; ++i) {
3750+
int src_id = i % 4;
3751+
int src_offset = (i / 4) * blck_size_interleave;
3752+
int dst_offset = i * blck_size_interleave;
3753+
3754+
memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
3755+
}
3756+
} else {
3757+
GGML_ASSERT(false);
3758+
}
3759+
3760+
return out;
3761+
}
3762+
3763+
static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) {
3764+
GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
3765+
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
3766+
3767+
block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data;
3768+
const block_iq4_nl * src = (const block_iq4_nl *)data;
3769+
block_iq4_nl dst_tmp[4];
3770+
int nrow = t->ne[1]; // Number of rows
3771+
int nrows_interleaved = 4;
3772+
int nblocks = t->ne[0] / QK4_0;
3773+
3774+
GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
3775+
3776+
if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
3777+
return -1;
3778+
}
3779+
3780+
for (int b = 0; b < nrow; b += nrows_interleaved) {
3781+
for (int64_t x = 0; x < nblocks; x++) {
3782+
for (int i = 0; i < nrows_interleaved; i++) {
3783+
dst_tmp[i] = src[x + i * nblocks];
3784+
}
3785+
*dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
3786+
}
3787+
src += nrows_interleaved * nblocks;
3788+
}
3789+
return 0;
3790+
3791+
GGML_UNUSED(data_size);
3792+
}
3793+
35213794
// Prepare for optimized kernels if applicable
35223795
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * restrict data, size_t data_size) {
35233796
if (cur->type == repack_type) {
35243797
memcpy(cur->data, data, data_size);
35253798
return;
35263799
}
35273800

3528-
GGML_ASSERT(cur->type == GGML_TYPE_Q4_0);
3529-
3530-
switch (repack_type) {
3531-
case GGML_TYPE_Q4_0_8_8:
3532-
repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
3533-
break;
3534-
case GGML_TYPE_Q4_0_4_8:
3535-
repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
3536-
break;
3537-
case GGML_TYPE_Q4_0_4_4:
3538-
repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
3539-
break;
3540-
default:
3541-
GGML_ABORT("Unsupported type");
3801+
if (cur->type == GGML_TYPE_Q4_0) {
3802+
switch (repack_type) {
3803+
case GGML_TYPE_Q4_0_8_8:
3804+
repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
3805+
break;
3806+
case GGML_TYPE_Q4_0_4_8:
3807+
repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
3808+
break;
3809+
case GGML_TYPE_Q4_0_4_4:
3810+
repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
3811+
break;
3812+
default:
3813+
GGML_ABORT("Unsupported type");
3814+
}
3815+
} else if (cur->type == GGML_TYPE_IQ4_NL) {
3816+
switch (repack_type) {
3817+
case GGML_TYPE_IQ4_NL_4_4:
3818+
repack_iq4_nl_to_iq4_nl_4_bl(cur, 4, data, data_size);
3819+
break;
3820+
default:
3821+
GGML_ABORT("Unsupported type");
3822+
}
3823+
} else {
3824+
GGML_ABORT("Unsupported type");
35423825
}
35433826
}
35443827

@@ -3554,6 +3837,10 @@ enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * c
35543837
if (ggml_cpu_has_neon()) {
35553838
return GGML_TYPE_Q4_0_4_4;
35563839
}
3840+
} else if (cur->type == GGML_TYPE_IQ4_NL) {
3841+
if (ggml_cpu_has_neon()) {
3842+
return GGML_TYPE_IQ4_NL_4_4;
3843+
}
35573844
}
35583845

35593846
return cur->type;

ggml/src/ggml-cpu/ggml-cpu-aarch64.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
1515
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
1616
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
1717
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
18+
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
1819

1920
// GEMM
2021
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
2122
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
2223
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
24+
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
2325

2426
void ggml_aarch64_repack_tensor(struct ggml_tensor * cur, enum ggml_type repack_type, const void * data, size_t data_size);
2527
enum ggml_type ggml_aarch64_get_optimal_repack_type(const struct ggml_tensor * cur);

0 commit comments

Comments
 (0)