Skip to content

Commit e08bca2

Browse files
authored
Support load fine-tuned LLaVA model (#80)
1 parent cd3ccb2 commit e08bca2

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

python/sglang/srt/models/llama2.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,8 @@ def load_weights(
303303
# Skip loading extra bias for GPTQ models.
304304
if name.endswith(".bias") and name not in params_dict:
305305
continue
306+
if name.startswith("model.vision_tower") and name not in params_dict:
307+
continue
306308
param = params_dict[name]
307309
weight_loader = param.weight_loader
308310
weight_loader(param, loaded_weight, shard_id)
@@ -311,6 +313,8 @@ def load_weights(
311313
# Skip loading extra bias for GPTQ models.
312314
if name.endswith(".bias") and name not in params_dict:
313315
continue
316+
if name.startswith("model.vision_tower") and name not in params_dict:
317+
continue
314318
param = params_dict[name]
315319
weight_loader = getattr(param, "weight_loader", default_weight_loader)
316320
weight_loader(param, loaded_weight)

0 commit comments

Comments
 (0)