@@ -20,7 +20,7 @@ use tokio_tungstenite::connect_async;
20
20
use tokio_tungstenite:: tungstenite:: protocol:: Message ;
21
21
22
22
use crate :: agent_desired_state:: AgentDesiredState ;
23
- use crate :: agent:: receive_tokens_stopper_collection :: ReceiveTokensStopperCollection ;
23
+ use crate :: agent:: receive_stream_stopper_collection :: ReceiveStreamStopperCollection ;
24
24
use crate :: agent:: jsonrpc:: Message as JsonRpcMessage ;
25
25
use crate :: agent:: jsonrpc:: Notification as JsonRpcNotification ;
26
26
use crate :: agent:: jsonrpc:: Request as JsonRpcRequest ;
@@ -40,7 +40,6 @@ use crate::balancer::management_service::http_route::api::ws_agent_socket::jsonr
40
40
use crate :: jsonrpc:: Error as JsonRpcError ;
41
41
use crate :: jsonrpc:: ResponseEnvelope ;
42
42
use crate :: jsonrpc:: RequestEnvelope ;
43
- use crate :: generated_token_envelope:: GeneratedTokenEnvelope ;
44
43
use crate :: produces_snapshot:: ProducesSnapshot ;
45
44
use crate :: jsonrpc:: ErrorEnvelope ;
46
45
use crate :: service:: Service ;
@@ -53,8 +52,9 @@ struct IncomingMessageContext {
53
52
continue_from_conversation_history_request_tx :
54
53
mpsc:: UnboundedSender < ContinueFromConversationHistoryRequest > ,
55
54
continue_from_raw_prompt_request_tx : mpsc:: UnboundedSender < ContinueFromRawPromptRequest > ,
55
+ generate_embedding_batch_request_tx : mpsc:: UnboundedSender < GenerateEmbeddingBatchRequest > ,
56
56
model_metadata_holder : Arc < ModelMetadataHolder > ,
57
- receive_tokens_stopper_collection : Arc < ReceiveTokensStopperCollection > ,
57
+ receive_stream_stopper_collection : Arc < ReceiveStreamStopperCollection > ,
58
58
message_tx : mpsc:: UnboundedSender < ManagementJsonRpcMessage > ,
59
59
}
60
60
@@ -67,47 +67,46 @@ pub struct ManagementSocketClientService {
67
67
pub generate_embedding_batch_request_tx : mpsc:: UnboundedSender < GenerateEmbeddingBatchRequest > ,
68
68
pub model_metadata_holder : Arc < ModelMetadataHolder > ,
69
69
pub name : Option < String > ,
70
- pub receive_tokens_stopper_collection : Arc < ReceiveTokensStopperCollection > ,
70
+ pub receive_stream_stopper_collection : Arc < ReceiveStreamStopperCollection > ,
71
71
pub slot_aggregated_status : Arc < SlotAggregatedStatus > ,
72
72
pub socket_url : String ,
73
73
}
74
74
75
75
impl ManagementSocketClientService {
76
- async fn generate_tokens < TRequest : FromRequestParams + ' static > (
76
+ async fn generate_responses < TRequest : FromRequestParams + ' static > (
77
77
connection_close_tx : broadcast:: Sender < ( ) > ,
78
78
id : String ,
79
79
message_tx : mpsc:: UnboundedSender < ManagementJsonRpcMessage > ,
80
80
request_params : TRequest :: RequestParams ,
81
- receive_tokens_stopper_collection : Arc < ReceiveTokensStopperCollection > ,
81
+ receive_stream_stopper_collection : Arc < ReceiveStreamStopperCollection > ,
82
82
request_tx : mpsc:: UnboundedSender < TRequest > ,
83
83
) -> Result < ( ) > {
84
- let ( generated_tokens_tx, mut generated_tokens_rx) =
85
- mpsc:: unbounded_channel :: < GeneratedTokenEnvelope > ( ) ;
86
- let ( generate_tokens_stop_tx, generate_tokens_stop_rx) = mpsc:: unbounded_channel :: < ( ) > ( ) ;
84
+ let ( response_tx, mut response_rx) = mpsc:: unbounded_channel :: < TRequest :: Response > ( ) ;
85
+ let ( stop_tx, stop_rx) = mpsc:: unbounded_channel :: < ( ) > ( ) ;
87
86
88
- let _guard = receive_tokens_stopper_collection
89
- . register_stopper_with_guard ( id. clone ( ) , generate_tokens_stop_tx )
90
- . context ( format ! ( "Failed to register stopper for request ID : {id}" ) ) ?;
87
+ let _guard = receive_stream_stopper_collection
88
+ . register_stopper_with_guard ( id. clone ( ) , stop_tx )
89
+ . context ( format ! ( "Failed to register stopper for request: {id}" ) ) ?;
91
90
92
91
request_tx. send ( TRequest :: from_request_params (
93
92
request_params,
94
- generate_tokens_stop_rx ,
95
- generated_tokens_tx ,
93
+ response_tx ,
94
+ stop_rx ,
96
95
) ) ?;
97
96
98
97
let mut connection_close_rx = connection_close_tx. subscribe ( ) ;
99
98
100
99
loop {
101
100
tokio:: select! {
102
101
_ = connection_close_rx. recv( ) => break ,
103
- generated_token_envelope = generated_tokens_rx . recv( ) => {
104
- match generated_token_envelope {
105
- Some ( generated_token_envelope ) => {
102
+ response = response_rx . recv( ) => {
103
+ match response {
104
+ Some ( response ) => {
106
105
message_tx. send(
107
106
ManagementJsonRpcMessage :: Response (
108
107
ResponseEnvelope {
109
108
request_id: id. clone( ) ,
110
- response: JsonRpcResponse :: GeneratedToken ( generated_token_envelope ) ,
109
+ response: response . into ( ) ,
111
110
}
112
111
) ,
113
112
) ?;
@@ -128,9 +127,10 @@ impl ManagementSocketClientService {
128
127
connection_close_tx,
129
128
continue_from_conversation_history_request_tx,
130
129
continue_from_raw_prompt_request_tx,
130
+ generate_embedding_batch_request_tx,
131
131
message_tx,
132
132
model_metadata_holder,
133
- receive_tokens_stopper_collection ,
133
+ receive_stream_stopper_collection ,
134
134
} : IncomingMessageContext ,
135
135
deserialized_message : JsonRpcMessage ,
136
136
) -> Result < ( ) > {
@@ -154,7 +154,7 @@ impl ManagementSocketClientService {
154
154
}
155
155
JsonRpcMessage :: Notification ( JsonRpcNotification :: StopGeneratingTokens ( request_id) ) => {
156
156
debug ! ( "Received StopGeneratingTokens notification for request ID: {request_id:?}" ) ;
157
- receive_tokens_stopper_collection
157
+ receive_stream_stopper_collection
158
158
. stop ( request_id. clone ( ) )
159
159
. context ( format ! (
160
160
"Failed to stop generating tokens for request ID: {request_id}"
@@ -181,12 +181,12 @@ impl ManagementSocketClientService {
181
181
continue_from_conversation_history_params,
182
182
) ,
183
183
} ) => {
184
- Self :: generate_tokens (
184
+ Self :: generate_responses (
185
185
connection_close_tx,
186
186
id,
187
187
message_tx,
188
188
continue_from_conversation_history_params,
189
- receive_tokens_stopper_collection ,
189
+ receive_stream_stopper_collection ,
190
190
continue_from_conversation_history_request_tx,
191
191
)
192
192
. await
@@ -195,20 +195,30 @@ impl ManagementSocketClientService {
195
195
id,
196
196
request : JsonRpcRequest :: ContinueFromRawPrompt ( generate_tokens_params) ,
197
197
} ) => {
198
- Self :: generate_tokens (
198
+ Self :: generate_responses (
199
199
connection_close_tx,
200
200
id,
201
201
message_tx,
202
202
generate_tokens_params,
203
- receive_tokens_stopper_collection ,
203
+ receive_stream_stopper_collection ,
204
204
continue_from_raw_prompt_request_tx,
205
205
)
206
206
. await
207
207
}
208
208
JsonRpcMessage :: Request ( RequestEnvelope {
209
209
id,
210
210
request : JsonRpcRequest :: GenerateEmbeddingBatch ( generate_embedding_batch_params) ,
211
- } ) => Ok ( ( ) ) ,
211
+ } ) => {
212
+ Self :: generate_responses (
213
+ connection_close_tx,
214
+ id,
215
+ message_tx,
216
+ generate_embedding_batch_params,
217
+ receive_stream_stopper_collection,
218
+ generate_embedding_batch_request_tx,
219
+ )
220
+ . await
221
+ }
212
222
JsonRpcMessage :: Request ( RequestEnvelope {
213
223
id,
214
224
request : JsonRpcRequest :: GetChatTemplateOverride ,
@@ -429,8 +439,9 @@ impl ManagementSocketClientService {
429
439
connection_close_tx: connection_close_tx. clone( ) ,
430
440
continue_from_conversation_history_request_tx: self . continue_from_conversation_history_request_tx. clone( ) ,
431
441
continue_from_raw_prompt_request_tx: self . continue_from_raw_prompt_request_tx. clone( ) ,
442
+ generate_embedding_batch_request_tx: self . generate_embedding_batch_request_tx. clone( ) ,
432
443
model_metadata_holder: self . model_metadata_holder. clone( ) ,
433
- receive_tokens_stopper_collection : self . receive_tokens_stopper_collection . clone( ) ,
444
+ receive_stream_stopper_collection : self . receive_stream_stopper_collection . clone( ) ,
434
445
message_tx: message_tx. clone( ) ,
435
446
} ,
436
447
msg,
0 commit comments