Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
c92be89
Add `serde` alias for `gelu_pytorch_tanh` to Gelu
alvarobartt Jul 25, 2025
bf99ea2
Restructure some GTE imports
alvarobartt Jul 25, 2025
3db49ca
Add Gemma3 implementation (WIP)
alvarobartt Jul 28, 2025
7979209
Sort imports and add Gemma3 loading
alvarobartt Jul 28, 2025
9eb10c8
Remove `use_bidirectional_attention` as sliding attention is always b…
alvarobartt Jul 28, 2025
8e1710b
Rename `Gemma3` to `TinyGemma`
alvarobartt Jul 28, 2025
fca9b1c
Use `query_pre_attn_scalar` for attention scaling
alvarobartt Jul 28, 2025
6ef6733
Add `rope_local_base_freq` base for local attention (`sliding_attenti…
alvarobartt Jul 28, 2025
1bdb605
Skip `DType::F16` for TinyGemma (fails on Transformers too)
alvarobartt Jul 28, 2025
6dc1b3f
Use `pad_token_id` instead of `eos_token_id`
alvarobartt Jul 28, 2025
247ed64
Add `TinyGemmaLayerNorm` (without bias)
alvarobartt Jul 28, 2025
3f0ec54
Use `broadcast_add` when applying attention bias masking
alvarobartt Jul 28, 2025
5bf4740
Fix `attention_type` match-clause (it was reversed)
alvarobartt Jul 28, 2025
e950f57
Replace `LayerNorm` with `RMSNorm` in decoder
alvarobartt Jul 28, 2025
ccf7e85
Fix `full_attention` (attend to all tokens)
alvarobartt Jul 28, 2025
6bbc181
Remove unused Gemma3 pooling methods (CLS, Last Token)
alvarobartt Jul 30, 2025
62a3d86
Use right-padding instead of left-padding
alvarobartt Jul 30, 2025
c8e31c3
Rename TinyGemma to Gemma3 (again)
alvarobartt Jul 30, 2025
235fffd
Add CUDA support for Gemma3 in FP32 (non-flash)
alvarobartt Jul 30, 2025
8ea46f5
Optimize attention by fusing KQV projection
alvarobartt Jul 30, 2025
6439582
Update `version` to 1.8.0 (#686)
alvarobartt Aug 5, 2025
4de5e18
Squash merge main into new-model-addition
alvarobartt Aug 5, 2025
1a985ad
Add `position_ids` as `ort::inputs!` for inference
alvarobartt Aug 6, 2025
f4b1172
Conditionally add `position_ids` if required for `ort`
alvarobartt Aug 7, 2025
527ce0a
Adjust HPU warmup: use dummy inputs with shape more close to real sce…
kaixuanliu Aug 8, 2025
ce48fb6
Add `extra_args` to `trufflehog` to exclude unverified results (#696)
alvarobartt Aug 11, 2025
6b415aa
Update GitHub templates & fix mentions to Text Embeddings Inference (…
alvarobartt Aug 11, 2025
b3f136a
Disable Flash Attention with `USE_FLASH_ATTENTION` (#692)
alvarobartt Aug 15, 2025
15b9ce5
Add support for `position_ids` and `past_key_values` in `OrtBackend` …
alvarobartt Aug 20, 2025
b1e48bd
HPU upgrade to Synapse 1.21.3 (#703)
kaixuanliu Aug 20, 2025
5e870c0
Upgrade to IPEX 2.8 (#702)
kaixuanliu Aug 21, 2025
57c811e
Parse `modules.json` to identify default `Dense` modules (#701)
alvarobartt Aug 21, 2025
35c0f77
Add `padding_side` and `pad_token_id` in `OrtBackend` (#705)
alvarobartt Sep 1, 2025
5dcc614
Merge `main` branch into `new-model-addition`
alvarobartt Sep 1, 2025
1b97bd9
Fix conflicts in `backends/candle/src/lib.rs`
alvarobartt Sep 1, 2025
b56a6c2
Merge branch 'main' into new-model-addition
alvarobartt Sep 1, 2025
47a2b05
Place `serde::Deserialize` import under `candle` feature
alvarobartt Sep 1, 2025
893ed3e
Handle `sentence_embedding` in among `SessionOutputs`
alvarobartt Sep 1, 2025
725c241
Fix `Gemma3Model::load` for CUDA
alvarobartt Sep 2, 2025
674797b
Set default `DType::Float32` for Gemma3
alvarobartt Sep 2, 2025
dc89bc3
Merge branch 'main' into new-model-addition
alvarobartt Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/candle/src/layers/linear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde::Deserialize;
#[derive(Debug, Deserialize, PartialEq, Clone)]
#[serde(rename_all = "lowercase")]
pub enum HiddenAct {
#[serde(alias = "gelu_pytorch_tanh")]
Gelu,
Relu,
Silu,
Expand Down
29 changes: 26 additions & 3 deletions backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ use crate::compute_cap::{
};
use crate::models::{
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig,
Model, ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
Qwen3Config, Qwen3Model,
GTEConfig, GTEModel, Gemma3Config, Gemma3Model, JinaBertModel, JinaCodeBertModel, MPNetConfig,
MPNetModel, MistralConfig, Model, ModernBertConfig, ModernBertModel, NomicBertModel,
NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
};
#[cfg(feature = "cuda")]
use crate::models::{
Expand Down Expand Up @@ -95,6 +95,8 @@ enum Config {
Camembert(BertConfig),
#[serde(rename(deserialize = "distilbert"))]
DistilBert(DistilBertConfig),
#[serde(rename(deserialize = "gemma3_text"))]
Gemma3(Gemma3Config),
#[serde(alias = "new")]
Gte(GTEConfig),
#[serde(rename = "mpnet")]
Expand Down Expand Up @@ -263,6 +265,16 @@ impl CandleBackend {
DistilBertModel::load(vb, &config, model_type).s()?,
))
}
(Config::Gemma3(config), Device::Cpu | Device::Metal(_)) => {
if dtype != DType::F32 {
Err(BackendError::Start(
"Gemma3 is only supported in fp32 precision".to_string(),
))
} else {
tracing::info!("Starting Gemma3 model on {:?}", device);
Ok(Box::new(Gemma3Model::load(vb, &config, model_type).s()?))
}
}
(Config::Gte(config), Device::Cpu | Device::Metal(_)) => {
tracing::info!("Starting GTE model on {:?}", device);
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
Expand Down Expand Up @@ -381,6 +393,17 @@ impl CandleBackend {
}
}
#[cfg(feature = "cuda")]
(Config::Gemma3(config), Device::Cuda(_)) => {
if dtype != DType::F32 {
Err(BackendError::Start(
"Gemma3 is only supported in fp32 precision".to_string(),
))
} else {
tracing::info!("Starting Gemma3 model on {:?}", device);
Ok(Box::new(Gemma3Model::load(vb, &config, model_type).s()?))
}
}
#[cfg(feature = "cuda")]
(Config::Gte(config), Device::Cuda(_)) => {
if dtype != DType::F16
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
Expand Down
4 changes: 3 additions & 1 deletion backends/candle/src/models/flash_gte.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use crate::flash_attn::flash_attn_varlen;
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
use crate::models::{GTEClassificationHead, GTEConfig, Model, PositionEmbeddingType, GTEMLP};
use crate::models::gte::{GTEClassificationHead, GTEConfig, GTEMLP};
use crate::models::{Model, PositionEmbeddingType};

use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_rotary::apply_rotary_inplace;
Expand Down
Loading
Loading