Skip to content

Commit f8c7cae

Browse files
committed
cuda : implement ssm scan for Mamba2
There is still room for improvement, but it works! * cuda : adapt Mamba1 ssm scan to shape changes from Mamba2
1 parent a42f239 commit f8c7cae

File tree

4 files changed

+195
-54
lines changed

4 files changed

+195
-54
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3191,7 +3191,18 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
31913191
case GGML_OP_COS:
31923192
case GGML_OP_CLAMP:
31933193
case GGML_OP_LOG:
3194-
case GGML_OP_SSM_SCAN:
3194+
return true;
3195+
case GGML_OP_SSM_SCAN: {
3196+
if (op->src[3]->ne[0] == 1) {
3197+
// Mamba2
3198+
// (kernel only supports d_state == 128 && d_head % 16 == 0)
3199+
return op->src[0]->ne[0] == 128 && op->src[0]->ne[1] % 16 == 0;
3200+
} else {
3201+
// Mamba
3202+
// (kernel only supports d_state == 16, n_group == 1, d_head == 1)
3203+
return op->src[0]->ne[0] == 16 && op->src[4]->ne[1] == 1 && op->src[0]->ne[1] == 1;
3204+
}
3205+
}
31953206
case GGML_OP_SSM_CONV:
31963207
return true;
31973208
case GGML_OP_CONT:

ggml/src/ggml-cuda/ssm-scan.cu

Lines changed: 181 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
44
__global__ void __launch_bounds__(splitD, 2)
55
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
66
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
7-
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
8-
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
9-
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
10-
float * __restrict__ dst, const int64_t L) {
11-
GGML_UNUSED(src1_nb0);
12-
GGML_UNUSED(src2_nb0);
7+
const int32_t * __restrict__ src6, float * __restrict__ dst,
8+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
9+
const int src2_nb1, const int src2_nb2, const int src3_nb1,
10+
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11+
const int64_t s_off, const int64_t d_inner, const int64_t L) {
1312

1413
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
15-
const int bidx = blockIdx.x; // split along B
16-
const int bidy = blockIdx.y; // split along D
14+
const int bidx = blockIdx.x; // split along B (sequences)
15+
const int bidy = blockIdx.y; // split along D (d_inner)
1716
const int tid = threadIdx.x;
1817
const int wid = tid / 32;
1918
const int wtid = tid % 32;
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
2423
float * smem_A = smem;
2524
float * smem_s0 = smem_A + splitD * stride_sA;
2625

27-
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
28-
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
26+
const float * s0_block = (const float *) ((const char *) src0 + src6[bidx] * src0_nb3 + bidy * splitD * src0_nb2);
27+
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3) + bidy * splitD * sizeof(float));
2928
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
3029
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
31-
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
32-
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
33-
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
34-
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
30+
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3));
31+
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3));
32+
float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof(float)) + bidy * splitD * sizeof(float));
33+
float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2);
3534

36-
const int stride_s0 = src0_nb1 / sizeof(float);
37-
const int stride_x = src1_nb1 / sizeof(float);
35+
const int stride_s0 = src0_nb2 / sizeof(float);
36+
const int stride_x = src1_nb2 / sizeof(float);
3837
const int stride_dt = src2_nb1 / sizeof(float);
3938
const int stride_A = src3_nb1 / sizeof(float);
40-
const int stride_B = src4_nb1 / sizeof(float);
41-
const int stride_C = src5_nb1 / sizeof(float);
39+
const int stride_B = src4_nb2 / sizeof(float);
40+
const int stride_C = src5_nb2 / sizeof(float);
4241
const int stride_s = stride_s0;
43-
const int stride_y = stride_x;
42+
const int stride_y = d_inner;
4443

4544
// can N not be 16? for example 32?
4645
if (N == 16) {
@@ -84,24 +83,157 @@ __global__ void __launch_bounds__(splitD, 2)
8483
}
8584
}
8685

