-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Description
I want to load mpt-1b-redpajama-200b-dolly.
I converted it to ggml with:
!git clone https://github.com/ggerganov/ggml
!cd ggml && rm -rf ./build && mkdir build ; cd build && cmake .. && make -j32
!cd ggml && cd build && python3 ../examples/mpt/convert-h5-to-ggml.py ../../mpt-1b-redpajama-200b-dolly 1
which logged
* Loading part: pytorch_model.bin
Processing variable: transformer.blocks.0.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.0.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.0.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.0.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.0.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.0.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.0.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.0.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.1.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.1.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.1.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.1.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.1.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.1.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.1.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.1.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.10.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.10.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.10.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.10.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.10.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.10.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.10.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.10.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.11.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.11.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.11.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.11.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.11.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.11.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.11.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.11.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.12.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.12.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.12.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.12.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.12.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.12.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.12.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.12.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.13.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.13.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.13.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.13.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.13.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.13.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.13.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.13.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.14.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.14.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.14.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.14.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.14.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.14.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.14.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.14.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.15.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.15.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.15.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.15.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.15.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.15.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.15.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.15.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.16.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.16.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.16.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.16.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.16.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.16.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.16.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.16.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.17.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.17.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.17.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.17.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.17.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.17.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.17.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.17.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.18.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.18.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.18.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.18.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.18.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.18.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.18.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.18.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.19.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.19.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.19.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.19.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.19.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.19.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.19.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.19.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.2.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.2.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.2.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.2.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.2.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.2.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.2.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.2.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.20.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.20.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.20.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.20.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.20.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.20.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.20.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.20.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.21.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.21.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.21.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.21.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.21.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.21.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.21.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.21.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.22.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.22.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.22.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.22.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.22.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.22.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.22.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.22.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.23.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.23.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.23.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.23.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.23.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.23.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.23.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.23.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.3.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.3.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.3.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.3.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.3.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.3.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.3.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.3.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.4.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.4.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.4.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.4.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.4.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.4.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.4.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.4.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.5.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.5.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.5.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.5.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.5.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.5.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.5.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.5.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.6.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.6.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.6.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.6.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.6.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.6.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.6.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.6.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.7.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.7.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.7.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.7.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.7.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.7.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.7.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.7.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.8.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.8.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.8.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.8.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.8.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.8.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.8.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.8.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.blocks.9.attn.Wqkv.weight with shape: (6144, 2048) -> float16
Processing variable: transformer.blocks.9.attn.k_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.9.attn.out_proj.weight with shape: (2048, 2048) -> float16
Processing variable: transformer.blocks.9.attn.q_ln.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.9.ln_1.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.9.ln_2.weight with shape: (2048,) -> float32
Processing variable: transformer.blocks.9.mlp.mlp_down.weight with shape: (2048, 8192) -> float16
Processing variable: transformer.blocks.9.mlp.mlp_up.weight with shape: (8192, 2048) -> float16
Processing variable: transformer.ln_f.weight with shape: (2048,) -> float32
Processing variable: transformer.wte.weight with shape: (50432, 2048) -> float16
Done. Output file: ../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin
to do that I changed 2 lines in convert-h5-to-ggml
#fout.write(struct.pack("f", hparams["attn_config"]["alibi_bias_max"] or 2048))
fout.write(struct.pack("f", 2048))
#fout.write(struct.pack("f", hparams["attn_config"]["clip_qkv"] or 0.0))
fout.write(struct.pack("f", 0.0))
Now I try to run it with:
!cd ggml && cd build && ./bin/mpt -m ../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin -p "I believe the meaning of life is" -t 8 -n 64
Getting
main: seed = 1693544259
main: n_threads = 8
main: n_batch = 8
main: n_ctx = 512
main: n_predict = 64
mpt_model_load: loading model from '../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin' - please wait ...
mpt_model_load: d_model = 2048
mpt_model_load: max_seq_len = 2048
mpt_model_load: n_ctx = 512
mpt_model_load: n_heads = 16
mpt_model_load: n_layers = 24
mpt_model_load: n_vocab = 50432
mpt_model_load: alibi_bias_max = 2048.000000
mpt_model_load: clip_qkv = 0.000000
mpt_model_load: ftype = 1
mpt_model_load: qntvr = 0
mpt_model_load: ggml ctx size = 2597.45 MB
mpt_model_load: memory_size = 96.00 MB, n_mem = 12288
mpt_model_load: unknown tensor 'transformer.blocks.0.attn.k_ln.weight' in model file
main: failed to load model from '../../mpt-1b-redpajama-200b-dolly/ggml-model-f16.bin'
mpt_model_load:
Could you please tell me what steps shall be performed to enable model loading (as it shall be quite similar to MPT)?
Metadata
Metadata
Assignees
Labels
No labels