@@ -120,6 +120,8 @@ class Llama4Tokenizer(ModelTokenizer, Transform):
120
120
- Model-specific templates that are required whenever the model is prompted, such as the [INST]
121
121
tags in Llama2 and in Mistral
122
122
- Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate`
123
+ truncation_type (str): type of truncation to apply, either "left" or "right".
124
+ Default is "right".
123
125
124
126
The extra text will still get tokenized as normal text, not as special tokens. Default is None.
125
127
@@ -136,6 +138,7 @@ def __init__(
136
138
special_tokens : Optional [dict [str , int ]] = None ,
137
139
max_seq_len : Optional [int ] = None ,
138
140
prompt_template : Optional [PromptTemplateInterface ] = None ,
141
+ truncation_type : str = "right" ,
139
142
):
140
143
self .special_tokens = (
141
144
special_tokens if special_tokens is not None else LLAMA4_SPECIAL_TOKENS
@@ -188,6 +191,8 @@ def __init__(
188
191
r"<\|header_start\|>.*?<\|header_end\|>\n\n"
189
192
)
190
193
194
+ self .truncation_type = truncation_type
195
+
191
196
def _validate_special_tokens (
192
197
self ,
193
198
):
@@ -420,9 +425,17 @@ def tokenize_messages(
420
425
421
426
if self .max_seq_len :
422
427
tokens = truncate (
423
- tokens , self .max_seq_len , self .eos_id if add_end_tokens else None
428
+ tokens = tokens ,
429
+ max_seq_len = self .max_seq_len ,
430
+ eos_id = self .eos_id if add_end_tokens else None ,
431
+ truncation_type = self .truncation_type ,
432
+ )
433
+ mask = truncate (
434
+ tokens = mask ,
435
+ max_seq_len = self .max_seq_len ,
436
+ eos_id = True if add_end_tokens else None ,
437
+ truncation_type = self .truncation_type ,
424
438
)
425
- mask = truncate (mask , self .max_seq_len , True if add_end_tokens else None )
426
439
427
440
return tokens , mask
428
441
0 commit comments