86+
// assumes as many threads as d_state
87+
template <int splitH, int d_state>
88+
__global__ void __launch_bounds__(d_state, 1)
89+
ssm_scan_f32_group(
90+
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
91+
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
92+
const int32_t * __restrict__ src6, float * __restrict__ dst,
93+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
94+
const int src2_nb1, const int src2_nb2, const int src3_nb1,
95+
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
96+
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
97+
98+
const int head_idx = (blockIdx.x * splitH) / d_head;
99+
const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
100+
const int seq_idx = blockIdx.y;
101+
102+
const int group_off = (head_idx & (n_group - 1)) * d_state * sizeof(float);
103+
104+
const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
105+
const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
106+
const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
107+
const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
108+
const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
109+
const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
110+
float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
111+
float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
112+
113+
// strides across n_seq_tokens
114+
const int stride_x = src1_nb2 / sizeof(float);
115+
const int stride_dt = src2_nb1 / sizeof(float);
116+
const int stride_B = src4_nb2 / sizeof(float);
117+
const int stride_C = src5_nb2 / sizeof(float);
118+
const int stride_y = n_head * d_head;
119+
120+
float state[splitH];
121+
// for the parallel accumulation
122+
__shared__ float stateC[splitH * d_state];
123+
124+
#pragma unroll
125+
for (int j = 0; j < splitH; j++) {
126+
state[j] = s0_block[j * d_state + threadIdx.x];
127+
}
128+
129+
for (int64_t i = 0; i < n_tok; i++) {
130+
// TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
131+
// TODO: only calculate B and C once per head group
132+
// NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
133+
float dt_soft_plus = dt_block[i * stride_dt];
134+
if (dt_soft_plus <= 20.0f) {
135+
dt_soft_plus = log1pf(expf(dt_soft_plus));
136+
}
137+
const float dA = expf(dt_soft_plus * A_block[0]);
138+
const float B = B_block[i * stride_B + threadIdx.x];
139+
const float C = C_block[i * stride_C + threadIdx.x];
140+
141+
// across d_head
142+
#pragma unroll
143+
for (int j = 0; j < splitH; j++) {
144+
const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
145+
146+
state[j] = (state[j] * dA) + (B * x_dt);
147+
148+
stateC[j * d_state + threadIdx.x] = state[j] * C;
149+
}
150+
151+
__syncthreads();
152+
153+
// parallel accumulation for stateC
154+
// TODO: simplify
155+
{
156+
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
157+
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
158+
159+
// reduce until w matches the warp size
160+
// TODO: does this work even when the physical warp size is 64?
161+
#pragma unroll
162+
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
163+
// (assuming there are d_state threads)
164+
#pragma unroll
165+
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
166+
// TODO: check for bank conflicts
167+
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
168+
stateC[k] += stateC[k + (w >> 1)];
169+
170+
}
171+
__syncthreads();
172+
}
173+
174+
static_assert(splitH >= d_state / WARP_SIZE);
175+
176+
#pragma unroll
177+
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
178+
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
179+
y = warp_reduce_sum(y);
180+
181+
// store the above accumulations
182+
if (threadIdx.x % WARP_SIZE == 0) {
183+
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
184+
y_block[i * stride_y + k] = y;
185+
}
186+
}
187+
}
188+
}
189+
190+
// write back the state
191+
#pragma unroll
192+
for (int j = 0; j < splitH; j++) {
193+
s_block[j * d_state + threadIdx.x] = state[j];
194+
}
195+
}
196+
87197
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
88-
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
89-
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
90-
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
91-
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
92-
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
198+
const float * src4, const float * src5, const int32_t * src6, float * dst,
199+
const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1,
200+
const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2,
201+
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
202+
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
93203
cudaStream_t stream) {
94204
const int threads = 128;
95-
// todo: consider D cannot be divided,does this situation exist?
96-
GGML_ASSERT(D % threads == 0);
97-
const dim3 blocks(B, (D + threads - 1) / threads, 1);
98-
const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
99-
if (N == 16) {
100-
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
101-
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
102-
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
205+
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206+
if (src3_nb1 == sizeof(float)) {
207+
// Mamba2
208+
if (d_state == 128) {
209+
GGML_ASSERT(d_state % threads == 0);
210+
// NOTE: can be any power of two between 4 and 64
211+
const int splitH = 16;
212+
GGML_ASSERT(head_dim % splitH == 0);
213+
const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
214+
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
215+
src0, src1, src2, src3, src4, src5, src6, dst,
216+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217+
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218+
} else {
219+
GGML_ABORT("doesn't support d_state!=128.");
220+
}
103221
} else {
104-
GGML_ABORT("doesn't support N!=16.");
222+
// Mamba1
223+
// todo: consider n_head cannot be divided, does this situation exist?
224+
GGML_ASSERT(n_head % threads == 0);
225+
GGML_ASSERT(head_dim == 1);
226+
GGML_ASSERT(n_group == 1);
227+
const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1);
228+
const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float);
229+
if (d_state == 16) {
230+
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
231+
src0, src1, src2, src3, src4, src5, src6, dst,
232+
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
233+
src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
234+
} else {
235+
GGML_ABORT("doesn't support d_state!=16.");
236+
}
105237
}
106238
}
107239

