Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ serde_json = "1.0.140"
tokio = { version = "1.45.1", features = ["full"] }
tokio-stream = { version = "0.1.17", features = ["sync"] }
url = { version = "2.5.4", features = ["serde"] }
regex = "1.11.1"

chrono = { version = "0.4.41", optional = true }
crossterm = { version = "0.28.1", features = ["event-stream"], optional = true }
Expand Down
2 changes: 2 additions & 0 deletions resources/ts/components/AgentsList.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export function AgentsList({ agents }: { agents: Array<Agent> }) {
<thead>
<tr>
<th>Name</th>
<th>Model</th>
<th>Issues</th>
<th>Llama.cpp address</th>
<th>Last update</th>
Expand Down Expand Up @@ -54,6 +55,7 @@ export function AgentsList({ agents }: { agents: Array<Agent> }) {
key={agent_id}
>
<td>{status.agent_name}</td>
<td>{status.model}</td>
<td>
{status.error && (
<>
Expand Down
1 change: 1 addition & 0 deletions resources/ts/schemas/Agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { StatusUpdateSchema } from "./StatusUpdate";
export const AgentSchema = z
.object({
agent_id: z.string(),
model: z.string().nullable(),
last_update: z.object({
nanos_since_epoch: z.number(),
secs_since_epoch: z.number(),
Expand Down
1 change: 1 addition & 0 deletions resources/ts/schemas/StatusUpdate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export const StatusUpdateSchema = z
is_unexpected_response_status: z.boolean().nullable(),
slots_idle: z.number(),
slots_processing: z.number(),
model: z.string().nullable(),
})
.strict();

Expand Down
15 changes: 14 additions & 1 deletion src/agent/monitoring_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct MonitoringService {
monitoring_interval: Duration,
name: Option<String>,
status_update_tx: Sender<Bytes>,
check_model: bool, // Store the check_model flag
}

impl MonitoringService {
Expand All @@ -32,13 +33,15 @@ impl MonitoringService {
monitoring_interval: Duration,
name: Option<String>,
status_update_tx: Sender<Bytes>,
check_model: bool, // Include the check_model flag
) -> Result<Self> {
Ok(MonitoringService {
external_llamacpp_addr,
llamacpp_client,
monitoring_interval,
name,
status_update_tx,
check_model,
})
}

Expand All @@ -50,6 +53,15 @@ impl MonitoringService {
.filter(|slot| slot.is_processing)
.count();

let model: Option<String> = if self.check_model {
match self.llamacpp_client.get_model().await {
Ok(model) => model,
Err(_) => None,
}
} else {
Some("".to_string())
};

StatusUpdate {
agent_name: self.name.to_owned(),
error: slots_response.error,
Expand All @@ -63,6 +75,7 @@ impl MonitoringService {
is_unexpected_response_status: slots_response.is_unexpected_response_status,
slots_idle: slots_response.slots.len() - slots_processing,
slots_processing,
model,
}
}

Expand Down Expand Up @@ -109,4 +122,4 @@ impl Service for MonitoringService {
fn threads(&self) -> Option<usize> {
Some(1)
}
}
}
103 changes: 103 additions & 0 deletions src/balancer/proxy_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::time::Duration;
use async_trait::async_trait;
use bytes::Bytes;
use log::error;
use log::info;
use pingora::http::RequestHeader;
use pingora::proxy::ProxyHttp;
use pingora::proxy::Session;
Expand Down Expand Up @@ -41,20 +42,23 @@ pub struct ProxyService {
buffered_request_timeout: Duration,
max_buffered_requests: usize,
rewrite_host_header: bool,
check_model: bool,
slots_endpoint_enable: bool,
upstream_peer_pool: Arc<UpstreamPeerPool>,
}

impl ProxyService {
pub fn new(
rewrite_host_header: bool,
check_model: bool,
slots_endpoint_enable: bool,
upstream_peer_pool: Arc<UpstreamPeerPool>,
buffered_request_timeout: Duration,
max_buffered_requests: usize,
) -> Self {
Self {
rewrite_host_header,
check_model,
slots_endpoint_enable,
upstream_peer_pool,
buffered_request_timeout,
Expand All @@ -73,6 +77,7 @@ impl ProxyHttp for ProxyService {
slot_taken: false,
upstream_peer_pool: self.upstream_peer_pool.clone(),
uses_slots: false,
requested_model: Some("".to_string()),
}
}

Expand Down Expand Up @@ -180,10 +185,108 @@ impl ProxyHttp for ProxyService {
}
"/chat/completions" => true,
"/completion" => true,
"/v1/completions" => true,
"/v1/chat/completions" => true,
_ => false,
};

info!("upstream_peer - {:?} request | rewrite_host_header? {} check_model? {}", session.req_header().method, self.rewrite_host_header, self.check_model);

// Check if the request method is POST and the content type is JSON
if self.check_model && ctx.uses_slots {
info!("Checking model...");
ctx.requested_model = None;
if session.req_header().method == "POST" {
// Check if the content type is application/json
if let Some(content_type) = session.get_header("Content-Type") {
if let Ok(content_type_str) = content_type.to_str() {
if content_type_str.contains("application/json") {
// Enable retry buffering to preserve the request body, reference: https://github.com/cloudflare/pingora/issues/349#issuecomment-2377277028
session.enable_retry_buffering();
session.read_body_or_idle(false).await.unwrap().unwrap();
let request_body = session.get_retry_buffer();

if let Some(body_bytes) = request_body {
match std::str::from_utf8(&body_bytes) {
Ok(_) => {
// The bytes are valid UTF-8, proceed as normal
if let Ok(json_value) = serde_json::from_slice::<serde_json::Value>(&body_bytes) {
if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) {
ctx.requested_model = Some(model.to_string());
info!("Model in request: {:?}", ctx.requested_model);
}
} else {
info!("Failed to parse JSON payload, trying regex extraction");
let body_str = String::from_utf8_lossy(&body_bytes).to_string();
let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap();
if let Some(caps) = re.captures(&body_str) {
if let Some(model) = caps.get(1) {
ctx.requested_model = Some(model.as_str().to_string());
info!("Model via regex: {:?}", ctx.requested_model);
}
} else {
info!("Failed to extract model using regex");
}
}
},
Err(e) => {
// Invalid UTF-8 detected. Truncate to the last valid UTF-8 boundary.
let valid_up_to = e.valid_up_to();
info!("Invalid UTF-8 detected. Truncating from {} bytes to {} bytes.", body_bytes.len(), valid_up_to);

// Create a new `Bytes` slice containing only the valid UTF-8 part.
let valid_body_bytes = body_bytes.slice(0..valid_up_to);

// Now proceed with the (truncated) valid_body_bytes
if let Ok(json_value) = serde_json::from_slice::<serde_json::Value>(&valid_body_bytes) {
if let Some(model) = json_value.get("model").and_then(|v| v.as_str()) {
ctx.requested_model = Some(model.to_string());
info!("Model in request (after truncation): {:?}", ctx.requested_model);
}
} else {
info!("Failed to parse JSON payload (after truncation), trying regex extraction");
let body_str = String::from_utf8_lossy(&valid_body_bytes).to_string();
let re = regex::Regex::new(r#""model"\s*:\s*["']([^"']*)["']"#).unwrap();
if let Some(caps) = re.captures(&body_str) {
if let Some(model) = caps.get(1) {
ctx.requested_model = Some(model.as_str().to_string());
info!("Model via regex (after truncation): {:?}", ctx.requested_model);
}
} else {
info!("Failed to extract model using regex (after truncation)");
}
}
}
}
} else {
info!("Request body is None");
}
}
}
}
}
// abort if model has not been set
if ctx.requested_model == None {
info!("Model missing in request");
session
.respond_error(pingora::http::StatusCode::BAD_REQUEST.as_u16())
.await?;

return Err(Error::new_down(pingora::ErrorType::ConnectRefused));
}
else if ctx.has_peer_supporting_model() == false {
info!("Model {:?} not supported by upstream", ctx.requested_model);
session
.respond_error(pingora::http::StatusCode::NOT_FOUND.as_u16())
.await?;

return Err(Error::new_down(pingora::ErrorType::ConnectRefused));
}
else {
info!("Model {:?}", ctx.requested_model);
}
}

let peer = tokio::select! {
result = async {
loop {
Expand Down
35 changes: 28 additions & 7 deletions src/balancer/request_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;

use anyhow::anyhow;
use log::error;
use log::info;
use pingora::Error;
use pingora::Result;

Expand All @@ -13,6 +14,7 @@ pub struct RequestContext {
pub selected_peer: Option<UpstreamPeer>,
pub upstream_peer_pool: Arc<UpstreamPeerPool>,
pub uses_slots: bool,
pub requested_model: Option<String>,
}

impl RequestContext {
Expand All @@ -30,16 +32,19 @@ impl RequestContext {
}
}

pub fn use_best_peer_and_take_slot(&mut self) -> anyhow::Result<Option<UpstreamPeer>> {
pub fn use_best_peer_and_take_slot(&mut self, model: Option<String>) -> anyhow::Result<Option<UpstreamPeer>> {
if let Some(peer) = self.upstream_peer_pool.with_agents_write(|agents| {
let model_str = model.as_deref().unwrap_or("");
for peer in agents.iter_mut() {
if peer.is_usable() {
peer.take_slot()?;
let is_usable = peer.is_usable();
let is_usable_for_model = peer.is_usable_for_model(model_str);

if is_usable && (model.is_none() || is_usable_for_model) {
info!("Peer {} is usable: {}, usable for model '{}': {}", peer.agent_id, is_usable, model_str, is_usable_for_model);
peer.take_slot()?;
return Ok(Some(peer.clone()));
}
}

Ok(None)
})? {
self.upstream_peer_pool.restore_integrity()?;
Expand All @@ -52,11 +57,26 @@ impl RequestContext {
}
}

pub fn has_peer_supporting_model(&self) -> bool {
let model_str = self.requested_model.as_deref().unwrap_or("");
match self.upstream_peer_pool.with_agents_read(|agents| {
for peer in agents.iter() {
if peer.supports_model(model_str) {
return Ok(true);
}
}
Ok(false)
}) {
Ok(result) => result,
Err(_) => false, // or handle the error as needed
}
}

pub fn select_upstream_peer(&mut self) -> Result<()> {
let result_option_peer = if self.uses_slots && !self.slot_taken {
self.use_best_peer_and_take_slot()
self.use_best_peer_and_take_slot(self.requested_model.clone())
} else {
self.upstream_peer_pool.use_best_peer()
self.upstream_peer_pool.use_best_peer(self.requested_model.clone())
};

self.selected_peer = match result_option_peer {
Expand Down Expand Up @@ -95,6 +115,7 @@ mod tests {
selected_peer: None,
upstream_peer_pool,
uses_slots: true,
requested_model: Some("llama3".to_string()),
}
}

Expand All @@ -105,7 +126,7 @@ mod tests {

pool.register_status_update("test_agent", mock_status_update("test_agent", 0, 0))?;

assert!(ctx.use_best_peer_and_take_slot().unwrap().is_none());
assert!(ctx.use_best_peer_and_take_slot(ctx.requested_model.clone()).unwrap().is_none());

assert!(!ctx.slot_taken);
assert_eq!(ctx.selected_peer, None);
Expand Down
1 change: 1 addition & 0 deletions src/balancer/status_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct StatusUpdate {
pub is_unexpected_response_status: Option<bool>,
pub slots_idle: usize,
pub slots_processing: usize,
pub model: Option<String>,
}

impl StatusUpdate {
Expand Down
1 change: 1 addition & 0 deletions src/balancer/test/mock_status_update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,6 @@ pub fn mock_status_update(
is_unexpected_response_status: Some(false),
slots_idle,
slots_processing,
model: Some("llama3".to_string()),
}
}
Loading