Skip to content

Commit b5c20fa

Browse files
committed
distribute embeddings among agents
1 parent 77f784a commit b5c20fa

File tree

9 files changed

+117
-48
lines changed

9 files changed

+117
-48
lines changed

src/balancer/agent_controller.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ use crate::balancer::agent_controller_update_result::AgentControllerUpdateResult
2424
use crate::balancer::chat_template_override_sender_collection::ChatTemplateOverrideSenderCollection;
2525
use crate::balancer::embedding_sender_collection::EmbeddingSenderCollection;
2626
use crate::balancer::generate_tokens_sender_collection::GenerateTokensSenderCollection;
27-
use crate::balancer::handles_agent_request::HandlesAgentRequest;
28-
use crate::balancer::manages_senders::ManagesSenders as _;
27+
use crate::balancer::handles_agent_streaming_response::HandlesAgentStreamingResponse;
2928
use crate::balancer::manages_senders::ManagesSenders;
3029
use crate::balancer::manages_senders_controller::ManagesSendersController;
3130
use crate::balancer::model_metadata_sender_collection::ModelMetadataSenderCollection;
@@ -235,7 +234,7 @@ impl AgentController {
235234
}
236235

237236
#[async_trait]
238-
impl HandlesAgentRequest<ContinueFromConversationHistoryParams> for AgentController {
237+
impl HandlesAgentStreamingResponse<ContinueFromConversationHistoryParams> for AgentController {
239238
type SenderCollection = GenerateTokensSenderCollection;
240239

241240
async fn handle(
@@ -256,7 +255,7 @@ impl HandlesAgentRequest<ContinueFromConversationHistoryParams> for AgentControl
256255
}
257256

258257
#[async_trait]
259-
impl HandlesAgentRequest<ContinueFromRawPromptParams> for AgentController {
258+
impl HandlesAgentStreamingResponse<ContinueFromRawPromptParams> for AgentController {
260259
type SenderCollection = GenerateTokensSenderCollection;
261260

262261
async fn handle(
@@ -277,7 +276,7 @@ impl HandlesAgentRequest<ContinueFromRawPromptParams> for AgentController {
277276
}
278277

279278
#[async_trait]
280-
impl HandlesAgentRequest<GenerateEmbeddingBatchParams> for AgentController {
279+
impl HandlesAgentStreamingResponse<GenerateEmbeddingBatchParams> for AgentController {
281280
type SenderCollection = EmbeddingSenderCollection;
282281

283282
async fn handle(

src/balancer/handles_agent_request.rs renamed to src/balancer/handles_agent_streaming_response.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::balancer::manages_senders::ManagesSenders;
66
use crate::balancer::manages_senders_controller::ManagesSendersController;
77

88
#[async_trait]
9-
pub trait HandlesAgentRequest<TParams>
9+
pub trait HandlesAgentStreamingResponse<TParams>
1010
where
1111
TParams: Into<AgentJsonRpcRequest>,
1212
{

src/balancer/inference_service/controls_inference_endpoint.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::agent::jsonrpc::Request as AgentJsonRpcRequest;
1313
use crate::balancer::agent_controller::AgentController;
1414
use crate::balancer::buffered_request_agent_wait_result::BufferedRequestAgentWaitResult;
1515
use crate::balancer::buffered_request_manager::BufferedRequestManager;
16-
use crate::balancer::handles_agent_request::HandlesAgentRequest;
16+
use crate::balancer::handles_agent_streaming_response::HandlesAgentStreamingResponse;
1717
use crate::balancer::inference_service::configuration::Configuration as InferenceServiceConfiguration;
1818
use crate::balancer::inference_service::http_route::api::ws_inference_socket::client::Message as OutgoingMessage;
1919
use crate::balancer::inference_service::http_route::api::ws_inference_socket::client::Response as OutgoingResponse;
@@ -39,8 +39,8 @@ pub trait ControlsInferenceEndpoint {
3939
) -> Result<()>
4040
where
4141
TParams: Debug + Into<AgentJsonRpcRequest> + Send,
42-
AgentController: HandlesAgentRequest<TParams>,
43-
<<AgentController as HandlesAgentRequest<TParams>>::SenderCollection as ManagesSenders>::Value: Into<OutgoingResponse> + StreamableResult,
42+
AgentController: HandlesAgentStreamingResponse<TParams>,
43+
<<AgentController as HandlesAgentStreamingResponse<TParams>>::SenderCollection as ManagesSenders>::Value: Into<OutgoingResponse> + StreamableResult,
4444
{
4545
match Self::wait_for_agent_controller(
4646
buffered_request_manager.clone(),
@@ -55,6 +55,8 @@ pub trait ControlsInferenceEndpoint {
5555
match agent_controller.handle(request_id.clone(), params).await {
5656
Ok(receive_response_controller) => receive_response_controller,
5757
Err(err) => {
58+
error!("Failed to handle request {request_id:?}: {err}");
59+
5860
Self::respond_with_error(
5961
JsonRpcError {
6062
code: 500,

src/balancer/inference_service/http_route/api/post_generate_embedding_batch.rs

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use actix_web::Responder;
1010
use async_trait::async_trait;
1111
use bytes::Bytes;
1212
use futures::stream::StreamExt;
13+
use log::error;
1314
use nanoid::nanoid;
1415
use tokio::sync::broadcast;
1516
use tokio::sync::mpsc;
@@ -26,10 +27,10 @@ pub fn register(cfg: &mut web::ServiceConfig) {
2627
cfg.service(respond);
2728
}
2829

29-
struct GenerateEmbeddingBatchController {}
30+
struct Controller {}
3031

3132
#[async_trait]
32-
impl ControlsInferenceEndpoint for GenerateEmbeddingBatchController {
33+
impl ControlsInferenceEndpoint for Controller {
3334
type SessionController = ChunkForwardingSessionController;
3435
}
3536

@@ -57,12 +58,36 @@ async fn respond(
5758
let (connection_close_tx, mut connection_close_rx) = broadcast::channel::<()>(1);
5859
let (chunk_tx, chunk_rx) = mpsc::unbounded_channel();
5960

60-
let request_batches = params.chunk_by_input_size(
61-
agent_desired_state.inference_parameters.batch_n_tokens
62-
* CHARACTERS_PER_TOKEN_APPROXIMATELY,
63-
);
61+
// Distribute the embeddings evenly across the available agents
62+
params
63+
.chunk_by_input_size(
64+
agent_desired_state.inference_parameters.batch_n_tokens
65+
* CHARACTERS_PER_TOKEN_APPROXIMATELY,
66+
)
67+
.for_each(|batch| {
68+
let buffered_request_manager_clone = app_data.buffered_request_manager.clone();
69+
let chunk_tx_clone = chunk_tx.clone();
70+
let connection_close_tx_clone = connection_close_tx.clone();
71+
let inference_service_configuration_clone =
72+
app_data.inference_service_configuration.clone();
6473

65-
rt::spawn(async move {});
74+
rt::spawn(async move {
75+
if let Err(err) = Controller::request_from_agent(
76+
buffered_request_manager_clone,
77+
connection_close_tx_clone,
78+
inference_service_configuration_clone,
79+
batch,
80+
nanoid!(),
81+
ChunkForwardingSessionController {
82+
chunk_tx: chunk_tx_clone,
83+
},
84+
)
85+
.await
86+
{
87+
error!("Failed to handle request: {err}");
88+
}
89+
});
90+
});
6691

6792
let stream = UnboundedReceiverStream::new(chunk_rx)
6893
.map(|chunk: String| Ok::<_, Error>(Bytes::from(format!("{chunk}\n"))))

src/balancer/inference_service/http_route/api/ws_inference_socket/client/response.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
use serde::Deserialize;
22
use serde::Serialize;
33

4+
use crate::embedding_result::EmbeddingResult;
45
use crate::generated_token_result::GeneratedTokenResult;
56

67
#[derive(Deserialize, Serialize)]
78
pub enum Response {
9+
Embedding(EmbeddingResult),
810
GeneratedToken(GeneratedTokenResult),
911
Timeout,
1012
TooManyBufferedRequests,
1113
}
1214

15+
impl From<EmbeddingResult> for Response {
16+
fn from(result: EmbeddingResult) -> Self {
17+
Response::Embedding(result)
18+
}
19+
}
20+
1321
impl From<GeneratedTokenResult> for Response {
1422
fn from(result: GeneratedTokenResult) -> Self {
1523
Response::GeneratedToken(result)

src/balancer/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ pub mod chat_template_override_sender_collection;
1313
mod controls_manages_senders_endpoint;
1414
pub mod embedding_sender_collection;
1515
pub mod generate_tokens_sender_collection;
16-
mod handles_agent_request;
16+
mod handles_agent_streaming_response;
1717
mod http_route;
1818
pub mod inference_service;
1919
pub mod management_service;

src/embedding_result.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,16 @@ use serde::Deserialize;
22
use serde::Serialize;
33

44
use crate::embedding::Embedding;
5+
use crate::streamable_result::StreamableResult;
56

67
#[derive(Debug, Deserialize, Serialize)]
78
pub enum EmbeddingResult {
89
Done,
910
Embedding(Embedding),
1011
}
12+
13+
impl StreamableResult for EmbeddingResult {
14+
fn is_done(&self) -> bool {
15+
matches!(self, EmbeddingResult::Done)
16+
}
17+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
use super::GenerateEmbeddingBatchParams;
2+
use crate::embedding_input_document::EmbeddingInputDocument;
3+
use crate::embedding_normalization_method::EmbeddingNormalizationMethod;
4+
5+
pub struct ChunkByInputSizeIter<'embedding_batch> {
6+
pub chunk_size: usize,
7+
pub current_index: usize,
8+
pub input_batch: &'embedding_batch [EmbeddingInputDocument],
9+
pub normalization_method: &'embedding_batch EmbeddingNormalizationMethod,
10+
}
11+
12+
impl<'embedding_batch> Iterator for ChunkByInputSizeIter<'embedding_batch> {
13+
type Item = GenerateEmbeddingBatchParams;
14+
15+
fn next(&mut self) -> Option<Self::Item> {
16+
if self.current_index >= self.input_batch.len() {
17+
return None;
18+
}
19+
20+
let mut current_batch = Vec::new();
21+
let mut current_size = 0;
22+
23+
while self.current_index < self.input_batch.len() {
24+
let input = &self.input_batch[self.current_index];
25+
let input_size = input.content.chars().count();
26+
27+
if current_size + input_size > self.chunk_size && !current_batch.is_empty() {
28+
break;
29+
}
30+
31+
current_batch.push(input.clone());
32+
current_size += input_size;
33+
self.current_index += 1;
34+
}
35+
36+
if current_batch.is_empty() {
37+
None
38+
} else {
39+
Some(GenerateEmbeddingBatchParams {
40+
input_batch: current_batch,
41+
normalization_method: self.normalization_method.clone(),
42+
})
43+
}
44+
}
45+
}

src/request_params/generate_embedding_batch_params.rs renamed to src/request_params/generate_embedding_batch_params/mod.rs

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
mod chunk_by_input_size_iter;
2+
13
use serde::Deserialize;
24
use serde::Serialize;
35

6+
use self::chunk_by_input_size_iter::ChunkByInputSizeIter;
47
use crate::embedding_input_document::EmbeddingInputDocument;
58
use crate::embedding_normalization_method::EmbeddingNormalizationMethod;
69

@@ -12,37 +15,16 @@ pub struct GenerateEmbeddingBatchParams {
1215

1316
impl GenerateEmbeddingBatchParams {
1417
/// Input size is the total number of characters in the resulting batches.
15-
pub fn chunk_by_input_size(&self, chunk_size: usize) -> Vec<GenerateEmbeddingBatchParams> {
16-
let mut batches: Vec<GenerateEmbeddingBatchParams> = Vec::new();
17-
let mut current_batch: Vec<EmbeddingInputDocument> = Vec::new();
18-
let mut current_size: usize = 0;
19-
20-
for input in &self.input_batch {
21-
let input_size = input.content.chars().count();
22-
23-
current_batch.push(input.clone());
24-
current_size += input_size;
25-
26-
if current_size + input_size > chunk_size {
27-
if !current_batch.is_empty() {
28-
batches.push(GenerateEmbeddingBatchParams {
29-
input_batch: current_batch.clone(),
30-
normalization_method: self.normalization_method.clone(),
31-
});
32-
}
33-
current_batch.clear();
34-
current_size = 0;
35-
}
18+
pub fn chunk_by_input_size<'embedding>(
19+
&'embedding self,
20+
chunk_size: usize,
21+
) -> ChunkByInputSizeIter<'embedding> {
22+
ChunkByInputSizeIter {
23+
input_batch: &self.input_batch,
24+
normalization_method: &self.normalization_method,
25+
chunk_size,
26+
current_index: 0,
3627
}
37-
38-
if !current_batch.is_empty() {
39-
batches.push(GenerateEmbeddingBatchParams {
40-
input_batch: current_batch,
41-
normalization_method: self.normalization_method.clone(),
42-
});
43-
}
44-
45-
batches
4628
}
4729
}
4830

@@ -70,7 +52,8 @@ mod tests {
7052
normalization_method: EmbeddingNormalizationMethod::None,
7153
};
7254

73-
let batches = params.chunk_by_input_size(10);
55+
let batches = params.chunk_by_input_size(10).collect::<Vec<_>>();
56+
7457
assert_eq!(batches.len(), 2);
7558
assert_eq!(batches[0].input_batch.len(), 2);
7659
assert_eq!(batches[0].input_batch[0].id, "1");

0 commit comments

Comments
 (0)