Skip to content

Commit 547d2a9

Browse files
committed
implement 'manages_sender_collection' in generate_tokens_sender_collection
1 parent ee07862 commit 547d2a9

File tree

4 files changed

+15
-47
lines changed

4 files changed

+15
-47
lines changed

src/balancer/agent_controller.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::balancer::agent_controller_update_result::AgentControllerUpdateResult
2424
use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection;
2525
use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection;
2626
use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection;
27+
use crate::balancer::manages_senders::ManagesSenders as _;
2728
use crate::balancer::manages_senders_controller::ManagesSendersController;
2829
use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection;
2930
use crate::balancer::receive_tokens_controller::ReceiveTokensController;
Lines changed: 10 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,27 @@
1-
use anyhow::anyhow;
2-
use anyhow::Result;
1+
use async_trait::async_trait;
32
use dashmap::DashMap;
43
use tokio::sync::mpsc;
54

5+
use crate::balancer::manages_senders::ManagesSenders;
66
use crate::generated_token_envelope::GeneratedTokenEnvelope;
77

88
pub struct GenerateTokensSenderCollection {
9-
generate_tokens_senders: DashMap<String, mpsc::UnboundedSender<GeneratedTokenEnvelope>>,
9+
senders: DashMap<String, mpsc::UnboundedSender<GeneratedTokenEnvelope>>,
1010
}
1111

1212
impl GenerateTokensSenderCollection {
1313
pub fn new() -> Self {
1414
Self {
15-
generate_tokens_senders: DashMap::new(),
15+
senders: DashMap::new(),
1616
}
1717
}
18+
}
1819

19-
pub fn deregister_sender(&self, request_id: String) -> Result<()> {
20-
if let Some(sender) = self.generate_tokens_senders.remove(&request_id) {
21-
drop(sender);
22-
23-
Ok(())
24-
} else {
25-
Err(anyhow!("No sender found for request_id {request_id}"))
26-
}
27-
}
28-
29-
pub async fn forward_generated_token(
30-
&self,
31-
request_id: String,
32-
generated_token_envelope: GeneratedTokenEnvelope,
33-
) -> Result<()> {
34-
if let Some(sender) = self.generate_tokens_senders.get(&request_id) {
35-
sender.send(generated_token_envelope)?;
36-
37-
Ok(())
38-
} else {
39-
Err(anyhow!("No sender found for request_id {request_id}"))
40-
}
41-
}
42-
43-
pub fn register_sender(
44-
&self,
45-
request_id: String,
46-
sender: mpsc::UnboundedSender<GeneratedTokenEnvelope>,
47-
) -> Result<()> {
48-
if self.generate_tokens_senders.contains_key(&request_id) {
49-
return Err(anyhow!("Sender for request_id {request_id} already exists"));
50-
}
51-
52-
self.generate_tokens_senders.insert(request_id, sender);
20+
#[async_trait]
21+
impl ManagesSenders for GenerateTokensSenderCollection {
22+
type Value = GeneratedTokenEnvelope;
5323

54-
Ok(())
24+
fn get_sender_collection(&self) -> &DashMap<String, mpsc::UnboundedSender<Self::Value>> {
25+
&self.senders
5526
}
5627
}

src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ use anyhow::Result;
2222
use async_trait::async_trait;
2323
use log::error;
2424
use log::info;
25-
use log::warn;
2625
use serde::Deserialize;
2726
use tokio::sync::broadcast;
2827
use tokio::sync::mpsc;
@@ -260,14 +259,10 @@ impl ControlsWebSocketEndpoint for AgentSocketController {
260259
request_id,
261260
response: AgentJsonRpcResponse::GeneratedToken(generated_token_envelope),
262261
}) => {
263-
if let Err(err) = context
262+
context
264263
.generate_tokens_sender_collection
265-
.forward_generated_token(request_id, generated_token_envelope)
266-
.await
267-
{
268-
// Token might come in after the sender was deregistered
269-
warn!("Error forwarding generated token: {err}");
270-
}
264+
.forward_response_safe(request_id, generated_token_envelope)
265+
.await;
271266

272267
Ok(ContinuationDecision::Continue)
273268
}

src/balancer/receive_tokens_controller.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use log::error;
44
use tokio::sync::mpsc;
55

66
use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection;
7+
use crate::balancer::manages_senders::ManagesSenders as _;
78
use crate::generated_token_envelope::GeneratedTokenEnvelope;
89

910
pub struct ReceiveTokensController {

0 commit comments

Comments
 (0)