Skip to content

Add support for missing quants in CPY (Metal & CUDA). #11987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,91 @@ static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
memcpy(dsti->qh, &qh, sizeof(qh));
}

static __device__ void cpy_blck_q5_0_f32(const char * cxi, char * cdsti) {
const block_q5_0 * xi = (const block_q5_0 *) cxi;
float * dst = (float *) cdsti;
float d = xi->d; // scale factor (computed as vmax / -16)
const float shift = 16.0f;

// Safely copy the 32-bit qh value to avoid misaligned access.
unsigned int qh;
memcpy(&qh, xi->qh, sizeof(qh));

// First half: lower nibble stores element j.
for (int j = 0; j < QK5_0/2; j++) {
uint8_t lower = xi->qs[j] & 0xF;
uint8_t high = (qh >> j) & 1;
uint8_t q = (high << 4) | lower;
dst[j] = ((float)q - shift) * d;
}
// Second half: upper nibble stores element j + QK5_0/2.
for (int j = QK5_0/2; j < QK5_0; j++) {
int k = j - QK5_0/2;
uint8_t lower = (xi->qs[k] >> 4) & 0xF;
uint8_t high = (qh >> j) & 1;
uint8_t q = (high << 4) | lower;
dst[j] = ((float)q - shift) * d;
}
}

static __device__ void cpy_blck_q5_1_f32(const char * cxi, char * cdsti) {
const block_q5_1 * xi = (const block_q5_1 *) cxi;
float * dst = (float *) cdsti;
float d = xi->dm.x; // scale
float min_val = xi->dm.y; // minimum value

// Safely copy the 32-bit qh value to avoid misaligned access.
unsigned int qh;
memcpy(&qh, xi->qh, sizeof(qh));

// Decode first half: lower nibble of xi->qs[j] holds element j.
for (int j = 0; j < QK5_1/2; j++) {
uint8_t lower = xi->qs[j] & 0xF;
uint8_t high = (qh >> j) & 1;
uint8_t q = (high << 4) | lower;
dst[j] = min_val + d * (float)q;
}
// Decode second half: upper nibble of xi->qs[j] holds element j+QK5_1/2.
for (int j = QK5_1/2; j < QK5_1; j++) {
int k = j - QK5_1/2;
uint8_t lower = (xi->qs[k] >> 4) & 0xF;
uint8_t high = (qh >> j) & 1;
uint8_t q = (high << 4) | lower;
dst[j] = min_val + d * (float)q;
}
}

static __device__ void cpy_blck_q4_0_f32(const char * cxi, char * cdsti) {
const block_q4_0 * xi = (const block_q4_0 *) cxi;
float * dst = (float *) cdsti;
float d = xi->d;
const float shift = 8.0f;

// Each byte packs two 4-bit quantized values.
for (int j = 0; j < QK4_0/2; j++) {
uint8_t q_val = xi->qs[j];
uint8_t q0 = q_val & 0x0F;
uint8_t q1 = (q_val >> 4) & 0x0F;
dst[j] = ((float)q0 - shift) * d;
dst[j + QK4_0/2] = ((float)q1 - shift) * d;
}
}

static __device__ void cpy_blck_q4_1_f32(const char * cxi, char * cdsti) {
const block_q4_1 * xi = (const block_q4_1 *) cxi;
float * dst = (float *) cdsti;
const float d = xi->dm.x;
const float vmin = xi->dm.y;

// Each byte packs two 4-bit quantized values.
for (int j = 0; j < QK4_1/2; ++j) {
uint8_t byte_val = xi->qs[j];
uint8_t q0 = byte_val & 0x0F;
uint8_t q1 = (byte_val >> 4) & 0x0F;
dst[j] = vmin + d * (float)q0;
dst[j + QK4_1/2] = vmin + d * (float)q1;
}
}

static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
Expand Down Expand Up @@ -420,6 +505,58 @@ static void ggml_cpy_f32_q5_1_cuda(
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}

static void ggml_cpy_q5_1_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q5_1_f32, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}

static void ggml_cpy_q5_0_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q5_0_f32, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}

static void ggml_cpy_q4_1_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q4_1_f32, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}

static void ggml_cpy_q4_0_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02,
const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12,
const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) {
const int num_blocks = ne;
cpy_q_f32<cpy_blck_q4_0_f32, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}

static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
Expand Down Expand Up @@ -488,14 +625,25 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
Expand Down Expand Up @@ -524,14 +672,22 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_q_f32<cpy_blck_q8_0_f32, QK8_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
} else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q4_0_f32, QK4_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
} else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q4_1_f32, QK4_1>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
} else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q5_0_f32, QK5_0>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
} else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_q_f32<cpy_blck_q5_1_f32, QK5_1>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
Expand Down
12 changes: 12 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3073,15 +3073,27 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
return true;
}
if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
return true;
}
if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
return true;
}
if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
return true;
}
if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ typedef struct {
} ggml_metal_kargs_repeat;

typedef struct {
int64_t ne;
int64_t ne00;
int64_t ne01;
int64_t ne02;
Expand Down
56 changes: 51 additions & 5 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,11 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR,
GGML_METAL_KERNEL_TYPE_SQRT,
Expand Down Expand Up @@ -1012,6 +1017,11 @@ @implementation GGMLMetalClass
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, cpy_q4_0_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, cpy_q4_1_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, cpy_q5_0_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, cpy_q5_1_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, cpy_q8_0_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT, sqrt, true);
Expand Down Expand Up @@ -1287,6 +1297,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
default:
return false;
}
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return (op->type == GGML_TYPE_F32);
default:
return false;
};
Expand Down Expand Up @@ -1615,7 +1631,10 @@ static void ggml_metal_encode_node(

const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;

const int64_t ne = ggml_nelements(src0);

ggml_metal_kargs_cpy args = {
/*.ne =*/ ne,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
Expand Down Expand Up @@ -3899,10 +3918,7 @@ static void ggml_metal_encode_node(
case GGML_OP_CPY:
case GGML_OP_CONT:
{
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);

int nth = MIN(1024, ne00/ggml_blck_size(src0->type));

const int64_t ne = ggml_nelements(src0);
id<MTLComputePipelineState> pipeline = nil;

switch (src0t) {
Expand Down Expand Up @@ -3936,13 +3952,33 @@ static void ggml_metal_encode_node(
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break;
default: GGML_ASSERT(false && "not implemented");
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
{
if (dstt == GGML_TYPE_F32) {
switch (src0t) {
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break;
default: GGML_ABORT("not implemented");
}
} else {
GGML_ABORT("not implemented");
}
} break;
default: GGML_ABORT("not implemented");
}

ggml_metal_kargs_cpy args = {
/*.ne =*/ ne,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
Expand All @@ -3966,7 +4002,17 @@ static void ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];

int nth;

if (src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
GGML_ASSERT(dstt == GGML_TYPE_F32);
nth = MIN(1024, ne);
} else {
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
nth = MIN(1024, ne00/ggml_blck_size(src0->type));
}
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];

} break;
case GGML_OP_SET:
{
Expand Down
Loading
Loading