diff --git a/Cargo.lock b/Cargo.lock index a42181cb..0bdc7848 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2115,6 +2115,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" @@ -2299,6 +2309,19 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "nohash-hasher" version = "0.2.0" @@ -2341,6 +2364,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -2740,6 +2772,33 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ort" +version = "2.0.0-rc.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14" +dependencies = [ + "half", + "js-sys", + "ndarray", + "ort-sys", + "thiserror", + "tracing", + "web-sys", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe" +dependencies = [ + "flate2", + "sha2", + "tar", + "ureq", +] + [[package]] name = "overload" version = "0.1.1" @@ -3163,6 +3222,12 @@ dependencies = [ "bitflags 2.6.0", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.10.0" @@ -3935,8 +4000,11 @@ name = "text-embeddings-backend" version = "1.4.0" dependencies = [ "clap", + "hf-hub", + "serde_json", "text-embeddings-backend-candle", "text-embeddings-backend-core", + "text-embeddings-backend-ort", "text-embeddings-backend-python", "tokio", "tracing", @@ -3981,6 +4049,21 @@ dependencies = [ "thiserror", ] +[[package]] +name = "text-embeddings-backend-ort" +version = "1.4.0" +dependencies = [ + "anyhow", + "ndarray", + "nohash-hasher", + "ort", + "serde", + "serde_json", + "text-embeddings-backend-core", + "thiserror", + "tracing", +] + [[package]] name = "text-embeddings-backend-python" version = "1.4.0" diff --git a/Cargo.toml b/Cargo.toml index 9e1c26cd..a780aa27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "backends", "backends/candle", + "backends/ort", "backends/core", "backends/python", "backends/grpc-client", diff --git a/Dockerfile b/Dockerfile index 8c1368c0..a3612d87 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,22 +28,9 @@ ARG ACTIONS_CACHE_URL ARG ACTIONS_RUNTIME_TOKEN ARG SCCACHE_GHA_ENABLED -RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \ -| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ - echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | \ - tee /etc/apt/sources.list.d/oneAPI.list - -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - intel-oneapi-mkl-devel=2024.0.0-49656 \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \ - gcc -shared -fPIC -o libfakeintel.so fakeintel.c - COPY --from=planner /usr/src/recipe.json recipe.json -RUN cargo chef cook --release --features candle --features mkl-dynamic --no-default-features --recipe-path recipe.json && sccache -s +RUN cargo chef cook --release --features ort --no-default-features --recipe-path recipe.json && sccache -s COPY backends backends COPY core core @@ -53,7 +40,7 @@ COPY Cargo.lock ./ FROM builder as http-builder -RUN cargo build --release --bin text-embeddings-router -F candle -F mkl-dynamic -F http --no-default-features && sccache -s +RUN cargo build --release --bin text-embeddings-router -F ort -F http --no-default-features && sccache -s FROM builder as grpc-builder @@ -65,35 +52,18 @@ RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ COPY proto proto -RUN cargo build --release --bin text-embeddings-router -F grpc -F candle -F mkl-dynamic --no-default-features && sccache -s +RUN cargo build --release --bin text-embeddings-router -F grpc -F ort --no-default-features && sccache -s FROM debian:bookworm-slim as base ENV HUGGINGFACE_HUB_CACHE=/data \ - PORT=80 \ - MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \ - RAYON_NUM_THREADS=8 \ - LD_PRELOAD=/usr/local/libfakeintel.so \ - LD_LIBRARY_PATH=/usr/local/lib + PORT=80 RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - libomp-dev \ ca-certificates \ libssl-dev \ - curl \ && rm -rf /var/lib/apt/lists/* -# Copy a lot of the Intel shared objects because of the mkl_serv_intel_cpu_true patch... -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_lp64.so.2 /usr/local/lib/libmkl_intel_lp64.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_intel_thread.so.2 /usr/local/lib/libmkl_intel_thread.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_core.so.2 /usr/local/lib/libmkl_core.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_def.so.2 /usr/local/lib/libmkl_vml_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_def.so.2 /usr/local/lib/libmkl_def.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx2.so.2 /usr/local/lib/libmkl_vml_avx2.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_vml_avx512.so.2 /usr/local/lib/libmkl_vml_avx512.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx2.so.2 /usr/local/lib/libmkl_avx2.so.2 -COPY --from=builder /opt/intel/oneapi/mkl/latest/lib/intel64/libmkl_avx512.so.2 /usr/local/lib/libmkl_avx512.so.2 -COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so FROM base as grpc diff --git a/Dockerfile-arm64 b/Dockerfile-arm64 deleted file mode 100644 index e170a534..00000000 --- a/Dockerfile-arm64 +++ /dev/null @@ -1,92 +0,0 @@ -FROM lukemathwalker/cargo-chef:latest-rust-1.75-bookworm AS chef - -WORKDIR /usr/src - -ENV SCCACHE=0.5.4 -ENV RUSTC_WRAPPER=/usr/local/bin/sccache - -# Donwload and configure sccache -RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \ - chmod +x /usr/local/bin/sccache - -FROM chef AS planner - -COPY backends backends -COPY core core -COPY router router -COPY Cargo.toml ./ -COPY Cargo.lock ./ - -RUN cargo chef prepare --recipe-path recipe.json - -FROM chef AS builder - -ARG GIT_SHA -ARG DOCKER_LABEL - -# sccache specific variables -ARG ACTIONS_CACHE_URL -ARG ACTIONS_RUNTIME_TOKEN -ARG SCCACHE_GHA_ENABLED - -RUN echo "int mkl_serv_intel_cpu_true() {return 1;}" > fakeintel.c && \ - gcc -shared -fPIC -o libfakeintel.so fakeintel.c - -COPY --from=planner /usr/src/recipe.json recipe.json - -RUN cargo chef cook --release --features candle --no-default-features --recipe-path recipe.json && sccache -s - -COPY backends backends -COPY core core -COPY router router -COPY Cargo.toml ./ -COPY Cargo.lock ./ - -FROM builder as http-builder - -RUN cargo build --release --bin text-embeddings-router -F candle -F http --no-default-features && sccache -s - -FROM builder as grpc-builder - -RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \ - curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \ - unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \ - unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \ - rm -f $PROTOC_ZIP - -COPY proto proto - -RUN cargo build --release --bin text-embeddings-router -F grpc -F candle --no-default-features && sccache -s - -FROM debian:bookworm-slim as base - -COPY --from=builder /usr/src/libfakeintel.so /usr/local/libfakeintel.so - -ENV HUGGINGFACE_HUB_CACHE=/data \ - PORT=80 \ - MKL_ENABLE_INSTRUCTIONS=AVX512_E4 \ - RAYON_NUM_THREADS=8 \ - LD_PRELOAD=/usr/local/libfakeintel.so \ - LD_LIBRARY_PATH=/usr/local/lib - -RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ - libomp-dev \ - ca-certificates \ - libssl-dev \ - curl \ - && rm -rf /var/lib/apt/lists/* - - -FROM base as grpc - -COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router - -ENTRYPOINT ["text-embeddings-router"] -CMD ["--json-output"] - -FROM base - -COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router - -ENTRYPOINT ["text-embeddings-router"] -CMD ["--json-output"] diff --git a/README.md b/README.md index 288225b7..8cbc2872 100644 --- a/README.md +++ b/README.md @@ -72,9 +72,9 @@ Below are some examples of the currently supported models: | MTEB Rank | Model Size | Model Type | Model ID | |-----------|---------------------|-------------|--------------------------------------------------------------------------------------------------| | 1 | 7B (Very Expensive) | Mistral | [Salesforce/SFR-Embedding-2_R](https://hf.co/Salesforce/SFR-Embedding-2_R) | -| 2 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) | -| 9 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) | -| 15 | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](Alibaba-NLP/gte-large-en-v1.5) | +| 2 | 7B (Very Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-7B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-7B-instruct) | +| 9 | 1.5B (Expensive) | Qwen2 | [Alibaba-NLP/gte-Qwen2-1.5B-instruct](https://hf.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct) | +| 15 | 0.4B | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) | | 20 | 0.3B | Bert | [WhereIsAI/UAE-Large-V1](https://hf.co/WhereIsAI/UAE-Large-V1) | | 24 | 0.5B | XLM-RoBERTa | [intfloat/multilingual-e5-large-instruct](https://hf.co/intfloat/multilingual-e5-large-instruct) | | N/A | 0.1B | NomicBert | [nomic-ai/nomic-embed-text-v1](https://hf.co/nomic-ai/nomic-embed-text-v1) | @@ -568,7 +568,7 @@ supported via Docker. As such inference will be CPU bound and most likely pretty M1/M2 ARM CPU. ``` -docker build . -f Dockerfile-arm64 --platform=linux/arm64 +docker build . -f Dockerfile --platform=linux/arm64 ``` ## Examples diff --git a/backends/Cargo.toml b/backends/Cargo.toml index d5659b18..f29283ad 100644 --- a/backends/Cargo.toml +++ b/backends/Cargo.toml @@ -7,15 +7,19 @@ homepage.workspace = true [dependencies] clap = { workspace = true, optional = true } +hf-hub = { workspace = true } +serde_json = { workspace = true } text-embeddings-backend-core = { path = "core" } text-embeddings-backend-python = { path = "python", optional = true } text-embeddings-backend-candle = { path = "candle", optional = true } +text-embeddings-backend-ort = { path = "ort", optional = true } tokio = { workspace = true } tracing = { workspace = true } [features] clap = ["dep:clap", "text-embeddings-backend-core/clap"] python = ["dep:text-embeddings-backend-python"] +ort = ["dep:text-embeddings-backend-ort"] candle = ["dep:text-embeddings-backend-candle"] cuda = ["text-embeddings-backend-candle?/cuda"] metal = ["text-embeddings-backend-candle?/metal"] diff --git a/backends/ort/Cargo.toml b/backends/ort/Cargo.toml new file mode 100644 index 00000000..9cbde87e --- /dev/null +++ b/backends/ort/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "text-embeddings-backend-ort" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +anyhow = { workspace = true } +nohash-hasher = { workspace = true } +ndarray = "0.15.6" +ort = { version = "2.0.0-rc.2", default-features = false, features = ["download-binaries", "half", "onednn", "ndarray"] } +text-embeddings-backend-core = { path = "../core" } +tracing = { workspace = true } +thiserror = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } diff --git a/backends/ort/src/lib.rs b/backends/ort/src/lib.rs new file mode 100644 index 00000000..6951d3b4 --- /dev/null +++ b/backends/ort/src/lib.rs @@ -0,0 +1,395 @@ +use ndarray::{s, Axis}; +use nohash_hasher::BuildNoHashHasher; +use ort::{GraphOptimizationLevel, Session}; +use std::collections::HashMap; +use std::ops::{Div, Mul}; +use std::path::PathBuf; +use text_embeddings_backend_core::{ + Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions, +}; + +pub struct OrtBackend { + session: Session, + pool: Pool, + type_id_name: Option, +} + +impl OrtBackend { + pub fn new( + model_path: PathBuf, + dtype: String, + model_type: ModelType, + ) -> Result { + // Check dtype + if &dtype == "float32" { + } else { + return Err(BackendError::Start(format!( + "DType {dtype} is not supported" + ))); + }; + + // Check model type + let pool = match model_type { + ModelType::Classifier => Pool::Cls, + ModelType::Embedding(pool) => match pool { + Pool::Splade | Pool::LastToken => { + return Err(BackendError::Start(format!( + "Pooling {pool} is not supported for this backend. Use `candle` backend instead." + ))); + } + pool => pool, + }, + }; + + // Get model path + let onnx_path = { + let default_path = model_path.join("model.onnx"); + match default_path.exists() { + true => default_path, + false => model_path.join("onnx/model.onnx"), + } + }; + + // Start onnx session + let session = Session::builder() + .s()? + .with_optimization_level(GraphOptimizationLevel::Level3) + .s()? + .commit_from_file(onnx_path) + .s()?; + + // Check if the model requires type tokens + let mut type_id_name = None; + for input in &session.inputs { + if &input.name == "token_type_ids" || &input.name == "input_type" { + type_id_name = Some(input.name.clone()); + break; + } + } + + Ok(Self { + session, + pool, + type_id_name, + }) + } +} + +impl Backend for OrtBackend { + fn max_batch_size(&self) -> Option { + Some(8) + } + + fn health(&self) -> Result<(), BackendError> { + Ok(()) + } + + fn is_padded(&self) -> bool { + true + } + + fn embed(&self, batch: Batch) -> Result { + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + // Whether a least one of the request in the batch is padded + let mut masking = false; + + let (input_ids, type_ids, input_lengths, attention_mask) = { + let elems = batch_size * max_length; + + if batch_size > 1 { + // Prepare padded batch + let mut input_ids = Vec::with_capacity(elems); + let mut type_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + let mut input_lengths = Vec::with_capacity(batch_size); + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + input_lengths.push(seq_length as f32); + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j] as i64); + type_ids.push(batch.token_type_ids[j] as i64); + attention_mask.push(1_i64); + } + + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + // Set bool to use attention mask + masking = true; + for _ in 0..padding { + input_ids.push(0); + type_ids.push(0); + attention_mask.push(0_i64); + } + } + } + (input_ids, type_ids, input_lengths, attention_mask) + } else { + let attention_mask = vec![1_i64; elems]; + + ( + batch.input_ids.into_iter().map(|v| v as i64).collect(), + batch.token_type_ids.into_iter().map(|v| v as i64).collect(), + vec![batch.max_length as f32], + attention_mask, + ) + } + }; + + // Create ndarrays + let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?; + let attention_mask = + ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?; + let input_lengths = ndarray::Array1::from_vec(input_lengths); + + // Create onnx inputs + let inputs = match self.type_id_name.as_ref() { + Some(type_id_name) => { + // Add type ids to inputs + let type_ids = + ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?; + ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()? + } + None => { + ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()] + .e()? + } + }; + + // Run model + let outputs = self.session.run(inputs).e()?; + // Get last_hidden_state ndarray + + let outputs = outputs + .get("last_hidden_state") + .or(outputs.get("token_embeddings")) + .ok_or(BackendError::Inference(format!( + "Unknown output keys: {:?}", + self.session.outputs + )))? + .try_extract_tensor::() + .e()? + .to_owned(); + + // Final embeddings struct + let mut embeddings = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + + let has_pooling_requests = !batch.pooled_indices.is_empty(); + let has_raw_requests = !batch.raw_indices.is_empty(); + + if has_pooling_requests { + let mut outputs = outputs.clone(); + + // Only use pooled_indices if at least one member of the batch ask for raw embeddings + let indices = if has_raw_requests { + let indices: Vec = + batch.pooled_indices.iter().map(|v| *v as usize).collect(); + + // Select values in the batch + outputs = outputs.select(Axis(0), &indices); + Some(indices) + } else { + None + }; + + let pooled_embeddings = match self.pool { + // CLS pooling + Pool::Cls => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(), + // Last token pooling is not supported for this model + Pool::LastToken => unreachable!(), + // Mean pooling + Pool::Mean => { + if masking { + let mut attention_mask = attention_mask; + + if let Some(indices) = indices { + // Select values in the batch + attention_mask = attention_mask.select(Axis(0), &indices); + }; + + // Cast and reshape + let attention_mask = attention_mask.mapv(|x| x as f32).insert_axis(Axis(2)); + + // Mask padded values + outputs = outputs.mul(attention_mask); + outputs.sum_axis(Axis(1)).div(input_lengths) + } else { + outputs.mean_axis(Axis(1)).unwrap() + } + } + Pool::Splade => unreachable!(), + }; + + for (i, e) in batch + .pooled_indices + .into_iter() + .zip(pooled_embeddings.rows()) + { + embeddings.insert(i as usize, Embedding::Pooled(e.to_vec())); + } + }; + + if has_raw_requests { + // Reshape outputs + let s = outputs.shape().to_vec(); + let outputs = outputs.into_shape((s[0] * s[1], s[2])).e()?; + + // We need to remove the padding tokens only if batch_size > 1 and there are some + // member of the batch that require pooling + // or if batch_size > 1 and the members of the batch have different lengths + let raw_embeddings = if (masking || has_pooling_requests) && batch_size > 1 { + let mut final_indices: Vec = Vec::with_capacity(batch_size * max_length); + + for i in batch.raw_indices.iter() { + let start = i * batch.max_length; + let i = *i as usize; + let length = + batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]; + + for j in start..start + length { + // Add indices for the tokens of this specific member of the batch + final_indices.push(j as usize); + } + } + + // Select the tokens with final indices + outputs.select(Axis(0), &final_indices) + } else { + outputs + }; + + // Used for indexing in the raw_embeddings tensor + let input_lengths: Vec = (0..batch_size) + .map(|i| { + (batch.cumulative_seq_lengths[i + 1] - batch.cumulative_seq_lengths[i]) as usize + }) + .collect(); + + let mut cumulative_length = 0; + for i in batch.raw_indices.into_iter() { + let length = input_lengths[i as usize]; + let e = raw_embeddings.slice(s![cumulative_length..cumulative_length + length, ..]); + let e = e.rows().into_iter().map(|v| v.to_vec()).collect(); + + embeddings.insert(i as usize, Embedding::All(e)); + cumulative_length += length; + } + } + + Ok(embeddings) + } + + fn predict(&self, batch: Batch) -> Result { + let batch_size = batch.len(); + let max_length = batch.max_length as usize; + + let (input_ids, type_ids, attention_mask) = { + let elems = batch_size * max_length; + + if batch_size > 1 { + // Prepare padded batch + let mut input_ids = Vec::with_capacity(elems); + let mut type_ids = Vec::with_capacity(elems); + let mut attention_mask = Vec::with_capacity(elems); + + for i in 0..batch_size { + let start = batch.cumulative_seq_lengths[i] as usize; + let end = batch.cumulative_seq_lengths[i + 1] as usize; + let seq_length = (end - start) as u32; + + // Copy values + for j in start..end { + input_ids.push(batch.input_ids[j] as i64); + type_ids.push(batch.token_type_ids[j] as i64); + attention_mask.push(1_i64); + } + + // Add padding if needed + let padding = batch.max_length - seq_length; + if padding > 0 { + for _ in 0..padding { + input_ids.push(0); + type_ids.push(0); + attention_mask.push(0_i64); + } + } + } + (input_ids, type_ids, attention_mask) + } else { + let attention_mask = vec![1_i64; elems]; + + ( + batch.input_ids.into_iter().map(|v| v as i64).collect(), + batch.token_type_ids.into_iter().map(|v| v as i64).collect(), + attention_mask, + ) + } + }; + + // Create ndarrays + let input_ids = ndarray::Array2::from_shape_vec((batch_size, max_length), input_ids).e()?; + let attention_mask = + ndarray::Array2::from_shape_vec((batch_size, max_length), attention_mask).e()?; + + // Create onnx inputs + let inputs = match self.type_id_name.as_ref() { + Some(type_id_name) => { + // Add type ids to inputs + let type_ids = + ndarray::Array2::from_shape_vec((batch_size, max_length), type_ids).e()?; + ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone(), type_id_name => type_ids].e()? + } + None => { + ort::inputs!["input_ids" => input_ids, "attention_mask" => attention_mask.clone()] + .e()? + } + }; + + // Run model + let outputs = self.session.run(inputs).e()?; + // Get last_hidden_state ndarray + let outputs = outputs["logits"] + .try_extract_tensor::() + .e()? + .to_owned(); + + let mut predictions = + HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default()); + for (i, r) in outputs.rows().into_iter().enumerate() { + predictions.insert(i, r.to_vec()); + } + + Ok(predictions) + } +} + +pub trait WrapErr { + fn s(self) -> Result; + fn e(self) -> Result; +} + +impl WrapErr for Result { + fn s(self) -> Result { + self.map_err(|e| BackendError::Start(e.to_string())) + } + fn e(self) -> Result { + self.map_err(|e| BackendError::Inference(e.to_string())) + } +} + +impl WrapErr for Result { + fn s(self) -> Result { + self.map_err(|e| BackendError::Start(e.to_string())) + } + fn e(self) -> Result { + self.map_err(|e| BackendError::Inference(e.to_string())) + } +} diff --git a/backends/src/dtype.rs b/backends/src/dtype.rs index d2c896ce..53193ef7 100644 --- a/backends/src/dtype.rs +++ b/backends/src/dtype.rs @@ -12,11 +12,8 @@ pub enum DType { all(feature = "candle", not(feature = "accelerate")) ))] Float16, - // Float32 is not available on candle cuda - #[cfg(any(feature = "python", feature = "candle"))] + #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] Float32, - // #[cfg(feature = "candle")] - // Q6K, } impl fmt::Display for DType { @@ -28,11 +25,32 @@ impl fmt::Display for DType { all(feature = "candle", not(feature = "accelerate")) ))] DType::Float16 => write!(f, "float16"), - // Float32 is not available on candle cuda - #[cfg(any(feature = "python", feature = "candle"))] + #[cfg(any(feature = "python", feature = "candle", feature = "ort"))] DType::Float32 => write!(f, "float32"), - // #[cfg(feature = "candle")] - // DType::Q6K => write!(f, "q6k"), + } + } +} + +#[allow(clippy::derivable_impls)] +impl Default for DType { + fn default() -> Self { + #[cfg(any( + feature = "accelerate", + feature = "mkl", + feature = "mkl-dynamic", + feature = "ort" + ))] + { + DType::Float32 + } + #[cfg(not(any( + feature = "accelerate", + feature = "mkl", + feature = "mkl-dynamic", + feature = "ort" + )))] + { + DType::Float16 } } } diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 9b5d1762..2ee63279 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -1,5 +1,6 @@ mod dtype; +use hf_hub::api::tokio::{ApiError, ApiRepo}; use std::cmp::{max, min}; use std::path::PathBuf; use std::sync::Arc; @@ -17,6 +18,9 @@ pub use text_embeddings_backend_core::{ #[cfg(feature = "candle")] use text_embeddings_backend_candle::CandleBackend; +#[cfg(feature = "ort")] +use text_embeddings_backend_ort::OrtBackend; + #[cfg(feature = "python")] use text_embeddings_backend_python::PythonBackend; @@ -222,6 +226,13 @@ fn init_backend( .expect("Python Backend management thread failed")?, )); } + } else if cfg!(feature = "ort") { + #[cfg(feature = "ort")] + return Ok(Box::new(OrtBackend::new( + model_path, + dtype.to_string(), + model_type, + )?)); } Err(BackendError::NoBackend) } @@ -286,3 +297,70 @@ enum BackendCommand { oneshot::Sender>, ), } + +pub async fn download_weights(api: &ApiRepo) -> Result, ApiError> { + let model_files = if cfg!(feature = "python") || cfg!(feature = "candle") { + match download_safetensors(api).await { + Ok(p) => p, + Err(_) => { + tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); + tracing::info!("Downloading `pytorch_model.bin`"); + let p = api.get("pytorch_model.bin").await?; + vec![p] + } + } + } else if cfg!(feature = "ort") { + tracing::info!("Downloading `model.onnx`"); + match api.get("model.onnx").await { + Ok(p) => vec![p], + Err(err) => { + tracing::warn!("Could not download `model.onnx`: {err}"); + tracing::info!("Downloading `onnx/model.onnx`"); + let p = api.get("onnx/model.onnx").await?; + vec![p.parent().unwrap().to_path_buf()] + } + } + } else { + unreachable!() + }; + Ok(model_files) +} + +async fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { + // Single file + tracing::info!("Downloading `model.safetensors`"); + match api.get("model.safetensors").await { + Ok(p) => return Ok(vec![p]), + Err(err) => tracing::warn!("Could not download `model.safetensors`: {}", err), + }; + + // Sharded weights + // Download and parse index file + tracing::info!("Downloading `model.safetensors.index.json`"); + let index_file = api.get("model.safetensors.index.json").await?; + let index_file_string: String = + std::fs::read_to_string(index_file).expect("model.safetensors.index.json is corrupted"); + let json: serde_json::Value = serde_json::from_str(&index_file_string) + .expect("model.safetensors.index.json is corrupted"); + + let weight_map = match json.get("weight_map") { + Some(serde_json::Value::Object(map)) => map, + _ => panic!("model.safetensors.index.json is corrupted"), + }; + + let mut safetensors_filenames = std::collections::HashSet::new(); + for value in weight_map.values() { + if let Some(file) = value.as_str() { + safetensors_filenames.insert(file.to_string()); + } + } + + // Download weight files + let mut safetensors_files = Vec::new(); + for n in safetensors_filenames { + tracing::info!("Downloading `{}`", n); + safetensors_files.push(api.get(&n).await?); + } + + Ok(safetensors_files) +} diff --git a/core/src/download.rs b/core/src/download.rs index 08220752..02ea484d 100644 --- a/core/src/download.rs +++ b/core/src/download.rs @@ -1,5 +1,6 @@ use hf_hub::api::tokio::{ApiError, ApiRepo}; use std::path::PathBuf; +use text_embeddings_backend::download_weights; use tracing::instrument; // Old classes used other config names than 'sentence_bert_config.json' @@ -25,15 +26,7 @@ pub async fn download_artifacts(api: &ApiRepo) -> Result { tracing::info!("Downloading `tokenizer.json`"); api.get("tokenizer.json").await?; - let model_files = match download_safetensors(api).await { - Ok(p) => p, - Err(_) => { - tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); - tracing::info!("Downloading `pytorch_model.bin`"); - let p = api.get("pytorch_model.bin").await?; - vec![p] - } - }; + let model_files = download_weights(api).await?; let model_root = model_files[0].parent().unwrap().to_path_buf(); tracing::info!("Model artifacts downloaded in {:?}", start.elapsed()); @@ -47,45 +40,6 @@ pub async fn download_pool_config(api: &ApiRepo) -> Result { Ok(pool_config_path) } -async fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { - // Single file - tracing::info!("Downloading `model.safetensors`"); - match api.get("model.safetensors").await { - Ok(p) => return Ok(vec![p]), - Err(err) => tracing::warn!("Could not download `model.safetensors`: {}", err), - }; - - // Sharded weights - // Download and parse index file - tracing::info!("Downloading `model.safetensors.index.json`"); - let index_file = api.get("model.safetensors.index.json").await?; - let index_file_string: String = - std::fs::read_to_string(index_file).expect("model.safetensors.index.json is corrupted"); - let json: serde_json::Value = serde_json::from_str(&index_file_string) - .expect("model.safetensors.index.json is corrupted"); - - let weight_map = match json.get("weight_map") { - Some(serde_json::Value::Object(map)) => map, - _ => panic!("model.safetensors.index.json is corrupted"), - }; - - let mut safetensors_filenames = std::collections::HashSet::new(); - for value in weight_map.values() { - if let Some(file) = value.as_str() { - safetensors_filenames.insert(file.to_string()); - } - } - - // Download weight files - let mut safetensors_files = Vec::new(); - for n in safetensors_filenames { - tracing::info!("Downloading `{}`", n); - safetensors_files.push(api.get(&n).await?); - } - - Ok(safetensors_files) -} - #[instrument(skip_all)] pub async fn download_st_config(api: &ApiRepo) -> Result { // Try default path diff --git a/router/Cargo.toml b/router/Cargo.toml index 44e56015..bc42265b 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -81,6 +81,7 @@ mkl = ["text-embeddings-backend/mkl"] mkl-dynamic = ["text-embeddings-backend/mkl-dynamic"] accelerate = ["text-embeddings-backend/accelerate"] python = ["text-embeddings-backend/python"] +ort = ["text-embeddings-backend/ort"] candle = ["text-embeddings-backend/candle"] candle-cuda = ["candle", "text-embeddings-backend/flash-attn"] candle-cuda-turing = ["candle", "text-embeddings-backend/flash-attn-v1"] diff --git a/router/src/lib.rs b/router/src/lib.rs index 57dcecb6..86a0f884 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -235,16 +235,7 @@ pub async fn run( ); // Get dtype - let dtype = dtype.unwrap_or({ - #[cfg(any(feature = "accelerate", feature = "mkl", feature = "mkl-dynamic"))] - { - DType::Float32 - } - #[cfg(not(any(feature = "accelerate", feature = "mkl", feature = "mkl-dynamic")))] - { - DType::Float16 - } - }); + let dtype = dtype.unwrap_or_default(); // Create backend tracing::info!("Starting model backend"); diff --git a/router/src/main.rs b/router/src/main.rs index e9cf084f..a3c8d042 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -14,10 +14,10 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; struct Args { /// The name of the model to load. /// Can be a MODEL_ID as listed on like - /// `Alibaba-NLP/gte-base-en-v1.5`. + /// `BAAI/bge-large-en-v1.5`. /// Or it can be a local directory containing the necessary files /// as saved by `save_pretrained(...)` methods of transformers - #[clap(default_value = "Alibaba-NLP/gte-base-en-v1.5", long, env)] + #[clap(default_value = "BAAI/bge-large-en-v1.5", long, env)] #[redact(partial)] model_id: String, diff --git a/router/tests/test_http_predict.rs b/router/tests/test_http_predict.rs index 6869f721..850e0874 100644 --- a/router/tests/test_http_predict.rs +++ b/router/tests/test_http_predict.rs @@ -16,12 +16,13 @@ pub struct SnapshotPrediction { #[tokio::test] #[cfg(feature = "http")] async fn test_predict() -> Result<()> { - start_server( - "SamLowe/roberta-base-go_emotions".to_string(), - None, - DType::Float32, - ) - .await?; + let model_id = if cfg!(feature = "ort") { + "SamLowe/roberta-base-go_emotions-onnx" + } else { + "SamLowe/roberta-base-go_emotions" + }; + + start_server(model_id.to_string(), None, DType::Float32).await?; let request = json!({ "inputs": "test" diff --git a/router/tests/test_http_rerank.rs b/router/tests/test_http_rerank.rs index 0c5b6079..920eca4d 100644 --- a/router/tests/test_http_rerank.rs +++ b/router/tests/test_http_rerank.rs @@ -17,12 +17,7 @@ pub struct SnapshotRank { #[tokio::test] #[cfg(feature = "http")] async fn test_rerank() -> Result<()> { - start_server( - "BAAI/bge-reranker-base".to_string(), - Some("refs/pr/5".to_string()), - DType::Float32, - ) - .await?; + start_server("BAAI/bge-reranker-base".to_string(), None, DType::Float32).await?; let request = json!({ "query": "test",