Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl Backend {
}
}
for shape in shapes.iter() {
let batch = self.create_warmup_batch(*shape, max_token as u32);
let batch = self.create_warmup_batch(*shape, max_token as u32, seq_bucket_size as u32);
match &self.model_type {
ModelType::Classifier => self.predict(batch).await.map(|_| ()),
ModelType::Embedding(_) => self.embed(batch).await.map(|_| ()),
Expand All @@ -179,19 +179,25 @@ impl Backend {
}

#[instrument(skip_all)]
pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32) -> Batch {
pub fn create_warmup_batch(&self, shape: (u32, u32), max_token: u32, seq_bucket_size: u32) -> Batch {
let (batch_size, length) = shape;
let min_length = length.saturating_sub(seq_bucket_size).saturating_add(1);
let tmp_length = if min_length < length {
rand::rng().random_range(min_length..length)
} else {
length
};
let mut batched_input_ids = Vec::new();
let mut batched_token_type_ids = Vec::new();
let mut batched_position_ids = Vec::new();
let mut cumulative_seq_lengths = Vec::with_capacity(batch_size as usize + 1);
let mut pooled_indices = Vec::with_capacity(batch_size as usize);
cumulative_seq_lengths.push(0);
let input_ids: Vec<u32> = (0..length)
let input_ids: Vec<u32> = (0..tmp_length)
.map(|_| rand::rng().random_range(0..max_token))
.collect();
let token_type_ids: Vec<u32> = vec![0; length as usize];
let position_ids: Vec<u32> = (0..length).collect();
let token_type_ids: Vec<u32> = vec![0; tmp_length as usize];
let position_ids: Vec<u32> = (0..tmp_length).collect();
let mut current_length = 0;
for batch_id in 0..batch_size {
batched_input_ids.extend(input_ids.iter().cloned());
Expand All @@ -206,7 +212,7 @@ impl Backend {
token_type_ids: batched_token_type_ids,
position_ids: batched_position_ids,
cumulative_seq_lengths,
max_length: length,
max_length: tmp_length,
pooled_indices,
raw_indices: vec![],
}
Expand Down
Loading