Skip to content

Commit 7cb30b2

Browse files
committed
add message transformer
1 parent fdefccc commit 7cb30b2

File tree

15 files changed

+209
-150
lines changed

15 files changed

+209
-150
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use anyhow::Result;
2+
use async_trait::async_trait;
3+
4+
use super::transforms_outgoing_message::TransformsOutgoingMessage;
5+
use crate::balancer::inference_client::Message as OutgoingMessage;
6+
7+
#[derive(Clone)]
8+
pub struct IdentityTransformer;
9+
10+
impl IdentityTransformer {
11+
pub fn new() -> Self {
12+
Self {}
13+
}
14+
}
15+
16+
#[async_trait]
17+
impl TransformsOutgoingMessage for IdentityTransformer {
18+
type TransformedMessage = OutgoingMessage;
19+
20+
async fn transform(&self, message: OutgoingMessage) -> Result<Self::TransformedMessage> {
21+
return Ok(message);
22+
}
23+
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
pub mod identity_transformer;
2+
pub mod transforms_outgoing_message;
3+
4+
use async_trait::async_trait;
5+
use tokio::sync::mpsc;
6+
7+
use self::transforms_outgoing_message::TransformsOutgoingMessage;
8+
use crate::balancer::inference_client::Message as OutgoingMessage;
9+
use crate::controls_session::ControlsSession;
10+
11+
#[derive(Clone)]
12+
pub struct ChunkForwardingSessionController<TTransformsOutgoingMessage>
13+
where
14+
TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync,
15+
{
16+
chunk_tx: mpsc::UnboundedSender<String>,
17+
transformer: TTransformsOutgoingMessage,
18+
}
19+
20+
impl<TTransformsOutgoingMessage> ChunkForwardingSessionController<TTransformsOutgoingMessage>
21+
where
22+
TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync,
23+
{
24+
pub fn new(
25+
chunk_tx: mpsc::UnboundedSender<String>,
26+
transformer: TTransformsOutgoingMessage,
27+
) -> Self {
28+
Self {
29+
chunk_tx,
30+
transformer,
31+
}
32+
}
33+
}
34+
35+
#[async_trait]
36+
impl<TTransformsOutgoingMessage> ControlsSession<OutgoingMessage>
37+
for ChunkForwardingSessionController<TTransformsOutgoingMessage>
38+
where
39+
TTransformsOutgoingMessage: Clone + TransformsOutgoingMessage + Send + Sync,
40+
{
41+
async fn send_response(&mut self, message: OutgoingMessage) -> anyhow::Result<()> {
42+
let transformed_message = self.transformer.transform(message).await?;
43+
let stringified_message = serde_json::to_string(&transformed_message)?;
44+
45+
self.chunk_tx.send(stringified_message)?;
46+
47+
Ok(())
48+
}
49+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use anyhow::Result;
2+
use async_trait::async_trait;
3+
use serde::Serialize;
4+
5+
use crate::balancer::inference_client::Message as OutgoingMessage;
6+
7+
#[async_trait]
8+
pub trait TransformsOutgoingMessage {
9+
type TransformedMessage: Serialize;
10+
11+
async fn transform(&self, message: OutgoingMessage) -> Result<Self::TransformedMessage>;
12+
}
File renamed without changes.
File renamed without changes.
File renamed without changes.

src/balancer/inference_service/chunk_forwarding_session_controller.rs

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 16 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,15 @@
1-
use actix_web::http::header;
2-
use log::error;
31
use actix_web::post;
42
use actix_web::error::ErrorBadRequest;
5-
use actix_web::rt;
63
use actix_web::web;
74
use actix_web::Error;
8-
use actix_web::HttpResponse;
95
use actix_web::Responder;
10-
use bytes::Bytes;
11-
use futures::stream::StreamExt;
12-
use nanoid::nanoid;
13-
use tokio::sync::broadcast;
14-
use tokio::sync::mpsc;
15-
use tokio_stream::wrappers::UnboundedReceiverStream;
166

177
use crate::validates::Validates as _;
188
use crate::balancer::inference_service::app_data::AppData;
19-
use crate::balancer::inference_service::chunk_forwarding_session_controller::ChunkForwardingSessionController;
20-
use crate::balancer::request_from_agent::request_from_agent;
219
use crate::request_params::ContinueFromConversationHistoryParams;
2210
use crate::request_params::continue_from_conversation_history_params::tool::tool_params::function_call::parameters_schema::raw_parameters_schema::RawParametersSchema;
23-
use crate::jsonrpc::Error as JsonRpcError;
24-
use crate::controls_session::ControlsSession as _;
25-
use crate::jsonrpc::ErrorEnvelope;
26-
use crate::balancer::inference_service::http_route::api::ws_inference_socket::client::Message as OutgoingMessage;
11+
use crate::balancer::stream_from_agent::stream_from_agent;
12+
use crate::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer;
2713

2814
pub fn register(cfg: &mut web::ServiceConfig) {
2915
cfg.service(respond);
@@ -34,50 +20,18 @@ async fn respond(
3420
app_data: web::Data<AppData>,
3521
params: web::Json<ContinueFromConversationHistoryParams<RawParametersSchema>>,
3622
) -> Result<impl Responder, Error> {
37-
let request_id: String = nanoid!();
38-
let (connection_close_tx, _connection_close_rx) = broadcast::channel(1);
39-
let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
40-
41-
let validated_params = match params.into_inner().validate() {
42-
Ok(validated_params) => validated_params,
43-
Err(validation_error) => {
44-
return Err(ErrorBadRequest(format!(
45-
"Invalid request parameters: {validation_error}"
46-
)));
47-
}
48-
};
49-
50-
rt::spawn(async move {
51-
let mut session_controller = ChunkForwardingSessionController { chunk_tx };
52-
53-
if let Err(err) = request_from_agent(
54-
app_data.buffered_request_manager.clone(),
55-
connection_close_tx,
56-
app_data.inference_service_configuration.clone(),
57-
validated_params,
58-
request_id.clone(),
59-
session_controller.clone(),
60-
)
61-
.await
62-
{
63-
error!("Failed to handle request: {err}");
64-
session_controller
65-
.send_response_safe(OutgoingMessage::Error(ErrorEnvelope {
66-
request_id: request_id.clone(),
67-
error: JsonRpcError {
68-
code: 500,
69-
description: format!("Request {request_id} failed: {err}"),
70-
},
71-
}))
72-
.await;
73-
}
74-
});
75-
76-
let stream = UnboundedReceiverStream::new(chunk_rx)
77-
.map(|chunk: String| Ok::<_, Error>(Bytes::from(format!("{chunk}\n"))));
78-
79-
Ok(HttpResponse::Ok()
80-
.insert_header(header::ContentType::json())
81-
.insert_header((header::CACHE_CONTROL, "no-cache"))
82-
.streaming(stream))
23+
stream_from_agent(
24+
app_data.buffered_request_manager.clone(),
25+
app_data.inference_service_configuration.clone(),
26+
match params.into_inner().validate() {
27+
Ok(validated_params) => validated_params,
28+
Err(validation_error) => {
29+
return Err(ErrorBadRequest(format!(
30+
"Invalid request parameters: {validation_error}"
31+
)));
32+
}
33+
},
34+
IdentityTransformer::new(),
35+
)
36+
.await
8337
}
Lines changed: 9 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,11 @@
11
use actix_web::Error;
2-
use actix_web::HttpResponse;
32
use actix_web::Responder;
4-
use actix_web::http::header;
53
use actix_web::post;
6-
use actix_web::rt;
74
use actix_web::web;
8-
use bytes::Bytes;
9-
use futures::stream::StreamExt;
10-
use log::error;
11-
use nanoid::nanoid;
12-
use tokio::sync::broadcast;
13-
use tokio::sync::mpsc;
14-
use tokio_stream::wrappers::UnboundedReceiverStream;
155

6+
use crate::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer;
167
use crate::balancer::inference_service::app_data::AppData;
17-
use crate::balancer::inference_service::chunk_forwarding_session_controller::ChunkForwardingSessionController;
18-
use crate::balancer::inference_service::http_route::api::ws_inference_socket::client::Message as OutgoingMessage;
19-
use crate::balancer::request_from_agent::request_from_agent;
20-
use crate::controls_session::ControlsSession as _;
21-
use crate::jsonrpc::Error as JsonRpcError;
22-
use crate::jsonrpc::ErrorEnvelope;
8+
use crate::balancer::stream_from_agent::stream_from_agent;
239
use crate::request_params::ContinueFromRawPromptParams;
2410

2511
pub fn register(cfg: &mut web::ServiceConfig) {
@@ -31,41 +17,11 @@ async fn respond(
3117
app_data: web::Data<AppData>,
3218
params: web::Json<ContinueFromRawPromptParams>,
3319
) -> Result<impl Responder, Error> {
34-
let request_id: String = nanoid!();
35-
let (connection_close_tx, _connection_close_rx) = broadcast::channel(1);
36-
let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
37-
38-
rt::spawn(async move {
39-
let mut session_controller = ChunkForwardingSessionController { chunk_tx };
40-
41-
if let Err(err) = request_from_agent(
42-
app_data.buffered_request_manager.clone(),
43-
connection_close_tx,
44-
app_data.inference_service_configuration.clone(),
45-
params.into_inner(),
46-
request_id.clone(),
47-
session_controller.clone(),
48-
)
49-
.await
50-
{
51-
error!("Failed to handle request: {err}");
52-
session_controller
53-
.send_response_safe(OutgoingMessage::Error(ErrorEnvelope {
54-
request_id: request_id.clone(),
55-
error: JsonRpcError {
56-
code: 500,
57-
description: format!("Request {request_id} failed: {err}"),
58-
},
59-
}))
60-
.await;
61-
}
62-
});
63-
64-
let stream = UnboundedReceiverStream::new(chunk_rx)
65-
.map(|chunk: String| Ok::<_, Error>(Bytes::from(format!("{chunk}\n"))));
66-
67-
Ok(HttpResponse::Ok()
68-
.insert_header(header::ContentType::json())
69-
.insert_header((header::CACHE_CONTROL, "no-cache"))
70-
.streaming(stream))
20+
stream_from_agent(
21+
app_data.buffered_request_manager.clone(),
22+
app_data.inference_service_configuration.clone(),
23+
params.into_inner(),
24+
IdentityTransformer::new(),
25+
)
26+
.await
7127
}

src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ use tokio::sync::broadcast;
1515
use tokio::sync::mpsc;
1616
use tokio_stream::wrappers::UnboundedReceiverStream;
1717

18+
use crate::balancer::chunk_forwarding_session_controller::ChunkForwardingSessionController;
19+
use crate::balancer::chunk_forwarding_session_controller::identity_transformer::IdentityTransformer;
20+
use crate::balancer::inference_client::Message as OutgoingMessage;
1821
use crate::balancer::inference_service::app_data::AppData;
19-
use crate::balancer::inference_service::chunk_forwarding_session_controller::ChunkForwardingSessionController;
20-
use crate::balancer::inference_service::http_route::api::ws_inference_socket::client::Message as OutgoingMessage;
2122
use crate::balancer::request_from_agent::request_from_agent;
2223
use crate::controls_session::ControlsSession as _;
2324
use crate::jsonrpc::Error as JsonRpcError;
@@ -67,9 +68,8 @@ async fn respond(
6768

6869
rt::spawn(async move {
6970
let request_id: String = nanoid!();
70-
let mut session_controller = ChunkForwardingSessionController {
71-
chunk_tx: chunk_tx_clone,
72-
};
71+
let mut session_controller =
72+
ChunkForwardingSessionController::new(chunk_tx_clone, IdentityTransformer::new());
7373

7474
if let Err(err) = request_from_agent(
7575
buffered_request_manager_clone,

0 commit comments

Comments
 (0)