@@ -1350,6 +1350,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
1350
1350
{
1351
1351
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1352
1352
1353
+ const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
1354
+ if (found_swa && hparams.n_swa > 0) {
1355
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1356
+ hparams.set_swa_pattern(4);
1357
+ } else {
1358
+ hparams.swa_type = LLAMA_SWA_TYPE_NONE;
1359
+ }
1360
+
1353
1361
switch (hparams.n_layer) {
1354
1362
case 16: type = LLM_TYPE_1B; break;
1355
1363
case 32: type = LLM_TYPE_7B; break;
@@ -12233,6 +12241,7 @@ struct llm_build_olmo : public llm_graph_context {
12233
12241
}
12234
12242
};
12235
12243
12244
+ template <bool iswa>
12236
12245
struct llm_build_olmo2 : public llm_graph_context {
12237
12246
llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
12238
12247
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -12248,7 +12257,14 @@ struct llm_build_olmo2 : public llm_graph_context {
12248
12257
// inp_pos - contains the positions
12249
12258
ggml_tensor * inp_pos = build_inp_pos();
12250
12259
12251
- auto * inp_attn = build_attn_inp_kv();
12260
+ using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
12261
+ inp_attn_type * inp_attn = nullptr;
12262
+
12263
+ if constexpr (iswa) {
12264
+ inp_attn = build_attn_inp_kv_iswa();
12265
+ } else {
12266
+ inp_attn = build_attn_inp_kv();
12267
+ }
12252
12268
12253
12269
ggml_tensor * inp_out_ids = build_inp_out_ids();
12254
12270
@@ -12281,17 +12297,36 @@ struct llm_build_olmo2 : public llm_graph_context {
12281
12297
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
12282
12298
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
12283
12299
12284
- Qcur = ggml_rope_ext(
12300
+ const bool is_swa = hparams.is_swa(il);
12301
+
12302
+ if (is_swa) {
12303
+ // For sliding window layers, Olmo3 use regular rope with no yarn rope scaling.
12304
+ // This is achieved here by setting freq_scale and attn_factor to 1.
12305
+ // We also set ext_factor to 0 to avoid a few unnecessary computations.
12306
+ Qcur = ggml_rope_ext(
12307
+ ctx0, Qcur, inp_pos, nullptr,
12308
+ n_rot, rope_type, n_ctx_orig, freq_base, 1.0,
12309
+ 0.0, 1.0, beta_fast, beta_slow
12310
+ );
12311
+
12312
+ Kcur = ggml_rope_ext(
12313
+ ctx0, Kcur, inp_pos, nullptr,
12314
+ n_rot, rope_type, n_ctx_orig, freq_base, 1.0,
12315
+ 0.0, 1.0, beta_fast, beta_slow
12316
+ );
12317
+ } else {
12318
+ Qcur = ggml_rope_ext(
12285
12319
ctx0, Qcur, inp_pos, nullptr,
12286
12320
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
12287
12321
ext_factor, attn_factor, beta_fast, beta_slow
12288
12322
);
12289
12323
12290
- Kcur = ggml_rope_ext(
12324
+ Kcur = ggml_rope_ext(
12291
12325
ctx0, Kcur, inp_pos, nullptr,
12292
12326
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
12293
12327
ext_factor, attn_factor, beta_fast, beta_slow
12294
12328
);
12329
+ }
12295
12330
12296
12331
cb(Qcur, "Qcur", il);
12297
12332
cb(Kcur, "Kcur", il);
@@ -19131,7 +19166,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
19131
19166
} break;
19132
19167
case LLM_ARCH_OLMO2:
19133
19168
{
19134
- llm = std::make_unique<llm_build_olmo2>(*this, params);
19169
+ if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
19170
+ llm = std::make_unique<llm_build_olmo2<true>>(*this, params);
19171
+ } else {
19172
+ llm = std::make_unique<llm_build_olmo2<false>>(*this, params);
19173
+ }
19135
19174
} break;
19136
19175
case LLM_ARCH_OLMOE:
19137
19176
{
0 commit comments