@@ -4,16 +4,15 @@ template <size_t splitD, size_t N>
4
4
__global__ void __launch_bounds__ (splitD, 2 )
5
5
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
6
6
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
7
- const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
8
- const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
9
- const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
10
- float * __restrict__ dst, const int64_t L) {
11
- GGML_UNUSED (src1_nb0);
12
- GGML_UNUSED (src2_nb0);
7
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
8
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
9
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
10
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
11
+ const int64_t s_off, const int64_t d_inner, const int64_t L) {
13
12
14
13
constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
15
- const int bidx = blockIdx .x ; // split along B
16
- const int bidy = blockIdx .y ; // split along D
14
+ const int bidx = blockIdx .x ; // split along B (sequences)
15
+ const int bidy = blockIdx .y ; // split along D (d_inner)
17
16
const int tid = threadIdx .x ;
18
17
const int wid = tid / 32 ;
19
18
const int wtid = tid % 32 ;
@@ -24,23 +23,23 @@ __global__ void __launch_bounds__(splitD, 2)
24
23
float * smem_A = smem;
25
24
float * smem_s0 = smem_A + splitD * stride_sA;
26
25
27
- const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1 );
28
- const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2 ) + bidy * splitD * sizeof (float ));
26
+ const float * s0_block = (const float *) ((const char *) src0 + src6[ bidx] * src0_nb3 + bidy * splitD * src0_nb2 );
27
+ const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb3 ) + bidy * splitD * sizeof (float ));
29
28
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof (float ));
30
29
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
31
- const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2 ));
32
- const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2 ));
33
- float * y_block = (float *) ((char *) dst + (bidx * src1_nb2 ) + bidy * splitD * sizeof (float ));
34
- float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1 );
30
+ const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb3 ));
31
+ const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb3 ));
32
+ float * y_block = (float *) ((char *) dst + (bidx * d_inner * L * sizeof ( float ) ) + bidy * splitD * sizeof (float ));
33
+ float * s_block = (float *) ((char *) dst + s_off + bidx * src0_nb3 + bidy * splitD * src0_nb2 );
35
34
36
- const int stride_s0 = src0_nb1 / sizeof (float );
37
- const int stride_x = src1_nb1 / sizeof (float );
35
+ const int stride_s0 = src0_nb2 / sizeof (float );
36
+ const int stride_x = src1_nb2 / sizeof (float );
38
37
const int stride_dt = src2_nb1 / sizeof (float );
39
38
const int stride_A = src3_nb1 / sizeof (float );
40
- const int stride_B = src4_nb1 / sizeof (float );
41
- const int stride_C = src5_nb1 / sizeof (float );
39
+ const int stride_B = src4_nb2 / sizeof (float );
40
+ const int stride_C = src5_nb2 / sizeof (float );
42
41
const int stride_s = stride_s0;
43
- const int stride_y = stride_x ;
42
+ const int stride_y = d_inner ;
44
43
45
44
// can N not be 16? for example 32?
46
45
if (N == 16 ) {
@@ -84,24 +83,157 @@ __global__ void __launch_bounds__(splitD, 2)
84
83
}
85
84
}
86
85
86
+ // assumes as many threads as d_state
87
+ template <int splitH, int d_state>
88
+ __global__ void __launch_bounds__ (d_state, 1 )
89
+ ssm_scan_f32_group(
90
+ const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
91
+ const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
92
+ const int32_t * __restrict__ src6, float * __restrict__ dst,
93
+ const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3,
94
+ const int src2_nb1, const int src2_nb2, const int src3_nb1,
95
+ const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
96
+ const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
97
+
98
+ const int head_idx = (blockIdx .x * splitH) / d_head;
99
+ const int head_off = ((blockIdx .x * splitH) % d_head) * sizeof (float );
100
+ const int seq_idx = blockIdx .y ;
101
+
102
+ const int group_off = (head_idx & (n_group - 1 )) * d_state * sizeof (float );
103
+
104
+ const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
105
+ const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx .x * splitH * sizeof (float ));
106
+ const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof (float ));
107
+ const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
108
+ const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
109
+ const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
110
+ float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx .x * splitH;
111
+ float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
112
+
113
+ // strides across n_seq_tokens
114
+ const int stride_x = src1_nb2 / sizeof (float );
115
+ const int stride_dt = src2_nb1 / sizeof (float );
116
+ const int stride_B = src4_nb2 / sizeof (float );
117
+ const int stride_C = src5_nb2 / sizeof (float );
118
+ const int stride_y = n_head * d_head;
119
+
120
+ float state[splitH];
121
+ // for the parallel accumulation
122
+ __shared__ float stateC[splitH * d_state];
123
+
124
+ #pragma unroll
125
+ for (int j = 0 ; j < splitH; j++) {
126
+ state[j] = s0_block[j * d_state + threadIdx .x ];
127
+ }
128
+
129
+ for (int64_t i = 0 ; i < n_tok; i++) {
130
+ // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
131
+ // TODO: only calculate B and C once per head group
132
+ // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
133
+ float dt_soft_plus = dt_block[i * stride_dt];
134
+ if (dt_soft_plus <= 20 .0f ) {
135
+ dt_soft_plus = log1pf (expf (dt_soft_plus));
136
+ }
137
+ const float dA = expf (dt_soft_plus * A_block[0 ]);
138
+ const float B = B_block[i * stride_B + threadIdx .x ];
139
+ const float C = C_block[i * stride_C + threadIdx .x ];
140
+
141
+ // across d_head
142
+ #pragma unroll
143
+ for (int j = 0 ; j < splitH; j++) {
144
+ const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
145
+
146
+ state[j] = (state[j] * dA) + (B * x_dt);
147
+
148
+ stateC[j * d_state + threadIdx .x ] = state[j] * C;
149
+ }
150
+
151
+ __syncthreads ();
152
+
153
+ // parallel accumulation for stateC
154
+ // TODO: simplify
155
+ {
156
+ static_assert ((d_state & -d_state) == d_state, " the state size has to be a power of 2" );
157
+ static_assert ((splitH & -splitH) == splitH, " splitH has to be a power of 2" );
158
+
159
+ // reduce until w matches the warp size
160
+ // TODO: does this work even when the physical warp size is 64?
161
+ #pragma unroll
162
+ for (int w = d_state; w > WARP_SIZE; w >>= 1 ) {
163
+ // (assuming there are d_state threads)
164
+ #pragma unroll
165
+ for (int j = 0 ; j < ((w >> 1 ) * splitH + d_state - 1 ) / d_state; j++) {
166
+ // TODO: check for bank conflicts
167
+ const int k = (threadIdx .x % (w >> 1 )) + (d_state * (threadIdx .x / (w >> 1 ))) + j * d_state * (d_state / (w >> 1 ));
168
+ stateC[k] += stateC[k + (w >> 1 )];
169
+
170
+ }
171
+ __syncthreads ();
172
+ }
173
+
174
+ static_assert (splitH >= d_state / WARP_SIZE);
175
+
176
+ #pragma unroll
177
+ for (int j = 0 ; j < splitH / (d_state / WARP_SIZE); j++) {
178
+ float y = stateC[(threadIdx .x % WARP_SIZE) + d_state * (threadIdx .x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
179
+ y = warp_reduce_sum (y);
180
+
181
+ // store the above accumulations
182
+ if (threadIdx .x % WARP_SIZE == 0 ) {
183
+ const int k = threadIdx .x / WARP_SIZE + j * (d_state / WARP_SIZE);
184
+ y_block[i * stride_y + k] = y;
185
+ }
186
+ }
187
+ }
188
+ }
189
+
190
+ // write back the state
191
+ #pragma unroll
192
+ for (int j = 0 ; j < splitH; j++) {
193
+ s_block[j * d_state + threadIdx .x ] = state[j];
194
+ }
195
+ }
196
+
87
197
static void ssm_scan_f32_cuda (const float * src0, const float * src1, const float * src2, const float * src3,
88
- const float * src4, const float * src5, const int src0_nb1, const int src0_nb2 ,
89
- const int src1_nb0 , const int src1_nb1 , const int src1_nb2, const int src1_nb3,
90
- const int src2_nb0 , const int src2_nb1 , const int src2_nb2 , const int src3_nb1 ,
91
- const int src4_nb1 , const int src4_nb2 , const int src5_nb1 , const int src5_nb2 ,
92
- float * dst, const int64_t N , const int64_t D , const int64_t L , const int64_t B ,
198
+ const float * src4, const float * src5, const int32_t * src6, float * dst ,
199
+ const int src0_nb2 , const int src0_nb3 , const int src1_nb2, const int src1_nb3, const int src2_nb1 ,
200
+ const int src2_nb2 , const int src3_nb1 , const int src4_nb2 , const int src4_nb3, const int src5_nb2 ,
201
+ const int src5_nb3 , const int64_t s_off , const int64_t d_state , const int64_t head_dim ,
202
+ const int64_t n_head , const int64_t n_group , const int64_t n_tok , const int64_t n_seq ,
93
203
cudaStream_t stream) {
94
204
const int threads = 128 ;
95
- // todo: consider D cannot be divided,does this situation exist?
96
- GGML_ASSERT (D % threads == 0 );
97
- const dim3 blocks (B, (D + threads - 1 ) / threads, 1 );
98
- const int smem_size = (threads * (N + 1 ) * 2 ) * sizeof (float );
99
- if (N == 16 ) {
100
- ssm_scan_f32<128 , 16 ><<<blocks, threads, smem_size, stream>>> (
101
- src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
102
- src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
205
+ // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
206
+ if (src3_nb1 == sizeof (float )) {
207
+ // Mamba2
208
+ if (d_state == 128 ) {
209
+ GGML_ASSERT (d_state % threads == 0 );
210
+ // NOTE: can be any power of two between 4 and 64
211
+ const int splitH = 16 ;
212
+ GGML_ASSERT (head_dim % splitH == 0 );
213
+ const dim3 blocks ((n_head * head_dim + (splitH - 1 )) / splitH, n_seq, 1 );
214
+ ssm_scan_f32_group<16 , 128 ><<<blocks, threads, 0 , stream>>> (
215
+ src0, src1, src2, src3, src4, src5, src6, dst,
216
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
217
+ src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
218
+ } else {
219
+ GGML_ABORT (" doesn't support d_state!=128." );
220
+ }
103
221
} else {
104
- GGML_ABORT (" doesn't support N!=16." );
222
+ // Mamba1
223
+ // todo: consider n_head cannot be divided, does this situation exist?
224
+ GGML_ASSERT (n_head % threads == 0 );
225
+ GGML_ASSERT (head_dim == 1 );
226
+ GGML_ASSERT (n_group == 1 );
227
+ const dim3 blocks (n_seq, (n_head + threads - 1 ) / threads, 1 );
228
+ const int smem_size = (threads * (d_state + 1 ) * 2 ) * sizeof (float );
229
+ if (d_state == 16 ) {
230
+ ssm_scan_f32<128 , 16 ><<<blocks, threads, smem_size, stream>>> (
231
+ src0, src1, src2, src3, src4, src5, src6, dst,
232
+ src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2,
233
+ src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok);
234
+ } else {
235
+ GGML_ABORT (" doesn't support d_state!=16." );
236
+ }
105
237
}
106
238
}
107
239
@@ -112,44 +244,42 @@ void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
112
244
const struct ggml_tensor * src3 = dst->src [3 ]; // A
113
245
const struct ggml_tensor * src4 = dst->src [4 ]; // B
114
246
const struct ggml_tensor * src5 = dst->src [5 ]; // C
115
-
116
- // const int64_t d_state = src0->ne[0];
117
- // const int64_t d_inner = src0->ne[1];
118
- // const int64_t l = src1->ne[1];
119
- // const int64_t b = src0->ne[2];
247
+ const struct ggml_tensor * src6 = dst->src [6 ]; // ids
120
248
121
249
const int64_t nc = src0->ne [0 ]; // d_state
122
- const int64_t nr = src0->ne [1 ]; // d_inner
123
- const int64_t n_t = src1->ne [1 ]; // number of tokens per sequence
124
- const int64_t n_s = src0->ne [2 ]; // number of sequences in the batch
250
+ const int64_t nr = src0->ne [1 ]; // head_dim or 1
251
+ const int64_t nh = src1->ne [1 ]; // n_head
252
+ const int64_t ng = src4->ne [1 ]; // n_group
253
+ const int64_t n_t = src1->ne [2 ]; // number of tokens per sequence
254
+ const int64_t n_s = src1->ne [3 ]; // number of sequences in the batch
255
+
256
+ const int64_t s_off = ggml_nelements (src1) * sizeof (float );
125
257
126
- GGML_ASSERT (ggml_nelements (src1) + ggml_nelements (src0) == ggml_nelements (dst));
258
+ GGML_ASSERT (ggml_nelements (src1) + nc*nr*nh*n_s == ggml_nelements (dst));
127
259
GGML_ASSERT (src0->nb [0 ] == sizeof (float ));
128
260
GGML_ASSERT (src1->nb [0 ] == sizeof (float ));
129
261
GGML_ASSERT (src2->nb [0 ] == sizeof (float ));
130
262
GGML_ASSERT (src3->nb [0 ] == sizeof (float ));
131
263
GGML_ASSERT (src4->nb [0 ] == sizeof (float ));
132
264
GGML_ASSERT (src5->nb [0 ] == sizeof (float ));
133
- // required for the dot product between s and C
134
- GGML_ASSERT (src0->nb [1 ] == src0->ne [0 ] * sizeof (float ));
135
- // required for per-sequence offsets for states
136
- GGML_ASSERT (src0->nb [2 ] == src0->ne [0 ] * src0->ne [1 ] * sizeof (float ));
137
- // required to get correct offset for state destination (i.e. src1->nb[3])
138
- GGML_ASSERT (src1->nb [3 ] == src1->ne [0 ] * src1->ne [1 ] * src1->ne [2 ] * sizeof (float ));
265
+ GGML_ASSERT (src6->nb [0 ] == sizeof (int32_t ));
139
266
140
267
const float * src0_d = (const float *) src0->data ;
141
268
const float * src1_d = (const float *) src1->data ;
142
269
const float * src2_d = (const float *) src2->data ;
143
270
const float * src3_d = (const float *) src3->data ;
144
271
const float * src4_d = (const float *) src4->data ;
145
272
const float * src5_d = (const float *) src5->data ;
273
+ const int32_t * src6_d = (const int32_t *) src6->data ;
146
274
float * dst_d = (float *) dst->data ;
147
275
cudaStream_t stream = ctx.stream ();
148
276
149
277
GGML_ASSERT (src0->type == GGML_TYPE_F32);
278
+ GGML_ASSERT (src6->type == GGML_TYPE_I32);
150
279
GGML_ASSERT (dst->type == GGML_TYPE_F32);
151
280
152
- ssm_scan_f32_cuda (src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb [1 ], src0->nb [2 ], src1->nb [0 ],
153
- src1->nb [1 ], src1->nb [2 ], src1->nb [3 ], src2->nb [0 ], src2->nb [1 ], src2->nb [2 ], src3->nb [1 ],
154
- src4->nb [1 ], src4->nb [2 ], src5->nb [1 ], src5->nb [2 ], dst_d, nc, nr, n_t , n_s, stream);
281
+ ssm_scan_f32_cuda (src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src6_d, dst_d,
282
+ src0->nb [2 ], src0->nb [3 ], src1->nb [2 ], src1->nb [3 ], src2->nb [1 ], src2->nb [2 ],
283
+ src3->nb [1 ], src4->nb [2 ], src4->nb [3 ], src5->nb [2 ], src5->nb [3 ],
284
+ s_off, nc, nr, nh, ng, n_t , n_s, stream);
155
285
}
0 commit comments