11
11
get_tensor_model_parallel_world_size ,
12
12
tensor_model_parallel_all_reduce ,
13
13
)
14
+ from sglang .srt .eplb .expert_location import get_global_expert_location_metadata
14
15
from sglang .srt .layers .moe .topk import TopKOutput
15
16
from sglang .srt .layers .quantization .base_config import (
16
17
QuantizationConfig ,
@@ -62,8 +63,9 @@ def __init__(
62
63
num_experts : int ,
63
64
hidden_size : int ,
64
65
intermediate_size : int ,
66
+ layer_id : int ,
65
67
top_k : Optional [int ] = None ,
66
- layer_id : Optional [ int ] = None ,
68
+ num_fused_shared_experts : int = 0 ,
67
69
params_dtype : Optional [torch .dtype ] = None ,
68
70
reduce_results : bool = False ,
69
71
quant_config : Optional [QuantizationConfig ] = None ,
@@ -84,13 +86,15 @@ def __init__(
84
86
if params_dtype is None :
85
87
params_dtype = torch .get_default_dtype ()
86
88
89
+ self .layer_id = layer_id
87
90
self .top_k = top_k
88
91
self .hidden_size = hidden_size
89
92
self .tp_size = (
90
93
tp_size if tp_size is not None else get_tensor_model_parallel_world_size ()
91
94
)
92
95
self .tp_rank = get_tensor_model_parallel_rank ()
93
96
self .num_experts = num_experts
97
+ self .num_fused_shared_experts = num_fused_shared_experts
94
98
self .expert_map = None
95
99
96
100
if enable_flashinfer_cutlass_moe and quant_config is None :
@@ -375,6 +379,45 @@ def weight_loader(
375
379
shard_id : str ,
376
380
expert_id : int ,
377
381
) -> None :
382
+
383
+ global_expert_location_metadata = get_global_expert_location_metadata ()
384
+ if global_expert_location_metadata is None :
385
+ self ._weight_loader_impl (
386
+ param = param ,
387
+ loaded_weight = loaded_weight ,
388
+ weight_name = weight_name ,
389
+ shard_id = shard_id ,
390
+ expert_id = expert_id ,
391
+ )
392
+ return
393
+
394
+ if expert_id >= self .num_experts - self .num_fused_shared_experts :
395
+ # This is a shared expert.
396
+ physical_expert_ids = [expert_id ]
397
+ else :
398
+ physical_expert_ids = (
399
+ global_expert_location_metadata .logical_to_all_physical (
400
+ self .layer_id , expert_id
401
+ )
402
+ )
403
+
404
+ for physical_expert_id in physical_expert_ids :
405
+ self ._weight_loader_physical (
406
+ param = param ,
407
+ loaded_weight = loaded_weight ,
408
+ weight_name = weight_name ,
409
+ shard_id = shard_id ,
410
+ expert_id = physical_expert_id ,
411
+ )
412
+
413
+ def _weight_loader_physical (
414
+ self ,
415
+ param : torch .nn .Parameter ,
416
+ loaded_weight : torch .Tensor ,
417
+ weight_name : str ,
418
+ shard_id : str ,
419
+ expert_id : int ,
420
+ ) -> None :
378
421
expert_id = self ._map_global_expert_id_to_local_expert_id (expert_id )
379
422
if expert_id == - 1 :
380
423
return
0 commit comments