Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ csv = "1.3.0"
ureq = "=2.9"
pyo3 = { workspace = true }
chrono = "0.4.39"
nvml-wrapper = "0.11.0"


[build-dependencies]
Expand Down
1 change: 1 addition & 0 deletions router/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ mod tests {
generated_tokens: 10,
seed: None,
finish_reason: FinishReason::Length,
energy_mj: None,
}),
});
if let ChatEvent::Events(events) = events {
Expand Down
38 changes: 38 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ use tracing::warn;
use utoipa::ToSchema;
use uuid::Uuid;
use validation::Validation;
use nvml_wrapper::Nvml;
use std::sync::OnceLock;

static NVML: OnceLock<Option<Nvml>> = OnceLock::new();

#[allow(clippy::large_enum_variant)]
#[derive(Clone)]
Expand Down Expand Up @@ -1468,6 +1472,9 @@ pub(crate) struct Details {
pub best_of_sequences: Option<Vec<BestOfSequence>>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub top_tokens: Vec<Vec<Token>>,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = true, example = 152)]
pub energy_mj: Option<u64>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -1498,6 +1505,9 @@ pub(crate) struct StreamDetails {
pub seed: Option<u64>,
#[schema(example = 1)]
pub input_length: u32,
#[serde(skip_serializing_if = "Option::is_none")]
#[schema(nullable = true, example = 152)]
pub energy_mj: Option<u64>,
}

#[derive(Serialize, ToSchema, Clone)]
Expand Down Expand Up @@ -1546,6 +1556,34 @@ impl Default for ModelsInfo {
}
}

pub struct EnergyMonitor;

impl EnergyMonitor {
fn nvml() -> Option<&'static Nvml> {
NVML.get_or_init(|| Nvml::init().ok()).as_ref()
}

pub fn energy_mj(gpu_index: u32) -> Option<u64> {
let nvml = Self::nvml()?;
let device = nvml.device_by_index(gpu_index).ok()?;
device.total_energy_consumption().ok()
}

pub fn total_energy_mj() -> Option<u64> {
let nvml = Self::nvml()?;
let count = nvml.device_count().ok()?;
let mut total = 0;
for i in 0..count {
if let Ok(device) = nvml.device_by_index(i) {
if let Ok(energy) = device.total_energy_consumption() {
total += energy;
}
}
}
Some(total)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
17 changes: 16 additions & 1 deletion router/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::{
ChatCompletion, ChatCompletionChoice, ChatCompletionChunk, ChatCompletionComplete,
ChatCompletionDelta, ChatCompletionLogprob, ChatCompletionLogprobs, ChatCompletionTopLogprob,
ChatRequest, Chunk, CompatGenerateRequest, Completion, CompletionComplete, CompletionFinal,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool,
CompletionRequest, CompletionType, DeltaToolCall, Function, Prompt, Tool, EnergyMonitor,
};
use crate::{ChatTokenizeResponse, JsonSchemaConfig};
use crate::{FunctionDefinition, HubPreprocessorConfig, ToolCall, ToolChoice};
Expand Down Expand Up @@ -293,6 +293,7 @@ pub(crate) async fn generate_internal(
span: tracing::Span,
) -> Result<(HeaderMap, u32, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
let start_time = Instant::now();
let start_energy = EnergyMonitor::total_energy_mj();
metrics::counter!("tgi_request_count").increment(1);

// Do not long ultra long inputs, like image payloads.
Expand All @@ -317,6 +318,12 @@ pub(crate) async fn generate_internal(
}
_ => (infer.generate(req).await?, None),
};

let end_energy = EnergyMonitor::total_energy_mj();
let energy_mj = match (start_energy, end_energy) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
};

// Token details
let input_length = response._input_length;
Expand Down Expand Up @@ -354,6 +361,7 @@ pub(crate) async fn generate_internal(
seed: response.generated_text.seed,
best_of_sequences,
top_tokens: response.top_tokens,
energy_mj,
})
}
false => None,
Expand Down Expand Up @@ -515,6 +523,7 @@ async fn generate_stream_internal(
impl Stream<Item = Result<StreamResponse, InferError>>,
) {
let start_time = Instant::now();
let start_energy = EnergyMonitor::total_energy_mj();
metrics::counter!("tgi_request_count").increment(1);

tracing::debug!("Input: {}", req.inputs);
Expand Down Expand Up @@ -590,13 +599,19 @@ async fn generate_stream_internal(
queued,
top_tokens,
} => {
let end_energy = EnergyMonitor::total_energy_mj();
let energy_mj = match (start_energy, end_energy) {
(Some(start), Some(end)) => Some(end.saturating_sub(start)),
_ => None,
};
// Token details
let details = match details {
true => Some(StreamDetails {
finish_reason: generated_text.finish_reason,
generated_tokens: generated_text.generated_tokens,
seed: generated_text.seed,
input_length,
energy_mj,
}),
false => None,
};
Expand Down
Loading