@@ -112,44 +244,42 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
112244
const struct ggml_tensor * src3 = dst->src[3]; // A
113245
const struct ggml_tensor * src4 = dst->src[4]; // B
114246
const struct ggml_tensor * src5 = dst->src[5]; // C
115-
116-
// const int64_t d_state = src0->ne[0];
117-
// const int64_t d_inner = src0->ne[1];
118-
// const int64_t l = src1->ne[1];
119-
// const int64_t b = src0->ne[2];
247+
const struct ggml_tensor * src6 = dst->src[6]; // ids
120248

121249
const int64_t nc = src0->ne[0]; // d_state
122-
const int64_t nr = src0->ne[1]; // d_inner
123-
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
124-
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
250+
const int64_t nr = src0->ne[1]; // head_dim or 1
251+
const int64_t nh = src1->ne[1]; // n_head
252+
const int64_t ng = src4->ne[1]; // n_group
253+
const int64_t n_t = src1->ne[2]; // number of tokens per sequence
254+
const int64_t n_s = src1->ne[3]; // number of sequences in the batch
255+
256+
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
125257

126-
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
258+
GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*n_s == ggml_nelements(dst));
127259
GGML_ASSERT(src0->nb[0] == sizeof(float));
128260
GGML_ASSERT(src1->nb[0] == sizeof(float));
129261
GGML_ASSERT(src2->nb[0] == sizeof(float));
130262
GGML_ASSERT(src3->nb[0] == sizeof(float));
131263
GGML_ASSERT(src4->nb[0] == sizeof(float));
132264
GGML_ASSERT(src5->nb[0] == sizeof(float));
133-
// required for the dot product between s and C
134-
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
135-
// required for per-sequence offsets for states
136-
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
137-
// required to get correct offset for state destination (i.e. src1->nb[3])
138-
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
265+
GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
139266

140267
const float * src0_d = (const float *) src0->data;
141268
const float * src1_d = (const float *) src1->data;
142269
const float * src2_d = (const float *) src2->data;
143270
const float * src3_d = (const float *) src3->data;
144271
const float * src4_d = (const float *) src4->data;
145272
const float * src5_d = (const float *) src5->data;
273+
const int32_t * src6_d = (const int32_t *) src6->data;
146274
float * dst_d = (float *) dst->data;
147275
cudaStream_t stream = ctx.stream();
148276

149277
GGML_ASSERT(src0->type == GGML_TYPE_F32);
278+
GGML_ASSERT(src6->type == GGML_TYPE_I32);
150279
GGML_ASSERT(dst->type == GGML_TYPE_F32);
151280

152-
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
153-
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
154-
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
281+
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
282+
src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2],
283+
src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3],
284+
s_off, nc, nr, nh, ng, n_t, n_s, stream);
155285
}

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w
215215
const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0];
216216
const int64_t n_head = w->ne[1];
217217
const int64_t head_dim = hparams.ssm_d_inner / n_head;
218-
const int64_t n_group = hparams.ssm_n_group;
218+
const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1;
219219
const int64_t n_seq_tokens = 512;
220220
const int64_t n_seqs = 3;
221221
ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs);

tests/test-backend-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4225,7 +4225,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
42254225
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
42264226

42274227
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
4228-
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2
4228+
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
42294229

42304230
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1));
42314231
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));

0 commit comments

Comments
 (0)