Skip to content

Commit 5ac8574

Browse files
committed
send embeddings back through channel
1 parent 74a6846 commit 5ac8574

File tree

6 files changed

+45
-11
lines changed

6 files changed

+45
-11
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
use actix::Message;
22
use anyhow::Result;
3+
use tokio::sync::mpsc;
34

5+
use crate::embedding_result::EmbeddingResult;
46
use crate::request_params::GenerateEmbeddingBatchParams;
57

68
#[derive(Debug, Message)]
79
#[rtype(result = "Result<()>")]
810
pub struct GenerateEmbeddingBatchRequest {
11+
pub generate_embedding_stop_rx: mpsc::UnboundedReceiver<()>,
12+
pub generated_embedding_tx: mpsc::UnboundedSender<EmbeddingResult>,
913
pub params: GenerateEmbeddingBatchParams,
1014
}

src/agent/jsonrpc/request.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@ use serde::Serialize;
33

44
use crate::request_params::ContinueFromConversationHistoryParams;
55
use crate::request_params::ContinueFromRawPromptParams;
6+
use crate::request_params::GenerateEmbeddingBatchParams;
67

78
#[derive(Deserialize, Serialize)]
89
pub enum Request {
910
ContinueFromConversationHistory(ContinueFromConversationHistoryParams),
1011
ContinueFromRawPrompt(ContinueFromRawPromptParams),
12+
GenerateEmbeddingBatch(GenerateEmbeddingBatchParams),
1113
GetChatTemplateOverride,
1214
GetModelMetadata,
1315
}

src/agent/llamacpp_arbiter_service.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use tokio::time::MissedTickBehavior;
1919

