@@ -3221,7 +3221,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
3221
3221
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3222
3222
}
3223
3223
3224
- static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
3224
+ GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
3225
+ return ggml_is_contiguous(tensor);
3226
+ }
3227
+
3228
+ GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
3225
3229
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3226
3230
3227
3231
return
@@ -3230,6 +3234,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
3230
3234
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3231
3235
}
3232
3236
3237
+ GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
3238
+ static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3239
+
3240
+ return
3241
+ tensor->nb[0] == ggml_type_size(tensor->type) &&
3242
+ tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3243
+ }
3244
+
3233
3245
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
3234
3246
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
3235
3247
@@ -11420,8 +11432,8 @@ static void ggml_compute_forward_gelu_f32(
11420
11432
11421
11433
const struct ggml_tensor * src0 = dst->src[0];
11422
11434
11423
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (src0));
11424
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (dst));
11435
+ GGML_ASSERT(ggml_is_contiguous_1 (src0));
11436
+ GGML_ASSERT(ggml_is_contiguous_1 (dst));
11425
11437
GGML_ASSERT(ggml_are_same_shape(src0, dst));
11426
11438
11427
11439
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11483,8 +11495,8 @@ static void ggml_compute_forward_gelu_quick_f32(
11483
11495
11484
11496
const struct ggml_tensor * src0 = dst->src[0];
11485
11497
11486
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (src0));
11487
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (dst));
11498
+ GGML_ASSERT(ggml_is_contiguous_1 (src0));
11499
+ GGML_ASSERT(ggml_is_contiguous_1 (dst));
11488
11500
GGML_ASSERT(ggml_are_same_shape(src0, dst));
11489
11501
11490
11502
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11546,8 +11558,8 @@ static void ggml_compute_forward_silu_f32(
11546
11558
11547
11559
const struct ggml_tensor * src0 = dst->src[0];
11548
11560
11549
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (src0));
11550
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (dst));
11561
+ GGML_ASSERT(ggml_is_contiguous_1 (src0));
11562
+ GGML_ASSERT(ggml_is_contiguous_1 (dst));
11551
11563
GGML_ASSERT(ggml_are_same_shape(src0, dst));
11552
11564
11553
11565
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11658,9 +11670,9 @@ static void ggml_compute_forward_silu_back_f32(
11658
11670
const struct ggml_tensor * src0 = dst->src[0];
11659
11671
const struct ggml_tensor * grad = dst->src[1];
11660
11672
11661
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (grad));
11662
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (src0));
11663
- GGML_ASSERT(ggml_is_contiguous_except_dim_1 (dst));
11673
+ GGML_ASSERT(ggml_is_contiguous_1 (grad));
11674
+ GGML_ASSERT(ggml_is_contiguous_1 (src0));
11675
+ GGML_ASSERT(ggml_is_contiguous_1 (dst));
11664
11676
GGML_ASSERT(ggml_are_same_shape(src0, dst));
11665
11677
GGML_ASSERT(ggml_are_same_shape(src0, grad));
11666
11678
@@ -14358,7 +14370,7 @@ static void ggml_compute_forward_rope_f32(
14358
14370
int ir = 0;
14359
14371
14360
14372
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14361
- const float inv_ndims = -1.f/n_dims;
14373
+
14362
14374
float corr_dims[2];
14363
14375
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14364
14376
@@ -14407,7 +14419,7 @@ static void ggml_compute_forward_rope_f32(
14407
14419
const float cos_block_theta = cosf(block_theta);
14408
14420
const float sin_block_theta = sinf(block_theta) * sin_sign;
14409
14421
14410
- theta_base *= theta_scale;
14422
+ theta_base *= theta_scale;
14411
14423
block_theta *= theta_scale;
14412
14424
14413
14425
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14442,29 +14454,22 @@ static void ggml_compute_forward_rope_f32(
14442
14454
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
14443
14455
}
14444
14456
} else {
14445
- // TODO: this might be wrong for ne0 != n_dims - need double check
14446
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14447
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14448
- theta_base *= freq_scale;
14457
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14449
14458
for (int64_t ic = 0; ic < ne0; ic += 2) {
14450
14459
if (ic < n_dims) {
14451
- const int64_t ib = 0 ;
14460
+ const int64_t i0 = ic/2 ;
14452
14461
14453
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14454
- float cur_rot = inv_ndims * ic - ib;
14455
- float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14462
+ const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
14456
14463
14457
14464
float cos_theta, sin_theta;
14458
14465
rope_yarn(
14459
- theta_base/freq_factor, freq_scale, corr_dims, cur_rot , ext_factor, attn_factor,
14466
+ theta_base/freq_factor, freq_scale, corr_dims, ic , ext_factor, attn_factor,
14460
14467
&cos_theta, &sin_theta
14461
14468
);
14462
- sin_theta *= sin_sign;
14463
14469
14470
+ sin_theta *= sin_sign;
14464
14471
theta_base *= theta_scale;
14465
14472
14466
- const int64_t i0 = ib*n_dims + ic/2;
14467
-
14468
14473
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14469
14474
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14470
14475
@@ -14543,7 +14548,7 @@ static void ggml_compute_forward_rope_f16(
14543
14548
int ir = 0;
14544
14549
14545
14550
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14546
- const float inv_ndims = -1.f/n_dims;
14551
+
14547
14552
float corr_dims[2];
14548
14553
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
14549
14554
@@ -14592,7 +14597,7 @@ static void ggml_compute_forward_rope_f16(
14592
14597
const float cos_block_theta = cosf(block_theta);
14593
14598
const float sin_block_theta = sinf(block_theta) * sin_sign;
14594
14599
14595
- theta_base *= theta_scale;
14600
+ theta_base *= theta_scale;
14596
14601
block_theta *= theta_scale;
14597
14602
14598
14603
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14623,29 +14628,22 @@ static void ggml_compute_forward_rope_f16(
14623
14628
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
14624
14629
}
14625
14630
} else {
14626
- // TODO: this might be wrong for ne0 != n_dims - need double check
14627
- // it seems we have to rope just the first n_dims elements and do nothing with the rest
14628
- // ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14629
- theta_base *= freq_scale;
14631
+ // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
14630
14632
for (int64_t ic = 0; ic < ne0; ic += 2) {
14631
14633
if (ic < n_dims) {
14632
- const int64_t ib = 0 ;
14634
+ const int64_t i0 = ic/2 ;
14633
14635
14634
- // simplified from `(ib * n_dims + ic) * inv_ndims`
14635
- float cur_rot = inv_ndims * ic - ib;
14636
- float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14636
+ const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
14637
14637
14638
14638
float cos_theta, sin_theta;
14639
14639
rope_yarn(
14640
- theta_base/freq_factor, freq_scale, corr_dims, cur_rot , ext_factor, attn_factor,
14640
+ theta_base/freq_factor, freq_scale, corr_dims, ic , ext_factor, attn_factor,
14641
14641
&cos_theta, &sin_theta
14642
14642
);
14643
- sin_theta *= sin_sign;
14644
14643
14644
+ sin_theta *= sin_sign;
14645
14645
theta_base *= theta_scale;
14646
14646
14647
- const int64_t i0 = ib*n_dims + ic/2;
14648
-
14649
14647
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
14650
14648
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
14651
14649
0 commit comments