Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ candle-cuda = ["candle", "text-embeddings-backend/flash-attn"]
candle-cuda-turing = ["candle", "text-embeddings-backend/flash-attn-v1"]
candle-cuda-volta = ["candle", "text-embeddings-backend/cuda"]
static-linking = ["text-embeddings-backend/static-linking"]
google = []
144 changes: 140 additions & 4 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::http::types::{
EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse, Input, OpenAICompatEmbedding,
OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
Sequence, SimpleToken, TokenizeRequest, TokenizeResponse,
Sequence, SimpleToken, TokenizeRequest, TokenizeResponse, VertexRequest,
};
use crate::{
shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
Expand Down Expand Up @@ -992,6 +992,101 @@ async fn tokenize(
Ok(Json(TokenizeResponse(tokens)))
}

/// Generate embeddings from a Vertex request
#[utoipa::path(
post,
tag = "Text Embeddings Inference",
path = "/vertex",
request_body = VertexRequest,
responses(
(status = 200, description = "Embeddings", body = EmbedAllResponse),
(status = 424, description = "Embedding Error", body = ErrorResponse,
example = json ! ({"error": "Inference failed", "error_type": "backend"})),
(status = 429, description = "Model is overloaded", body = ErrorResponse,
example = json ! ({"error": "Model is overloaded", "error_type": "overloaded"})),
(status = 422, description = "Tokenization error", body = ErrorResponse,
example = json ! ({"error": "Tokenization error", "error_type": "tokenizer"})),
(status = 413, description = "Batch size error", body = ErrorResponse,
example = json ! ({"error": "Batch size error", "error_type": "validation"})),
)
)]
#[instrument(skip_all)]
async fn vertex_compatibility(
infer: Extension<Infer>,
info: Extension<Info>,
Json(req): Json<VertexRequest>,
) -> Result<(HeaderMap, Json<EmbedAllResponse>), (StatusCode, Json<ErrorResponse>)> {
let span = tracing::Span::current();
let start_time = Instant::now();

let batch_size = req.instances.len();
if batch_size > info.max_client_batch_size {
let message = format!(
"batch size {batch_size} > maximum allowed batch size {}",
info.max_client_batch_size
);
tracing::error!("{message}");
let err = ErrorResponse {
error: message,
error_type: ErrorType::Validation,
};
metrics::increment_counter!("te_request_failure", "err" => "batch_size");
Err(err)?;
}

let mut futures = Vec::with_capacity(batch_size);
let mut compute_chars = 0;

for instance in req.instances.iter() {
let input = instance.inputs.clone();
compute_chars += input.chars().count();

let local_infer = infer.clone();
futures.push(async move {
let permit = local_infer.acquire_permit().await;
local_infer
.embed_all(input, instance.truncate, permit)
.await
})
}
let results = join_all(futures)
.await
.into_iter()
.collect::<Result<Vec<AllEmbeddingsInferResponse>, TextEmbeddingsError>>()
.map_err(ErrorResponse::from)?;

let mut embeddings = Vec::with_capacity(batch_size);
let mut total_tokenization_time = 0;
let mut total_queue_time = 0;
let mut total_inference_time = 0;
let mut total_compute_tokens = 0;

for r in results {
total_tokenization_time += r.metadata.tokenization.as_nanos() as u64;
total_queue_time += r.metadata.queue.as_nanos() as u64;
total_inference_time += r.metadata.inference.as_nanos() as u64;
total_compute_tokens += r.metadata.prompt_tokens;
embeddings.push(r.results);
}
let batch_size = batch_size as u64;

let response = EmbedAllResponse(embeddings);
let metadata = ResponseMetadata::new(
compute_chars,
total_compute_tokens,
start_time,
Duration::from_nanos(total_tokenization_time / batch_size),
Duration::from_nanos(total_queue_time / batch_size),
Duration::from_nanos(total_inference_time / batch_size),
);

metadata.record_span(&span);
metadata.record_metrics();
tracing::info!("Success");

Ok((HeaderMap::from(metadata), Json(response)))
}