2020
use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest;
2121
use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest;
22+
use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest;
2223
use crate::agent::llamacpp_arbiter::LlamaCppArbiter;
2324
use crate::agent::llamacpp_arbiter_handle::LlamaCppArbiterHandle;
2425
use crate::agent::llamacpp_slot::LlamaCppSlot;
@@ -39,6 +40,7 @@ pub struct LlamaCppArbiterService {
3940
mpsc::UnboundedReceiver<ContinueFromConversationHistoryRequest>,
4041
pub continue_from_raw_prompt_request_rx: mpsc::UnboundedReceiver<ContinueFromRawPromptRequest>,
4142
pub desired_slots_total: i32,
43+
pub generate_embedding_batch_request_rx: mpsc::UnboundedReceiver<GenerateEmbeddingBatchRequest>,
4244
pub llamacpp_arbiter_handle: Option<LlamaCppArbiterHandle>,
4345
pub model_metadata_holder: Arc<ModelMetadataHolder>,
4446
pub slot_aggregated_status_manager: Arc<SlotAggregatedStatusManager>,
@@ -143,7 +145,7 @@ impl LlamaCppArbiterService {
143145
Ok(())
144146
}
145147

146-
async fn generate_tokens<TRequest>(
148+
async fn forward_request_to_arbiter<TRequest>(
147149
&mut self,
148150
request: TRequest,
149151
mut shutdown: broadcast::Receiver<()>,
@@ -158,11 +160,11 @@ impl LlamaCppArbiterService {
158160
rt::spawn(async move {
159161
tokio::select! {
160162
_ = shutdown.recv() => {
161-
error!("Shutdown received, stopping ContinueFromRawPromptRequest processing");
163+
error!("Shutdown received, stopping request processing");
162164
}
163165
result = llamacpp_slot_addr.send(request) => {
164166
if let Err(err) = result {
165-
error!("Failed to send ContinueFromRawPromptRequest: {err}");
167+
error!("Failed to forward request to arbiter: {err}");
166168
}
167169
}
168170
}
@@ -222,7 +224,7 @@ impl Service for LlamaCppArbiterService {
222224
continue_from_conversation_history_request = self.continue_from_conversation_history_request_rx.recv() => {
223225
match continue_from_conversation_history_request {
224226
Some(continue_from_conversation_history_request) => {
225-
self.generate_tokens(
227+
self.forward_request_to_arbiter(
226228
continue_from_conversation_history_request,
227229
shutdown.resubscribe(),
228230
).await
@@ -235,7 +237,7 @@ impl Service for LlamaCppArbiterService {
235237
continue_from_raw_prompt_request = self.continue_from_raw_prompt_request_rx.recv() => {
236238
match continue_from_raw_prompt_request {
237239
Some(continue_from_raw_prompt_request) => {
238-
self.generate_tokens(
240+
self.forward_request_to_arbiter(
239241
continue_from_raw_prompt_request,
240242
shutdown.resubscribe(),
241243
).await
@@ -245,6 +247,19 @@ impl Service for LlamaCppArbiterService {
245247
}
246248
}
247249
}
250+
generate_embedding_batch_request = self.generate_embedding_batch_request_rx.recv() => {
251+
match generate_embedding_batch_request {
252+
Some(generate_embedding_batch_request) => {
253+
self.forward_request_to_arbiter(
254+
generate_embedding_batch_request,
255+
shutdown.resubscribe(),
256+
).await
257+
}
258+
None => {
259+
break Err(anyhow!("GenerateEmbeddingBatchRequest channel closed unexpectedly"));
260+
}
261+
}
262+
}
248263
}
249264
}
250265
}

src/agent/llamacpp_slot.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ use crate::agent::llamacpp_slot_context::LlamaCppSlotContext;
3131
use crate::embedding::Embedding;
3232
use crate::embedding_input_tokenized::EmbeddingInputTokenized;
3333
use crate::embedding_normalization_method::EmbeddingNormalizationMethod;
34+
use crate::embedding_result::EmbeddingResult;
3435
use crate::generated_token_envelope::GeneratedTokenEnvelope;
3536
use crate::generated_token_result::GeneratedTokenResult;
3637
use crate::request_params::ContinueFromConversationHistoryParams;
@@ -123,8 +124,8 @@ impl LlamaCppSlot {
123124
&mut self,
124125
batch: &mut LlamaBatch,
125126
current_batch_embeddings: &Vec<&EmbeddingInputTokenized>,
127+
generated_embedding_tx: &mpsc::UnboundedSender<EmbeddingResult>,
126128
normalization_method: &EmbeddingNormalizationMethod,
127-
output: &mut Vec<Embedding>,
128129
) -> Result<()> {
129130
self.llama_context.clear_kv_cache();
130131
self.llama_context.decode(batch)?;
@@ -135,14 +136,14 @@ impl LlamaCppSlot {
135136
.embeddings_seq_ith(index as i32)
136137
.context("Failed to get embeddings")?;
137138

138-
output.push(
139+
generated_embedding_tx.send(EmbeddingResult::Embedding(
139140
Embedding {
140141
embedding: embedding.to_vec(),
141142
normalization_method: EmbeddingNormalizationMethod::None,
142143
source_document_id: embedding_input_tokenized.id.clone(),
143144
}
144145
.normalize(normalization_method)?,
145-
);
146+
))?;
146147
}
147148

148149
batch.clear();
@@ -369,6 +370,8 @@ impl Handler<GenerateEmbeddingBatchRequest> for LlamaCppSlot {
369370
fn handle(
370371
&mut self,
371372
GenerateEmbeddingBatchRequest {
373+
generate_embedding_stop_rx,
374+
generated_embedding_tx,
372375
params:
373376
GenerateEmbeddingBatchParams {
374377
input_batch,
@@ -408,7 +411,6 @@ impl Handler<GenerateEmbeddingBatchRequest> for LlamaCppSlot {
408411

409412
let mut batch = LlamaBatch::new(self.slot_context.inference_parameters.batch_n_tokens, 1);
410413
let mut current_batch_embeddings: Vec<&EmbeddingInputTokenized> = Vec::new();
411-
let mut output = Vec::with_capacity(tokens_lines_list.len());
412414

413415
for embedding_input_tokenized in &tokens_lines_list {
414416
// Flush the batch if the next prompt would exceed our batch size
@@ -418,8 +420,8 @@ impl Handler<GenerateEmbeddingBatchRequest> for LlamaCppSlot {
418420
self.embedding_batch_decode(
419421
&mut batch,
420422
&current_batch_embeddings,
423+
&generated_embedding_tx,
421424
&normalization_method,
422-
&mut output,
423425
)?;
424426

425427
current_batch_embeddings.clear();
@@ -436,8 +438,8 @@ impl Handler<GenerateEmbeddingBatchRequest> for LlamaCppSlot {
436438
self.embedding_batch_decode(
437439
&mut batch,
438440
&current_batch_embeddings,
441+
&generated_embedding_tx,
439442
&normalization_method,
440-
&mut output,
441443
)?;
442444

443445
Ok(())

src/agent/management_socket_client_service.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use crate::agent::from_request_params::FromRequestParams;
3434
use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::Message as ManagementJsonRpcMessage;
3535
use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::Notification as ManagementJsonRpcNotification;
3636
use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest;
37+
use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest;
3738
use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::notification_params::RegisterAgentParams;
3839
use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonrpc::notification_params::UpdateAgentStatusParams;
3940
use crate::jsonrpc::Error as JsonRpcError;
@@ -63,6 +64,7 @@ pub struct ManagementSocketClientService {
6364
pub continue_from_conversation_history_request_tx:
6465
mpsc::UnboundedSender<ContinueFromConversationHistoryRequest>,
6566
pub continue_from_raw_prompt_request_tx: mpsc::UnboundedSender<ContinueFromRawPromptRequest>,
67+
pub generate_embedding_batch_request_tx: mpsc::UnboundedSender<GenerateEmbeddingBatchRequest>,
6668
pub model_metadata_holder: Arc<ModelMetadataHolder>,
6769
pub name: Option<String>,
6870
pub receive_tokens_stopper_collection: Arc<ReceiveTokensStopperCollection>,
@@ -203,6 +205,10 @@ impl ManagementSocketClientService {
203205
)
204206
.await
205207
}
208+
JsonRpcMessage::Request(RequestEnvelope {
209+
id,
210+
request: JsonRpcRequest::GenerateEmbeddingBatch(generate_embedding_batch_params),
211+
}) => Ok(()),
206212
JsonRpcMessage::Request(RequestEnvelope {
207213
id,
208214
request: JsonRpcRequest::GetChatTemplateOverride,

src/cmd/agent.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use super::handler::Handler;
1212
use super::parse_socket_addr;
1313
use crate::agent::continue_from_conversation_history_request::ContinueFromConversationHistoryRequest;
1414
use crate::agent::continue_from_raw_prompt_request::ContinueFromRawPromptRequest;
15+
use crate::agent::generate_embedding_batch_request::GenerateEmbeddingBatchRequest;
1516
use crate::agent::llamacpp_arbiter_service::LlamaCppArbiterService;
1617
use crate::agent::management_socket_client_service::ManagementSocketClientService;
1718
use crate::agent::model_metadata_holder::ModelMetadataHolder;
@@ -47,6 +48,8 @@ impl Handler for Agent {
4748
) = mpsc::unbounded_channel::<ContinueFromConversationHistoryRequest>();
4849
let (continue_from_raw_prompt_request_tx, continue_from_raw_prompt_request_rx) =
4950
mpsc::unbounded_channel::<ContinueFromRawPromptRequest>();
51+
let (generate_embedding_batch_request_tx, generate_embedding_batch_request_rx) =
52+
mpsc::unbounded_channel::<GenerateEmbeddingBatchRequest>();
5053

5154
let agent_applicable_state_holder = Arc::new(AgentApplicableStateHolder::new());
5255
let model_metadata_holder = Arc::new(ModelMetadataHolder::new());
@@ -60,6 +63,7 @@ impl Handler for Agent {
6063
continue_from_conversation_history_request_rx,
6164
continue_from_raw_prompt_request_rx,
6265
desired_slots_total: self.slots,
66+
generate_embedding_batch_request_rx,
6367
llamacpp_arbiter_handle: None,
6468
model_metadata_holder: model_metadata_holder.clone(),
6569
slot_aggregated_status_manager: slot_aggregated_status_manager.clone(),
@@ -70,6 +74,7 @@ impl Handler for Agent {
7074
agent_desired_state_tx,
7175
continue_from_conversation_history_request_tx,
7276
continue_from_raw_prompt_request_tx,
77+
generate_embedding_batch_request_tx,
7378
model_metadata_holder,
7479
name: self.name.clone(),
7580
receive_tokens_stopper_collection: Arc::new(ReceiveTokensStopperCollection::new()),

0 commit comments

Comments
 (0)