Skip to content

Commit c00aa57

Browse files
authored
add truncation_type for llama4 tokenizer (#2812)
1 parent b7055f9 commit c00aa57

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

torchtune/models/llama4/_tokenizer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ class Llama4Tokenizer(ModelTokenizer, Transform):
120120
- Model-specific templates that are required whenever the model is prompted, such as the [INST]
121121
tags in Llama2 and in Mistral
122122
- Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate`
123+
truncation_type (str): type of truncation to apply, either "left" or "right".
124+
Default is "right".
123125
124126
The extra text will still get tokenized as normal text, not as special tokens. Default is None.
125127
@@ -136,6 +138,7 @@ def __init__(
136138
special_tokens: Optional[dict[str, int]] = None,
137139
max_seq_len: Optional[int] = None,
138140
prompt_template: Optional[PromptTemplateInterface] = None,
141+
truncation_type: str = "right",
139142
):
140143
self.special_tokens = (
141144
special_tokens if special_tokens is not None else LLAMA4_SPECIAL_TOKENS
@@ -188,6 +191,8 @@ def __init__(
188191
r"<\|header_start\|>.*?<\|header_end\|>\n\n"
189192
)
190193

194+
self.truncation_type = truncation_type
195+
191196
def _validate_special_tokens(
192197
self,
193198
):
@@ -420,9 +425,17 @@ def tokenize_messages(
420425

421426
if self.max_seq_len:
422427
tokens = truncate(
423-
tokens, self.max_seq_len, self.eos_id if add_end_tokens else None
428+
tokens=tokens,
429+
max_seq_len=self.max_seq_len,
430+
eos_id=self.eos_id if add_end_tokens else None,
431+
truncation_type=self.truncation_type,
432+
)
433+
mask = truncate(
434+
tokens=mask,
435+
max_seq_len=self.max_seq_len,
436+
eos_id=True if add_end_tokens else None,
437+
truncation_type=self.truncation_type,
424438
)
425-
mask = truncate(mask, self.max_seq_len, True if add_end_tokens else None)
426439

427440
return tokens, mask
428441

0 commit comments

Comments
 (0)