Skip to content

mpt-1b fails with mpt_model_load: unknown tensor 'transformer.blocks.0.attn.k_ln.weight' in model file #499

@OlegJakushkin

Description

@OlegJakushkin

I want to load mpt-1b-redpajama-200b-dolly.

I converted it to ggml with:

!git clone https://github.com/ggerganov/ggml
!cd ggml && rm -rf ./build && mkdir build ; cd build && cmake .. && make -j32
!cd ggml && cd build && python3 ../examples/mpt/convert-h5-to-ggml.py ../../mpt-1b-redpajama-200b-dolly 1

which logged

* Loading part: pytorch_model.bin
Processing variable: transformer.blocks.0.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.0.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.0.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.0.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.0.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.0.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.0.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.0.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.1.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.1.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.1.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.1.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.1.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.1.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.1.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.1.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.10.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.10.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.10.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.10.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.10.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.10.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.10.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.10.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.11.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.11.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.11.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.11.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.11.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.11.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.11.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.11.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.12.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.12.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.12.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.12.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.12.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.12.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.12.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.12.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.13.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.13.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.13.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.13.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.13.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.13.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.13.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.13.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.14.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.14.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.14.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.14.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.14.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.14.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.14.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.14.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.15.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.15.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.15.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.15.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.15.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.15.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.15.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.15.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.16.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.16.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.16.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.16.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.16.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.16.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.16.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.16.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.17.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.17.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.17.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.17.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.17.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.17.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.17.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.17.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.18.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.18.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.18.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.18.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.18.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.18.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.18.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.18.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.19.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.19.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.19.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.19.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.19.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.19.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.19.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.19.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.2.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.2.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.2.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.2.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.2.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.2.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.2.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.2.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.20.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.20.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.20.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.20.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.20.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.20.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.20.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.20.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.21.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.21.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.21.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.21.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.21.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.21.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.21.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.21.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.22.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.22.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.22.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.22.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.22.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.22.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.22.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.22.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.23.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.23.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.23.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.23.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.23.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.23.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.23.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.23.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.3.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.3.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.3.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.3.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.3.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.3.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.3.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.3.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.4.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.4.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.4.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.4.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.4.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.4.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.4.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.4.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.5.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.5.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.5.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.5.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.5.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.5.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.5.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.5.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.6.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.6.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.6.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.6.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.6.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.6.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.6.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.6.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.7.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.7.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.7.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.7.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.7.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.7.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.7.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.7.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.8.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.8.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.8.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.8.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.8.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.8.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.8.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.8.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.blocks.9.attn.Wqkv.weight with shape:  (6144, 2048) -> float16
Processing variable: transformer.blocks.9.attn.k_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.9.attn.out_proj.weight with shape:  (2048, 2048) -> float16
Processing variable: transformer.blocks.9.attn.q_ln.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.9.ln_1.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.9.ln_2.weight with shape:  (2048,) -> float32
Processing variable: transformer.blocks.9.mlp.mlp_down.weight with shape:  (2048, 8192) -> float16
Processing variable: transformer.blocks.9.mlp.mlp_up.weight with shape:  (8192, 2048) -> float16
Processing variable: transformer.ln_f.weight with shape:  (2048,) -> float32
Processing variable: transformer.wte.weight with shape:  (50432, 2048) -> float16
Done. Output file: ../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin

to do that I changed 2 lines in convert-h5-to-ggml

#fout.write(struct.pack("f", hparams["attn_config"]["alibi_bias_max"] or 2048))
fout.write(struct.pack("f", 2048))
#fout.write(struct.pack("f", hparams["attn_config"]["clip_qkv"] or 0.0))
fout.write(struct.pack("f", 0.0))

Now I try to run it with:

!cd ggml && cd build && ./bin/mpt -m  ../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin -p "I believe the meaning of life is" -t 8 -n 64

Getting

main: seed      = 1693544259
main: n_threads = 8
main: n_batch   = 8
main: n_ctx     = 512
main: n_predict = 64

mpt_model_load: loading model from '../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin' - please wait ...
mpt_model_load: d_model        = 2048
mpt_model_load: max_seq_len    = 2048
mpt_model_load: n_ctx          = 512
mpt_model_load: n_heads        = 16
mpt_model_load: n_layers       = 24
mpt_model_load: n_vocab        = 50432
mpt_model_load: alibi_bias_max = 2048.000000
mpt_model_load: clip_qkv       = 0.000000
mpt_model_load: ftype          = 1
mpt_model_load: qntvr          = 0
mpt_model_load: ggml ctx size = 2597.45 MB
mpt_model_load: memory_size =    96.00 MB, n_mem = 12288
mpt_model_load: unknown tensor 'transformer.blocks.0.attn.k_ln.weight' in model file
main: failed to load model from '../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin'
mpt_model_load: 

Could you please tell me what steps shall be performed to enable model loading (as it shall be quite similar to MPT)?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions