@@ -101,6 +101,14 @@ def get_modality_of_token(self, token: str) -> Optional[Modality]:
101
101
102
102
return None
103
103
104
+ def get_token_id_by_modality (self , modality : Modality ) -> Optional [int ]:
105
+ return {
106
+ Modality .IMAGE : self .image_token_id ,
107
+ Modality .MULTI_IMAGES : self .image_token_id ,
108
+ Modality .VIDEO : self .video_token_id ,
109
+ Modality .AUDIO : self .audio_token_id ,
110
+ }.get (modality )
111
+
104
112
def parse_regex (self ):
105
113
if self .image_token_regex is None and self .image_token is not None :
106
114
self .image_token_regex = re .compile (re .escape (self .image_token ))
@@ -608,14 +616,12 @@ def process_and_combine_mm_data(
608
616
609
617
# Add offsets to all items
610
618
for mm_item in all_collected_items :
619
+ mm_token_id = mm_tokens .get_token_id_by_modality (mm_item .modality )
620
+ if mm_token_id is None :
621
+ raise ValueError (f"No token id found for modality: { mm_item .modality } " )
611
622
mm_item .offsets = self .get_mm_items_offset (
612
623
input_ids = input_ids ,
613
- mm_token_id = {
614
- Modality .IMAGE : mm_tokens .image_token_id ,
615
- Modality .MULTI_IMAGES : mm_tokens .image_token_id ,
616
- Modality .VIDEO : mm_tokens .video_token_id ,
617
- Modality .AUDIO : mm_tokens .audio_token_id ,
618
- }.get (mm_item .modality , None ),
624
+ mm_token_id = mm_token_id ,
619
625
)
620
626
621
627
return all_collected_items , input_ids , ret
0 commit comments