Skip to content

Commit ee07862

Browse files
committed
add embedding result
1 parent 984144d commit ee07862

File tree

17 files changed

+151
-64
lines changed

17 files changed

+151
-64
lines changed

src/agent/jsonrpc/response.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@ use serde::Deserialize;
22
use serde::Serialize;
33

44
use crate::chat_template::ChatTemplate;
5+
use crate::embedding_result::EmbeddingResult;
56
use crate::generated_token_envelope::GeneratedTokenEnvelope;
67
use crate::model_metadata::ModelMetadata;
78

89
#[derive(Deserialize, Serialize)]
910
pub enum Response {
1011
ChatTemplateOverride(Option<ChatTemplate>),
12+
Embedding(EmbeddingResult),
1113
GeneratedToken(GeneratedTokenEnvelope),
1214
ModelMetadata(Option<ModelMetadata>),
1315
}

src/balancer/agent_controller.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,17 @@ use crate::atomic_value::AtomicValue;
2222
use crate::balancer::agent_controller_snapshot::AgentControllerSnapshot;
2323
use crate::balancer::agent_controller_update_result::AgentControllerUpdateResult;
2424
use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection;
25+
use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection;
2526
use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection;
2627
use crate::balancer::manages_senders_controller::ManagesSendersController;
2728
use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection;
2829
use crate::balancer::receive_tokens_controller::ReceiveTokensController;
30+
use crate::embedding_result::EmbeddingResult;
2931
use crate::jsonrpc::RequestEnvelope;
3032
use crate::produces_snapshot::ProducesSnapshot;
3133
use crate::request_params::ContinueFromConversationHistoryParams;
3234
use crate::request_params::ContinueFromRawPromptParams;
35+
use crate::request_params::GenerateEmbeddingBatchParams;
3336
use crate::sends_rpc_message::SendsRpcMessage;
3437
use crate::sets_desired_state::SetsDesiredState;
3538
use crate::slot_aggregated_status_snapshot::SlotAggregatedStatusSnapshot;
@@ -42,6 +45,7 @@ pub struct AgentController {
4245
pub download_current: AtomicValue<AtomicUsize>,
4346
pub download_filename: RwLock<Option<String>>,
4447
pub download_total: AtomicValue<AtomicUsize>,
48+
pub embedding_sender_collection: Arc<EmbeddingSenderCollection>,
4549
pub generate_tokens_sender_collection: Arc<GenerateTokensSenderCollection>,
4650
pub id: String,
4751
pub issues: RwLock<BTreeSet<AgentIssue>>,
@@ -59,15 +63,13 @@ impl AgentController {
5963
pub async fn continue_from_conversation_history(
6064
&self,
6165
request_id: String,
62-
continue_from_conversation_history_params: ContinueFromConversationHistoryParams,
66+
params: ContinueFromConversationHistoryParams,
6367
) -> Result<ReceiveTokensController> {
6468
self.receiver_from_message(
6569
request_id.clone(),
6670
AgentJsonRpcMessage::Request(RequestEnvelope {
6771
id: request_id,
68-
request: AgentJsonRpcRequest::ContinueFromConversationHistory(
69-
continue_from_conversation_history_params.clone(),
70-
),
72+
request: AgentJsonRpcRequest::ContinueFromConversationHistory(params.clone()),
7173
}),
7274
)
7375
.await
@@ -76,20 +78,32 @@ impl AgentController {
7678
pub async fn continue_from_raw_prompt(
7779
&self,
7880
request_id: String,
79-
continue_from_raw_prompt_params: ContinueFromRawPromptParams,
81+
params: ContinueFromRawPromptParams,
8082
) -> Result<ReceiveTokensController> {
8183
self.receiver_from_message(
8284
request_id.clone(),
8385
AgentJsonRpcMessage::Request(RequestEnvelope {
8486
id: request_id,
85-
request: AgentJsonRpcRequest::ContinueFromRawPrompt(
86-
continue_from_raw_prompt_params.clone(),
87-
),
87+
request: AgentJsonRpcRequest::ContinueFromRawPrompt(params.clone()),
8888
}),
8989
)
9090
.await
9191
}
9292

93+
pub async fn generate_embedding_batch(
94+
&self,
95+
embedding_rx: mpsc::UnboundedReceiver<EmbeddingResult>,
96+
embedding_tx: mpsc::UnboundedSender<EmbeddingResult>,
97+
request_id: String,
98+
params: GenerateEmbeddingBatchParams,
99+
) -> Result<ManagesSendersController<EmbeddingSenderCollection>> {
100+
Ok(ManagesSendersController {
101+
request_id,
102+
response_rx: embedding_rx,
103+
response_sender_collection: self.embedding_sender_collection.clone(),
104+
})
105+
}
106+
93107
pub async fn get_chat_template_override(
94108
&self,
95109
) -> Result<ManagesSendersController<ChatTemplateOverrideSenderCollection>> {
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use async_trait::async_trait;
2+
use dashmap::DashMap;
3+
use tokio::sync::mpsc;
4+
5+
use crate::balancer::manages_senders::ManagesSenders;
6+
use crate::embedding_result::EmbeddingResult;
7+
8+
pub struct EmbeddingSenderCollection {
9+
senders: DashMap<String, mpsc::UnboundedSender<EmbeddingResult>>,
10+
}
11+
12+
impl EmbeddingSenderCollection {
13+
pub fn new() -> Self {
14+
Self {
15+
senders: DashMap::new(),
16+
}
17+
}
18+
}
19+
20+
#[async_trait]
21+
impl ManagesSenders for EmbeddingSenderCollection {
22+
type Value = EmbeddingResult;
23+
24+
fn get_sender_collection(&self) -> &DashMap<String, mpsc::UnboundedSender<Self::Value>> {
25+
&self.senders
26+
}
27+
}

src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,17 @@ use actix_web::web;
77
use actix_web::Error;
88
use actix_web::HttpResponse;
99
use actix_web::Responder;
10-
use async_trait::async_trait;
1110
use bytes::Bytes;
1211
use futures::stream::StreamExt;
13-
use log::error;
1412
use nanoid::nanoid;
1513
use tokio::sync::broadcast;
1614
use tokio::sync::mpsc;
1715
use tokio_stream::wrappers::UnboundedReceiverStream;
1816

1917
use crate::balancer::inference_service::app_data::AppData;
20-
use crate::balancer::inference_service::chunk_forwarding_session_controller::ChunkForwardingSessionController;
21-
use crate::balancer::inference_service::controls_inference_endpoint::ControlsInferenceEndpoint;
2218
use crate::request_params::GenerateEmbeddingBatchParams;
2319

24-
const CHARACTERS_PER_TOKEN_MORE_OR_LESS: usize = 4;
20+
const CHARACTERS_PER_TOKEN_APPROXIMATELY: usize = 3;
2521

2622
pub fn register(cfg: &mut web::ServiceConfig) {
2723
cfg.service(respond);
@@ -48,12 +44,12 @@ async fn respond(
4844
));
4945
}
5046

51-
let request_id: String = nanoid!();
5247
let (connection_close_tx, mut connection_close_rx) = broadcast::channel::<()>(1);
5348
let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
5449

5550
let request_batches = params.chunk_by_input_size(
56-
agent_desired_state.inference_parameters.batch_n_tokens * CHARACTERS_PER_TOKEN_MORE_OR_LESS,
51+
agent_desired_state.inference_parameters.batch_n_tokens
52+
* CHARACTERS_PER_TOKEN_APPROXIMATELY,
5753
);
5854

5955
rt::spawn(async move {});

src/balancer/inference_service/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ impl Service for InferenceService {
8080
.configure(common_http_route::get_health::register)
8181
.configure(http_route::api::post_continue_from_conversation_history::register)
8282
.configure(http_route::api::post_continue_from_raw_prompt::register)
83+
.configure(http_route::api::post_generate_embedding_batch::register)
8384
.configure(http_route::api::ws_inference_socket::register)
8485
})
8586
.shutdown_signal(async move {

src/balancer/management_service/app_data.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::sync::Arc;
33
use crate::balancer::agent_controller_pool::AgentControllerPool;
44
use crate::balancer::buffered_request_manager::BufferedRequestManager;
55
use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection;
6+
use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection;
67
use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection;
78
use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection;
89
use crate::balancer::state_database::StateDatabase;
@@ -13,6 +14,7 @@ pub struct AppData {
1314
pub balancer_applicable_state_holder: Arc<BalancerApplicableStateHolder>,
1415
pub buffered_request_manager: Arc<BufferedRequestManager>,
1516
pub chat_template_override_sender_collection: Arc<ChatTemplateOverrideSenderCollection>,
17+
pub embedding_sender_collection: Arc<EmbeddingSenderCollection>,
1618
pub generate_tokens_sender_collection: Arc<GenerateTokensSenderCollection>,
1719
pub model_metadata_sender_collection: Arc<ModelMetadataSenderCollection>,
1820
pub state_database: Arc<dyn StateDatabase>,

src/balancer/management_service/http_route/api/ws_agent_socket/agent_socket_controller_context.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@ use log::info;
55

66
use crate::balancer::agent_controller_pool::AgentControllerPool;
77
use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection;
8+
use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection;
89
use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection;
910
use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection;
1011
use crate::balancer_applicable_state_holder::BalancerApplicableStateHolder;
1112

1213
pub struct AgentSocketControllerContext {
13-
pub balancer_applicable_state_holder: Arc<BalancerApplicableStateHolder>,
1414
pub agent_controller_pool: Arc<AgentControllerPool>,
1515
pub agent_id: String,
16+
pub balancer_applicable_state_holder: Arc<BalancerApplicableStateHolder>,
1617
pub chat_template_override_sender_collection: Arc<ChatTemplateOverrideSenderCollection>,
18+
pub embedding_sender_collection: Arc<EmbeddingSenderCollection>,
1719
pub generate_tokens_sender_collection: Arc<GenerateTokensSenderCollection>,
1820
pub model_metadata_sender_collection: Arc<ModelMetadataSenderCollection>,
1921
}

src/balancer/management_service/http_route/api/ws_agent_socket/mod.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use crate::balancer::agent_controller::AgentController;
4141
use crate::balancer::agent_controller_pool::AgentControllerPool;
4242
use crate::balancer::agent_controller_update_result::AgentControllerUpdateResult;
4343
use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection;
44+
use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection;
4445
use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection;
4546
use crate::balancer::management_service::app_data::AppData;
4647
use crate::balancer::manages_senders::ManagesSenders as _;
@@ -63,6 +64,7 @@ struct AgentSocketController {
6364
agent_id: String,
6465
balancer_applicable_state_holder: Arc<BalancerApplicableStateHolder>,
6566
chat_template_override_sender_collection: Arc<ChatTemplateOverrideSenderCollection>,
67+
embedding_sender_collection: Arc<EmbeddingSenderCollection>,
6668
generate_tokens_sender_collection: Arc<GenerateTokensSenderCollection>,
6769
model_metadata_sender_collection: Arc<ModelMetadataSenderCollection>,
6870
}
@@ -81,6 +83,7 @@ impl ControlsWebSocketEndpoint for AgentSocketController {
8183
chat_template_override_sender_collection: self
8284
.chat_template_override_sender_collection
8385
.clone(),
86+
embedding_sender_collection: self.embedding_sender_collection.clone(),
8487
generate_tokens_sender_collection: self.generate_tokens_sender_collection.clone(),
8588
model_metadata_sender_collection: self.model_metadata_sender_collection.clone(),
8689
}
@@ -136,6 +139,7 @@ impl ControlsWebSocketEndpoint for AgentSocketController {
136139
download_current: AtomicValue::<AtomicUsize>::new(download_current),
137140
download_filename: RwLock::new(download_filename),
138141
download_total: AtomicValue::<AtomicUsize>::new(download_total),
142+
embedding_sender_collection: context.embedding_sender_collection.clone(),
139143
generate_tokens_sender_collection: context
140144
.generate_tokens_sender_collection
141145
.clone(),
@@ -241,6 +245,17 @@ impl ControlsWebSocketEndpoint for AgentSocketController {
241245

242246
Ok(ContinuationDecision::Continue)
243247
}
248+
ManagementJsonRpcMessage::Response(ResponseEnvelope {
249+
request_id,
250+
response: AgentJsonRpcResponse::Embedding(embedding_result),
251+
}) => {
252+
context
253+
.embedding_sender_collection
254+
.forward_response_safe(request_id, embedding_result)
255+
.await;
256+
257+
Ok(ContinuationDecision::Continue)
258+
}
244259
ManagementJsonRpcMessage::Response(ResponseEnvelope {
245260
request_id,
246261
response: AgentJsonRpcResponse::GeneratedToken(generated_token_envelope),
@@ -310,6 +325,7 @@ async fn respond(
310325
chat_template_override_sender_collection: app_data
311326
.chat_template_override_sender_collection
312327
.clone(),
328+
embedding_sender_collection: app_data.embedding_sender_collection.clone(),
313329
generate_tokens_sender_collection: app_data.generate_tokens_sender_collection.clone(),
314330
model_metadata_sender_collection: app_data.model_metadata_sender_collection.clone(),
315331
};

src/balancer/management_service/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use tokio::sync::broadcast;
1515
use crate::balancer::agent_controller_pool::AgentControllerPool;
1616
use crate::balancer::buffered_request_manager::BufferedRequestManager;
1717
use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection;
18+
use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection;
1819
use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection;
1920
use crate::balancer::http_route as common_http_route;
2021
use crate::balancer::management_service::app_data::AppData;
@@ -33,6 +34,7 @@ pub struct ManagementService {
3334
pub buffered_request_manager: Arc<BufferedRequestManager>,
3435
pub chat_template_override_sender_collection: Arc<ChatTemplateOverrideSenderCollection>,
3536
pub configuration: ManagementServiceConfiguration,
37+
pub embedding_sender_collection: Arc<EmbeddingSenderCollection>,
3638
pub generate_tokens_sender_collection: Arc<GenerateTokensSenderCollection>,
3739
pub model_metadata_sender_collection: Arc<ModelMetadataSenderCollection>,
3840
pub state_database: Arc<dyn StateDatabase>,
@@ -64,6 +66,7 @@ impl Service for ManagementService {
6466
chat_template_override_sender_collection: self
6567
.chat_template_override_sender_collection
6668
.clone(),
69+
embedding_sender_collection: self.embedding_sender_collection.clone(),
6770
generate_tokens_sender_collection: self.generate_tokens_sender_collection.clone(),
6871
model_metadata_sender_collection: self.model_metadata_sender_collection.clone(),
6972
state_database: self.state_database.clone(),

src/balancer/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub mod buffered_request_manager;
1111
mod buffered_request_manager_snapshot;
1212
pub mod chat_template_override_sender_collection;
1313
mod controls_manages_senders_endpoint;
14+
pub mod embedding_sender_collection;
1415
pub mod generate_tokens_sender_collection;
1516
mod http_route;
1617
pub mod inference_service;

0 commit comments

Comments
 (0)