@@ -19,6 +19,7 @@ use tokio::time::MissedTickBehavior;
19
19
20
20
use crate :: agent:: continue_from_conversation_history_request:: ContinueFromConversationHistoryRequest ;
21
21
use crate :: agent:: continue_from_raw_prompt_request:: ContinueFromRawPromptRequest ;
22
+ use crate :: agent:: generate_embedding_batch_request:: GenerateEmbeddingBatchRequest ;
22
23
use crate :: agent:: llamacpp_arbiter:: LlamaCppArbiter ;
23
24
use crate :: agent:: llamacpp_arbiter_handle:: LlamaCppArbiterHandle ;
24
25
use crate :: agent:: llamacpp_slot:: LlamaCppSlot ;
@@ -39,6 +40,7 @@ pub struct LlamaCppArbiterService {
39
40
mpsc:: UnboundedReceiver < ContinueFromConversationHistoryRequest > ,
40
41
pub continue_from_raw_prompt_request_rx : mpsc:: UnboundedReceiver < ContinueFromRawPromptRequest > ,
41
42
pub desired_slots_total : i32 ,
43
+ pub generate_embedding_batch_request_rx : mpsc:: UnboundedReceiver < GenerateEmbeddingBatchRequest > ,
42
44
pub llamacpp_arbiter_handle : Option < LlamaCppArbiterHandle > ,
43
45
pub model_metadata_holder : Arc < ModelMetadataHolder > ,
44
46
pub slot_aggregated_status_manager : Arc < SlotAggregatedStatusManager > ,
@@ -143,7 +145,7 @@ impl LlamaCppArbiterService {
143
145
Ok ( ( ) )
144
146
}
145
147
146
- async fn generate_tokens < TRequest > (
148
+ async fn forward_request_to_arbiter < TRequest > (
147
149
& mut self ,
148
150
request : TRequest ,
149
151
mut shutdown : broadcast:: Receiver < ( ) > ,
@@ -158,11 +160,11 @@ impl LlamaCppArbiterService {
158
160
rt:: spawn ( async move {
159
161
tokio:: select! {
160
162
_ = shutdown. recv( ) => {
161
- error!( "Shutdown received, stopping ContinueFromRawPromptRequest processing" ) ;
163
+ error!( "Shutdown received, stopping request processing" ) ;
162
164
}
163
165
result = llamacpp_slot_addr. send( request) => {
164
166
if let Err ( err) = result {
165
- error!( "Failed to send ContinueFromRawPromptRequest : {err}" ) ;
167
+ error!( "Failed to forward request to arbiter : {err}" ) ;
166
168
}
167
169
}
168
170
}
@@ -222,7 +224,7 @@ impl Service for LlamaCppArbiterService {
222
224
continue_from_conversation_history_request = self . continue_from_conversation_history_request_rx. recv( ) => {
223
225
match continue_from_conversation_history_request {
224
226
Some ( continue_from_conversation_history_request) => {
225
- self . generate_tokens (
227
+ self . forward_request_to_arbiter (
226
228
continue_from_conversation_history_request,
227
229
shutdown. resubscribe( ) ,
228
230
) . await
@@ -235,7 +237,7 @@ impl Service for LlamaCppArbiterService {
235
237
continue_from_raw_prompt_request = self . continue_from_raw_prompt_request_rx. recv( ) => {
236
238
match continue_from_raw_prompt_request {
237
239
Some ( continue_from_raw_prompt_request) => {
238
- self . generate_tokens (
240
+ self . forward_request_to_arbiter (
239
241
continue_from_raw_prompt_request,
240
242
shutdown. resubscribe( ) ,
241
243
) . await
@@ -245,6 +247,19 @@ impl Service for LlamaCppArbiterService {
245
247
}
246
248
}
247
249
}
250
+ generate_embedding_batch_request = self . generate_embedding_batch_request_rx. recv( ) => {
251
+ match generate_embedding_batch_request {
252
+ Some ( generate_embedding_batch_request) => {
253
+ self . forward_request_to_arbiter(
254
+ generate_embedding_batch_request,
255
+ shutdown. resubscribe( ) ,
256
+ ) . await
257
+ }
258
+ None => {
259
+ break Err ( anyhow!( "GenerateEmbeddingBatchRequest channel closed unexpectedly" ) ) ;
260
+ }
261
+ }
262
+ }
248
263
}
249
264
}
250
265
}
0 commit comments