Skip to content

Commit b8e09f0

Browse files
authored
model : add grok-2 support (#15539)
* add grok-2 support * type fix * type fix * type fix * "fix" vocab for invalid sequences * fix expert tensor mapping and spaces in vocab * add chat template * fix norm tensor mapping * rename layer_out_norm to ffn_post_norm * ensure ffn_post_norm is mapped * fix experts merging * remove erroneous FFN_GATE entry * concatenate split tensors and add more metadata * process all expert layers and try cat instead of hstack * add support for community BPE vocab * fix expert feed forward length and ffn_down concat * commit this too * add ffn_up/gate/down, unsure if sequence is right * add ffn_gate/down/up to tensor names * correct residual moe (still not working) * mess-- * fix embedding scale being applied twice * add built in chat template * change beta fast for grok if default value * remove spm vocab in favor of community bpe vocab * change attention temp length metadata type to integer * update attention temp length metadata * remove comment * replace M_SQRT2 with std::sqrt(2) * add yarn metadata, move defaults to hparams
1 parent 6c019cb commit b8e09f0

16 files changed

+275
-90
lines changed

common/common.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,9 @@ struct common_params {
288288
float rope_freq_base = 0.0f; // RoPE base frequency
289289
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
290290
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
291-
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
292-
float yarn_beta_fast = 32.0f; // YaRN low correction dim
293-
float yarn_beta_slow = 1.0f; // YaRN high correction dim
291+
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
292+
float yarn_beta_fast = -1.0f; // YaRN low correction dim
293+
float yarn_beta_slow = -1.0f; // YaRN high correction dim
294294
int32_t yarn_orig_ctx = 0; // YaRN original context length
295295

296296
// offload params

convert_hf_to_gguf.py

Lines changed: 78 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
735735
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
736736
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
737737
res = "qwen2"
738+
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
739+
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
740+
res = "grok-2"
738741
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
739742
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
740743
res = "llama-bpe"
@@ -2682,57 +2685,109 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26822685
yield (new_name, data_torch)
26832686

26842687

2685-
@ModelBase.register("GrokForCausalLM")
2688+
@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM")
26862689
class GrokModel(TextModel):
26872690
model_arch = gguf.MODEL_ARCH.GROK
26882691

26892692
def set_vocab(self):
2690-
self._set_vocab_sentencepiece()
2693+
if (self.dir_model / 'tokenizer.model').is_file():
2694+
self._set_vocab_sentencepiece()
2695+
return
2696+
2697+
if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
2698+
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
2699+
sys.exit(1)
2700+
2701+
self._set_vocab_gpt2()
26912702

26922703
def __init__(self, *args, **kwargs):
26932704
super().__init__(*args, **kwargs)
26942705

26952706
def set_gguf_parameters(self):
26962707
super().set_gguf_parameters()
26972708

2698-
_experts: list[dict[str, Tensor]] | None = None
2709+
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
2710+
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
2711+
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
2712+
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
2713+
2714+
if (rope_dim := self.hparams.get("head_dim")) is None:
2715+
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
2716+
2717+
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
2718+
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
2719+
2720+
# Treat "original" as "yarn", seems to have been a mistake
2721+
if self.hparams.get("rope_type") in ("yarn", "original"):
2722+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2723+
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
2724+
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
2725+
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
2726+
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
2727+
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
2728+
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
2729+
2730+
if temp_len := self.hparams.get("attn_temperature_len"):
2731+
self.gguf_writer.add_attn_temperature_length(temp_len)
2732+
2733+
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
2734+
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
2735+
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
2736+
2737+
_experts: list[dict[str, list[Tensor]]] | None = None
2738+
_cur_expert = ""
26992739

27002740
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2741+
tensors: list[tuple[str, Tensor]] = []
2742+
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
2743+
2744+
if not is_expert:
2745+
tensors.append((self.map_tensor_name(name), data_torch))
2746+
27012747
# process the experts separately
2702-
if name.find(".moe.") != -1:
2748+
if is_expert or self._cur_expert:
27032749
n_experts = self.hparams["num_local_experts"]
27042750

27052751
assert bid is not None
27062752

27072753
if self._experts is None:
27082754
self._experts = [{} for _ in range(self.block_count)]
27092755

2710-
self._experts[bid][name] = data_torch
2756+
# concatenate split tensors
2757+
if name in self._experts[bid]:
2758+
self._cur_expert = name
2759+
self._experts[bid][name].append(data_torch)
2760+
return []
2761+
elif is_expert:
2762+
self._cur_expert = name
2763+
self._experts[bid][name] = [data_torch]
2764+
return []
2765+
else:
2766+
self._cur_expert = ""
27112767

2712-
if len(self._experts[bid]) >= n_experts * 3:
2713-
tensors: list[tuple[str, Tensor]] = []
2768+
for bid in range(self.block_count):
2769+
if len(self._experts[bid]) >= n_experts * 3:
2770+
# merge the experts into a single 3d tensor
2771+
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
2772+
datas: list[Tensor] = []
27142773

2715-
# merge the experts into a single 3d tensor
2716-
for wid in ["linear", "linear_1", "linear_v"]:
2717-
datas: list[Tensor] = []
2774+
for xid in range(n_experts):
2775+
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
2776+
if ename not in self._experts[bid]:
2777+
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
2778+
tensor_list = self._experts[bid][ename]
2779+
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
2780+
del self._experts[bid][ename]
27182781

2719-
for xid in range(n_experts):
2720-
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
2721-
datas.append(self._experts[bid][ename])
2722-
del self._experts[bid][ename]
2782+
data_torch = torch.stack(datas, dim=0)
27232783

2724-
data_torch = torch.stack(datas, dim=0)
2784+
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"
27252785

2726-
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
2727-
2728-
new_name = self.map_tensor_name(merged_name)
2786+
new_name = self.map_tensor_name(merged_name)
27292787

2730-
tensors.append((new_name, data_torch))
2731-
return tensors
2732-
else:
2733-
return []
2788+
yield (new_name, data_torch)
27342789

2735-
return [(self.map_tensor_name(name), data_torch)]
2790+
yield from tensors
27362791

27372792

27382793
@ModelBase.register("DbrxForCausalLM")

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ class TOKENIZER_TYPE(IntEnum):
158158
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
159159
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
160160
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
161+
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
161162
]
162163

163164

gguf-py/gguf/constants.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class LLM:
111111
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
112112
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
113113
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
114+
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
114115
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
115116
SWIN_NORM = "{arch}.swin_norm"
116117
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
@@ -146,21 +147,27 @@ class Attention:
146147
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
147148
SLIDING_WINDOW = "{arch}.attention.sliding_window"
148149
SCALE = "{arch}.attention.scale"
150+
OUTPUT_SCALE = "{arch}.attention.output_scale"
151+
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
149152
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
150153
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
151154
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
152155
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
153156

154157
class Rope:
155-
DIMENSION_COUNT = "{arch}.rope.dimension_count"
156-
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
157-
FREQ_BASE = "{arch}.rope.freq_base"
158-
SCALING_TYPE = "{arch}.rope.scaling.type"
159-
SCALING_FACTOR = "{arch}.rope.scaling.factor"
160-
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
161-
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
162-
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
163-
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
158+
DIMENSION_COUNT = "{arch}.rope.dimension_count"
159+
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
160+
FREQ_BASE = "{arch}.rope.freq_base"
161+
SCALING_TYPE = "{arch}.rope.scaling.type"
162+
SCALING_FACTOR = "{arch}.rope.scaling.factor"
163+
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
164+
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
165+
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
166+
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
167+
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
168+
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
169+
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
170+
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
164171

165172
class Split:
166173
LLM_KV_SPLIT_NO = "split.no"
@@ -1114,6 +1121,7 @@ class MODEL_TENSOR(IntEnum):
11141121
MODEL_TENSOR.FFN_GATE_EXP,
11151122
MODEL_TENSOR.FFN_DOWN_EXP,
11161123
MODEL_TENSOR.FFN_UP_EXP,
1124+
MODEL_TENSOR.FFN_POST_NORM,
11171125
MODEL_TENSOR.LAYER_OUT_NORM,
11181126
],
11191127
MODEL_ARCH.GPTNEOX: [

gguf-py/gguf/gguf_writer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,9 @@ def add_logit_scale(self, value: float) -> None:
733733
def add_attn_logit_softcapping(self, value: float) -> None:
734734
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
735735

736+
def add_router_logit_softcapping(self, value: float) -> None:
737+
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
738+
736739
def add_final_logit_softcapping(self, value: float) -> None:
737740
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
738741

@@ -829,6 +832,12 @@ def add_sliding_window(self, value: int) -> None:
829832
def add_attention_scale(self, value: float) -> None:
830833
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
831834

835+
def add_attn_output_scale(self, value: float) -> None:
836+
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
837+
838+
def add_attn_temperature_length(self, value: int) -> None:
839+
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
840+
832841
def add_pooling_type(self, value: PoolingType) -> None:
833842
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
834843

@@ -859,6 +868,18 @@ def add_rope_scaling_finetuned(self, value: bool) -> None:
859868
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
860869
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
861870

871+
def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
872+
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
873+
874+
def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
875+
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
876+
877+
def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
878+
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
879+
880+
def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
881+
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
882+
862883
def add_ssm_conv_kernel(self, value: int) -> None:
863884
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
864885

gguf-py/gguf/tensor_mapping.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class TensorNameMap:
136136
"model.layers.{bid}.norm", # mamba-qbert
137137
"backbone.layers.{bid}.norm", # mamba
138138
"transformer.decoder_layer.{bid}.rms_norm", # Grok
139+
"model.layers.{bid}.pre_attn_norm", # grok-2
139140
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
140141
"encoder.layers.{bid}.input_layernorm", # chatglm
141142
"transformer.layers.{bid}.attn_norm", # openelm
@@ -278,6 +279,7 @@ class TensorNameMap:
278279
"transformer.layer.{bid}.sa_layer_norm", # distillbert
279280
"encoder.layers.{bid}.norm1", # nomic-bert
280281
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
282+
"model.layers.{bid}.post_attn_norm", # grok-2
281283
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
282284
),
283285

@@ -313,6 +315,7 @@ class TensorNameMap:
313315
"h.{bid}.ln_2", # gpt2
314316
"model.layers.{bid}.ffn_norm", # internlm2
315317
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
318+
"model.layers.{bid}.pre_moe_norm", # grok-2
316319
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
317320
"transformer.layers.{bid}.ffn_norm", # openelm
318321
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
@@ -333,11 +336,12 @@ class TensorNameMap:
333336

334337
# Post feed-forward norm
335338
MODEL_TENSOR.FFN_POST_NORM: (
336-
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
337-
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
338-
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
339+
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
340+
"layers.{bid}.post_feedforward_layernorm", # embeddinggemma
341+
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
339342
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
340343
"model.layers.{bid}.feed_forward.up_proj",
344+
"model.layers.{bid}.post_moe_norm", # grok-2
341345
),
342346

343347
MODEL_TENSOR.FFN_GATE_INP: (

0 commit comments

Comments
 (0)