Skip to content

Commit c1efd4a

Browse files
committed
separate serializer for openai-like combined response
1 parent f696a90 commit c1efd4a

File tree

1 file changed

+37
-10
lines changed

1 file changed

+37
-10
lines changed

src/balancer/compatibility/openai_service/http_route/post_chat_completions.rs

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use actix_web::Error;
55
use actix_web::HttpResponse;
66
use actix_web::post;
77
use actix_web::web;
8+
use anyhow::anyhow;
89
use async_trait::async_trait;
910
use nanoid::nanoid;
1011
use serde::Deserialize;
@@ -61,16 +62,19 @@ struct OpenAICompletionRequestParams {
6162
}
6263

6364
#[derive(Clone)]
64-
struct OpenAIResponseTransformer {
65+
struct OpenAIStreamingResponseTransformer {
6566
model: String,
6667
system_fingerprint: String,
6768
}
6869

6970
#[async_trait]
70-
impl TransformsOutgoingMessage for OpenAIResponseTransformer {
71+
impl TransformsOutgoingMessage for OpenAIStreamingResponseTransformer {
7172
type TransformedMessage = serde_json::Value;
7273

73-
async fn transform(&self, message: OutgoingMessage) -> anyhow::Result<serde_json::Value> {
74+
async fn transform(
75+
&self,
76+
message: OutgoingMessage,
77+
) -> anyhow::Result<Self::TransformedMessage> {
7478
match message {
7579
OutgoingMessage::Response(ResponseEnvelope {
7680
request_id,
@@ -116,6 +120,31 @@ impl TransformsOutgoingMessage for OpenAIResponseTransformer {
116120
}
117121
}
118122

123+
#[derive(Clone)]
124+
struct OpenAICombinedResponseTransfomer {}
125+
126+
#[async_trait]
127+
impl TransformsOutgoingMessage for OpenAICombinedResponseTransfomer {
128+
type TransformedMessage = String;
129+
130+
async fn transform(
131+
&self,
132+
message: OutgoingMessage,
133+
) -> anyhow::Result<Self::TransformedMessage> {
134+
match message {
135+
OutgoingMessage::Response(ResponseEnvelope {
136+
response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Done),
137+
..
138+
}) => Ok("".to_string()),
139+
OutgoingMessage::Response(ResponseEnvelope {
140+
response: OutgoingResponse::GeneratedToken(GeneratedTokenResult::Token(token)),
141+
..
142+
}) => Ok(token),
143+
_ => Err(anyhow!("Unexpected message type: {:?}", message)),
144+
}
145+
}
146+
}
147+
119148
#[post("/v1/chat_completions")]
120149
async fn respond(
121150
app_data: web::Data<AppData>,
@@ -133,24 +162,22 @@ async fn respond(
133162
tools: vec![],
134163
};
135164

136-
let response_transformer = OpenAIResponseTransformer {
137-
model: openai_params.model.clone(),
138-
system_fingerprint: nanoid!(),
139-
};
140-
141165
if openai_params.stream {
142166
http_stream_from_agent(
143167
app_data.buffered_request_manager.clone(),
144168
app_data.inference_service_configuration.clone(),
145169
paddler_params,
146-
response_transformer,
170+
OpenAIStreamingResponseTransformer {
171+
model: openai_params.model.clone(),
172+
system_fingerprint: nanoid!(),
173+
},
147174
)
148175
} else {
149176
let combined_response = unbounded_stream_from_agent(
150177
app_data.buffered_request_manager.clone(),
151178
app_data.inference_service_configuration.clone(),
152179
paddler_params,
153-
response_transformer,
180+
OpenAICombinedResponseTransfomer {},
154181
)?
155182
.collect::<Vec<String>>()
156183
.await

0 commit comments

Comments
 (0)