Skip to content

Commit da3fa1d

Browse files
committed
generalize agent response handlers
1 parent 5ac8574 commit da3fa1d

12 files changed

+130
-68
lines changed

src/agent/continue_from_conversation_history_request.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ pub struct ContinueFromConversationHistoryRequest {
1616

1717
impl FromRequestParams for ContinueFromConversationHistoryRequest {
1818
type RequestParams = ContinueFromConversationHistoryParams;
19+
type Response = GeneratedTokenEnvelope;
1920

2021
fn from_request_params(
2122
params: Self::RequestParams,
23+
generated_tokens_tx: mpsc::UnboundedSender<Self::Response>,
2224
generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>,
23-
generated_tokens_tx: mpsc::UnboundedSender<GeneratedTokenEnvelope>,
2425
) -> Self {
2526
ContinueFromConversationHistoryRequest {
2627
generate_tokens_stop_rx,

src/agent/continue_from_raw_prompt_request.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ pub struct ContinueFromRawPromptRequest {
1616

1717
impl FromRequestParams for ContinueFromRawPromptRequest {
1818
type RequestParams = ContinueFromRawPromptParams;
19+
type Response = GeneratedTokenEnvelope;
1920

2021
fn from_request_params(
2122
params: Self::RequestParams,
23+
generated_tokens_tx: mpsc::UnboundedSender<Self::Response>,
2224
generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>,
23-
generated_tokens_tx: mpsc::UnboundedSender<GeneratedTokenEnvelope>,
2425
) -> Self {
2526
ContinueFromRawPromptRequest {
2627
generate_tokens_stop_rx,

src/agent/from_request_params.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
use tokio::sync::mpsc;
22

3-
use crate::generated_token_envelope::GeneratedTokenEnvelope;
3+
use crate::agent::jsonrpc::response::Response;
44

55
pub trait FromRequestParams: Send + Sync {
66
type RequestParams;
7+
type Response: Into<Response>;
78

89
fn from_request_params(
9-
request_params: Self::RequestParams,
10-
generate_tokens_stop_rx: mpsc::UnboundedReceiver<()>,
11-
generated_tokens_tx: mpsc::UnboundedSender<GeneratedTokenEnvelope>,
10+
params: Self::RequestParams,
11+
response_tx: mpsc::UnboundedSender<Self::Response>,
12+
stop_rx: mpsc::UnboundedReceiver<()>,
1213
) -> Self;
1314
}

src/agent/generate_embedding_batch_request.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use actix::Message;
22
use anyhow::Result;
33
use tokio::sync::mpsc;
44

5+
use crate::agent::from_request_params::FromRequestParams;
56
use crate::embedding_result::EmbeddingResult;
67
use crate::request_params::GenerateEmbeddingBatchParams;
78

@@ -12,3 +13,20 @@ pub struct GenerateEmbeddingBatchRequest {
1213
pub generated_embedding_tx: mpsc::UnboundedSender<EmbeddingResult>,
1314
pub params: GenerateEmbeddingBatchParams,
1415
}
16+
17+
impl FromRequestParams for GenerateEmbeddingBatchRequest {
18+
type RequestParams = GenerateEmbeddingBatchParams;
19+
type Response = EmbeddingResult;
20+
21+
fn from_request_params(
22+
params: Self::RequestParams,
23+
generated_embedding_tx: mpsc::UnboundedSender<Self::Response>,
24+
generate_embedding_stop_rx: mpsc::UnboundedReceiver<()>,
25+
) -> Self {
26+
GenerateEmbeddingBatchRequest {
27+
generate_embedding_stop_rx,
28+
generated_embedding_tx,
29+
params,
30+
}
31+
}
32+
}

src/agent/jsonrpc/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mod message;
22
mod notification;
33
pub mod notification_params;
44
mod request;
5-
mod response;
5+
pub mod response;
66

77
pub use self::message::Message;
88
pub use self::notification::Notification;

src/agent/jsonrpc/response.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,27 @@ pub enum Response {
1313
GeneratedToken(GeneratedTokenEnvelope),
1414
ModelMetadata(Option<ModelMetadata>),
1515
}
16+
17+
impl From<Option<ChatTemplate>> for Response {
18+
fn from(chat_template: Option<ChatTemplate>) -> Self {
19+
Response::ChatTemplateOverride(chat_template)
20+
}
21+
}
22+
23+
impl From<EmbeddingResult> for Response {
24+
fn from(embedding_result: EmbeddingResult) -> Self {
25+
Response::Embedding(embedding_result)
26+
}
27+
}
28+
29+
impl From<GeneratedTokenEnvelope> for Response {
30+
fn from(generated_token_envelope: GeneratedTokenEnvelope) -> Self {
31+
Response::GeneratedToken(generated_token_envelope)
32+
}
33+
}
34+
35+
impl From<Option<ModelMetadata>> for Response {
36+
fn from(model_metadata: Option<ModelMetadata>) -> Self {
37+
Response::ModelMetadata(model_metadata)
38+
}
39+
}

src/agent/management_socket_client_service.rs

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use tokio_tungstenite::connect_async;
2020
use tokio_tungstenite::tungstenite::protocol::Message;
2121

2222
use crate::agent_desired_state::AgentDesiredState;
23-
use crate::agent::receive_tokens_stopper_collection::ReceiveTokensStopperCollection;
23+
use crate::agent::receive_stream_stopper_collection::ReceiveStreamStopperCollection;
2424
use crate::agent::jsonrpc::Message as JsonRpcMessage;
2525
use crate::agent::jsonrpc::Notification as JsonRpcNotification;
2626
use crate::agent::jsonrpc::Request as JsonRpcRequest;
@@ -40,7 +40,6 @@ use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonr
4040
use crate::jsonrpc::Error as JsonRpcError;
4141
use crate::jsonrpc::ResponseEnvelope;
4242
use crate::jsonrpc::RequestEnvelope;
43-
use crate::generated_token_envelope::GeneratedTokenEnvelope;
4443
use crate::produces_snapshot::ProducesSnapshot;
4544
use crate::jsonrpc::ErrorEnvelope;
4645
use crate::service::Service;
@@ -53,8 +52,9 @@ struct IncomingMessageContext {
5352
continue_from_conversation_history_request_tx:
5453
mpsc::UnboundedSender<ContinueFromConversationHistoryRequest>,
5554
continue_from_raw_prompt_request_tx: mpsc::UnboundedSender<ContinueFromRawPromptRequest>,
55+
generate_embedding_batch_request_tx: mpsc::UnboundedSender<GenerateEmbeddingBatchRequest>,
5656
model_metadata_holder: Arc<ModelMetadataHolder>,
57-
receive_tokens_stopper_collection: Arc<ReceiveTokensStopperCollection>,
57+
receive_stream_stopper_collection: Arc<ReceiveStreamStopperCollection>,
5858
message_tx: mpsc::UnboundedSender<ManagementJsonRpcMessage>,
5959
}
6060

@@ -67,47 +67,46 @@ pub struct ManagementSocketClientService {
6767
pub generate_embedding_batch_request_tx: mpsc::UnboundedSender<GenerateEmbeddingBatchRequest>,
6868
pub model_metadata_holder: Arc<ModelMetadataHolder>,
6969
pub name: Option<String>,
70-
pub receive_tokens_stopper_collection: Arc<ReceiveTokensStopperCollection>,
70+
pub receive_stream_stopper_collection: Arc<ReceiveStreamStopperCollection>,
7171
pub slot_aggregated_status: Arc<SlotAggregatedStatus>,
7272
pub socket_url: String,
7373
}
7474

7575
impl ManagementSocketClientService {
76-
async fn generate_tokens<TRequest: FromRequestParams + 'static>(
76+
async fn generate_responses<TRequest: FromRequestParams + 'static>(
7777
connection_close_tx: broadcast::Sender<()>,
7878
id: String,
7979
message_tx: mpsc::UnboundedSender<ManagementJsonRpcMessage>,
8080
request_params: TRequest::RequestParams,
81-
receive_tokens_stopper_collection: Arc<ReceiveTokensStopperCollection>,
81+
receive_stream_stopper_collection: Arc<ReceiveStreamStopperCollection>,
8282
request_tx: mpsc::UnboundedSender<TRequest>,
8383
) -> Result<()> {
84-
let (generated_tokens_tx, mut generated_tokens_rx) =
85-
mpsc::unbounded_channel::<GeneratedTokenEnvelope>();
86-
let (generate_tokens_stop_tx, generate_tokens_stop_rx) = mpsc::unbounded_channel::<()>();
84+
let (response_tx, mut response_rx) = mpsc::unbounded_channel::<TRequest::Response>();
85+
let (stop_tx, stop_rx) = mpsc::unbounded_channel::<()>();
8786

88-
let _guard = receive_tokens_stopper_collection
89-
.register_stopper_with_guard(id.clone(), generate_tokens_stop_tx)
90-
.context(format!("Failed to register stopper for request ID: {id}"))?;
87+
let _guard = receive_stream_stopper_collection
88+
.register_stopper_with_guard(id.clone(), stop_tx)
89+
.context(format!("Failed to register stopper for request: {id}"))?;
9190

9291
request_tx.send(TRequest::from_request_params(
9392
request_params,
94-
generate_tokens_stop_rx,
95-
generated_tokens_tx,
93+
response_tx,
94+
stop_rx,
9695
))?;
9796

9897
let mut connection_close_rx = connection_close_tx.subscribe();
9998

10099
loop {
101100
tokio::select! {
102101
_ = connection_close_rx.recv() => break,
103-
generated_token_envelope = generated_tokens_rx.recv() => {
104-
match generated_token_envelope {
105-
Some(generated_token_envelope) => {
102+
response = response_rx.recv() => {
103+
match response {
104+
Some(response) => {
106105
message_tx.send(
107106
ManagementJsonRpcMessage::Response(
108107
ResponseEnvelope {
109108
request_id: id.clone(),
110-
response: JsonRpcResponse::GeneratedToken(generated_token_envelope),
109+
response: response.into(),
111110
}
112111
),
113112
)?;
@@ -128,9 +127,10 @@ impl ManagementSocketClientService {
128127
connection_close_tx,
129128
continue_from_conversation_history_request_tx,
130129
continue_from_raw_prompt_request_tx,
130+
generate_embedding_batch_request_tx,
131131
message_tx,
132132
model_metadata_holder,
133-
receive_tokens_stopper_collection,
133+
receive_stream_stopper_collection,
134134
}: IncomingMessageContext,
135135
deserialized_message: JsonRpcMessage,
136136
) -> Result<()> {
@@ -154,7 +154,7 @@ impl ManagementSocketClientService {
154154
}
155155
JsonRpcMessage::Notification(JsonRpcNotification::StopGeneratingTokens(request_id)) => {
156156
debug!("Received StopGeneratingTokens notification for request ID: {request_id:?}");
157-
receive_tokens_stopper_collection
157+
receive_stream_stopper_collection
158158
.stop(request_id.clone())
159159
.context(format!(
160160
"Failed to stop generating tokens for request ID: {request_id}"
@@ -181,12 +181,12 @@ impl ManagementSocketClientService {
181181
continue_from_conversation_history_params,
182182
),
183183
}) => {
184-
Self::generate_tokens(
184+
Self::generate_responses(
185185
connection_close_tx,
186186
id,
187187
message_tx,
188188
continue_from_conversation_history_params,
189-
receive_tokens_stopper_collection,
189+
receive_stream_stopper_collection,
190190
continue_from_conversation_history_request_tx,
191191
)
192192
.await
@@ -195,20 +195,30 @@ impl ManagementSocketClientService {
195195
id,
196196
request: JsonRpcRequest::ContinueFromRawPrompt(generate_tokens_params),
197197
}) => {
198-
Self::generate_tokens(
198+
Self::generate_responses(
199199
connection_close_tx,
200200
id,
201201
message_tx,
202202
generate_tokens_params,
203-
receive_tokens_stopper_collection,
203+
receive_stream_stopper_collection,
204204
continue_from_raw_prompt_request_tx,
205205
)
206206
.await
207207
}
208208
JsonRpcMessage::Request(RequestEnvelope {
209209
id,
210210
request: JsonRpcRequest::GenerateEmbeddingBatch(generate_embedding_batch_params),
211-
}) => Ok(()),
211+
}) => {
212+
Self::generate_responses(
213+
connection_close_tx,
214+
id,
215+
message_tx,
216+
generate_embedding_batch_params,
217+
receive_stream_stopper_collection,
218+
generate_embedding_batch_request_tx,
219+
)
220+
.await
221+
}
212222
JsonRpcMessage::Request(RequestEnvelope {
213223
id,
214224
request: JsonRpcRequest::GetChatTemplateOverride,
@@ -429,8 +439,9 @@ impl ManagementSocketClientService {
429439
connection_close_tx: connection_close_tx.clone(),
430440
continue_from_conversation_history_request_tx: self.continue_from_conversation_history_request_tx.clone(),
431441
continue_from_raw_prompt_request_tx: self.continue_from_raw_prompt_request_tx.clone(),
442+
generate_embedding_batch_request_tx: self.generate_embedding_batch_request_tx.clone(),
432443
model_metadata_holder: self.model_metadata_holder.clone(),
433-
receive_tokens_stopper_collection: self.receive_tokens_stopper_collection.clone(),
444+
receive_stream_stopper_collection: self.receive_stream_stopper_collection.clone(),
434445
message_tx: message_tx.clone(),
435446
},
436447
msg,

src/agent/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ mod llamacpp_slot;
1111
mod llamacpp_slot_context;
1212
pub mod management_socket_client_service;
1313
pub mod model_metadata_holder;
14-
pub mod receive_tokens_stopper_collection;
15-
mod receive_tokens_stopper_drop_guard;
14+
pub mod receive_stream_stopper_collection;
15+
mod receive_stream_stopper_drop_guard;
1616
pub mod reconciliation_service;

src/agent/receive_tokens_stopper_collection.rs renamed to src/agent/receive_stream_stopper_collection.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,21 @@ use anyhow::Result;
55
use dashmap::DashMap;
66
use tokio::sync::mpsc;
77

8-
use crate::agent::receive_tokens_stopper_drop_guard::ReceiveTokensStopperDropGuard;
8+
use crate::agent::receive_stream_stopper_drop_guard::ReceiveStreamStopperDropGuard;
99

10-
pub struct ReceiveTokensStopperCollection {
11-
receive_tokens_stoppers: DashMap<String, mpsc::UnboundedSender<()>>,
10+
pub struct ReceiveStreamStopperCollection {
11+
receive_stoppers: DashMap<String, mpsc::UnboundedSender<()>>,
1212
}
1313

14-
impl ReceiveTokensStopperCollection {
14+
impl ReceiveStreamStopperCollection {
1515
pub fn new() -> Self {
1616
Self {
17-
receive_tokens_stoppers: DashMap::new(),
17+
receive_stoppers: DashMap::new(),
1818
}
1919
}
2020

2121
pub fn deregister_stopper(&self, request_id: String) -> Result<()> {
22-
if let Some(stopper) = self.receive_tokens_stoppers.remove(&request_id) {
22+
if let Some(stopper) = self.receive_stoppers.remove(&request_id) {
2323
drop(stopper);
2424

2525
Ok(())
@@ -33,13 +33,13 @@ impl ReceiveTokensStopperCollection {
3333
request_id: String,
3434
stopper: mpsc::UnboundedSender<()>,
3535
) -> Result<()> {
36-
if self.receive_tokens_stoppers.contains_key(&request_id) {
36+
if self.receive_stoppers.contains_key(&request_id) {
3737
return Err(anyhow!(
3838
"Stopper for request_id {request_id} already exists"
3939
));
4040
}
4141

42-
self.receive_tokens_stoppers.insert(request_id, stopper);
42+
self.receive_stoppers.insert(request_id, stopper);
4343

4444
Ok(())
4545
}
@@ -48,17 +48,17 @@ impl ReceiveTokensStopperCollection {
4848
self: &Arc<Self>,
4949
request_id: String,
5050
stopper: mpsc::UnboundedSender<()>,
51-
) -> Result<ReceiveTokensStopperDropGuard> {
51+
) -> Result<ReceiveStreamStopperDropGuard> {
5252
self.register_stopper(request_id.clone(), stopper)?;
5353

54-
Ok(ReceiveTokensStopperDropGuard {
55-
receive_tokens_stopper_collection: self.clone(),
54+
Ok(ReceiveStreamStopperDropGuard {
55+
receive_stream_stopper_collection: self.clone(),
5656
request_id,
5757
})
5858
}
5959

6060
pub fn stop(&self, request_id: String) -> Result<()> {
61-
if let Some(stopper) = self.receive_tokens_stoppers.get(&request_id) {
61+
if let Some(stopper) = self.receive_stoppers.get(&request_id) {
6262
stopper.send(())?;
6363

6464
Ok(())

src/agent/receive_tokens_stopper_drop_guard.rs renamed to src/agent/receive_stream_stopper_drop_guard.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@ use std::sync::Arc;
22

33
use log::error;
44

5-
use crate::agent::receive_tokens_stopper_collection::ReceiveTokensStopperCollection;
5+
use crate::agent::receive_stream_stopper_collection::ReceiveStreamStopperCollection;
66

7-
pub struct ReceiveTokensStopperDropGuard {
8-
pub receive_tokens_stopper_collection: Arc<ReceiveTokensStopperCollection>,
7+
pub struct ReceiveStreamStopperDropGuard {
8+
pub receive_stream_stopper_collection: Arc<ReceiveStreamStopperCollection>,
99
pub request_id: String,
1010
}
1111

12-
impl Drop for ReceiveTokensStopperDropGuard {
12+
impl Drop for ReceiveStreamStopperDropGuard {
1313
fn drop(&mut self) {
1414
if let Err(err) = self
15-
.receive_tokens_stopper_collection
15+
.receive_stream_stopper_collection
1616
.deregister_stopper(self.request_id.clone())
1717
{
1818
error!(

0 commit comments

Comments
 (0)