/// Prometheus metrics scrape endpoint
#[utoipa::path(
get,
Expand Down Expand Up @@ -1089,9 +1184,32 @@ pub async fn run(
.allow_headers([http::header::CONTENT_TYPE])
.allow_origin(allow_origin);

// Define VertextApiDoc conditionally only if the "google" feature is enabled
let doc = {
// avoid `mut` if possible
#[cfg(feature = "google")]
{
use crate::http::types::VertexInstance;

#[derive(OpenApi)]
#[openapi(
paths(vertex_compatibility),
components(schemas(VertexInstance, VertexRequest))
)]
struct VertextApiDoc;

// limiting mutability to the smallest scope necessary
let mut doc = ApiDoc::openapi();
doc.merge(VertextApiDoc::openapi());
doc
}
#[cfg(not(feature = "google"))]
ApiDoc::openapi()
};

// Create router
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", ApiDoc::openapi()))
let base_routes = Router::new()
.merge(SwaggerUi::new("/docs").url("/api-doc/openapi.json", doc))
// Base routes
.route("/info", get(get_model_info))
.route("/embed", post(embed))
Expand All @@ -1101,6 +1219,8 @@ pub async fn run(
.route("/tokenize", post(tokenize))
// OpenAI compat route
.route("/embeddings", post(openai_embed))
// Vertex compat route
.route("/vertex", post(vertex_compatibility))
Comment on lines +1222 to +1223
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be needed since we only need the endpoint when /env is set.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the TGI case we also always include the endpoint and only conditionally add the api docs. Should we avoid adding it in both cases if the feature is not on?

// Base Health route
.route("/health", get(health))
// Inference API health route
Expand All @@ -1110,8 +1230,10 @@ pub async fn run(
// Prometheus metrics route
.route("/metrics", get(metrics));

let mut app = Router::new().merge(base_routes);

// Set default routes
let app = match &info.model_type {
app = match &info.model_type {
ModelType::Classifier(_) => {
app.route("/", post(predict))
// AWS Sagemaker route
Expand All @@ -1129,6 +1251,20 @@ pub async fn run(
}
};

#[cfg(feature = "google")]
{
tracing::info!("Built with `google` feature");
tracing::info!(
"Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
);
if let Ok(env_predict_route) = std::env::var("AIP_PREDICT_ROUTE") {
app = app.route(&env_predict_route, post(vertex_compatibility));
}
if let Ok(env_health_route) = std::env::var("AIP_HEALTH_ROUTE") {
app = app.route(&env_health_route, get(health));
}
}

let app = app
.layer(Extension(infer))
.layer(Extension(info))
Expand Down
15 changes: 15 additions & 0 deletions router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,18 @@ pub(crate) struct SimpleToken {
#[derive(Serialize, ToSchema)]
#[schema(example = json!([[{"id": 0, "text": "test", "special": false, "start": 0, "stop": 2}]]))]
pub(crate) struct TokenizeResponse(pub Vec<Vec<SimpleToken>>);

#[derive(Clone, Deserialize, ToSchema)]
pub(crate) struct VertexInstance {
#[schema(example = "What is Deep Learning?")]
pub inputs: String,
#[serde(default)]
#[schema(default = "false", example = "false")]
pub truncate: bool,
}

#[derive(Deserialize, ToSchema)]
pub(crate) struct VertexRequest {
#[serde(rename = "instances")]
pub instances: Vec<VertexInstance>,
}
9 changes: 9 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,15 @@ pub async fn run(
docker_label: option_env!("DOCKER_LABEL"),
};

// Determine the server port based on the feature and environment variable.
let port = if cfg!(feature = "google") {
std::env::var("AIP_HTTP_PORT")
.map(|aip_http_port| aip_http_port.parse::<u16>().unwrap_or(port))
.unwrap_or(port)
} else {
port
};

let addr = match hostname.unwrap_or("0.0.0.0".to_string()).parse() {
Ok(ip) => SocketAddr::new(ip, port),
Err(_) => {
Expand Down