diff --git a/core/src/tokenization.rs b/core/src/tokenization.rs index fece3a98..71617daf 100644 --- a/core/src/tokenization.rs +++ b/core/src/tokenization.rs @@ -274,7 +274,7 @@ fn prepare_pre_prompt( #[allow(clippy::too_many_arguments)] fn tokenize_input( - inputs: EncodingInput, + mut inputs: EncodingInput, add_special_tokens: bool, max_input_length: usize, truncate_params: Option, @@ -288,9 +288,12 @@ fn tokenize_input( let input_chars = inputs.count_chars(); let limit = max_input_length * MAX_CHAR_MULTIPLIER; if input_chars > limit { - return Err(TextEmbeddingsError::Validation(format!( - "`inputs` must have less than {limit} characters. Given: {input_chars}" - ))); + if truncate_params.is_none() { + return Err(TextEmbeddingsError::Validation(format!( + "`inputs` must have less than {limit} characters. Given: {input_chars}" + ))); + } + inputs.apply_limit(limit); } let encoding = match inputs { @@ -426,6 +429,25 @@ impl EncodingInput { EncodingInput::Ids(v) => v.len(), } } + + fn apply_limit(&mut self, limit: usize) { + let truncate_string = |s: &mut String, limit: usize| { + if s.is_char_boundary(limit) { + s.truncate(limit) + } + }; + + match self { + EncodingInput::Single(s) => { + truncate_string(s, limit); + } + EncodingInput::Dual(s1, s2) => { + truncate_string(s1, limit / 2); + truncate_string(s2, limit / 2); + } + EncodingInput::Ids(_) => {} + } + } } impl From for EncodingInput {