Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ dependencies = [
]

[project.optional-dependencies]
litellm = ["litellm[caching]", "diskcache"]
litellm = ["litellm[caching]>=1.66.0", "diskcache"]
tgi = ["text-generation>=0.7.0"]
optimum = ["optimum==1.12.0"]
quantization = ["bitsandbytes>=0.41.0", "auto-gptq>=0.4.2"]
Expand Down
165 changes: 122 additions & 43 deletions src/lighteval/models/endpoints/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from json import JSONDecodeError

import requests
from tqdm import tqdm

from lighteval.data import GenerativeTaskDataset
Expand All @@ -39,14 +41,15 @@

if is_package_available("litellm"):
import litellm
from litellm import encode
from litellm.caching.caching import Cache
from litellm import encode, supports_reasoning
from litellm.caching.caching import Cache, LiteLLMCacheType
from litellm.utils import ModelResponse as LitellmModelResponse
from litellm.utils import get_max_tokens

logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("LiteLLM").handlers.clear()

litellm.cache = Cache(type="disk")
litellm.cache = Cache(type=LiteLLMCacheType.DISK)
else:
from unittest.mock import Mock

Expand Down Expand Up @@ -81,6 +84,18 @@ class LiteLLMModelConfig(ModelConfig):
Maximum number of concurrent API requests to execute in parallel.
Higher values can improve throughput for batch processing but may hit rate limits
or exhaust API quotas faster. Default is 10.
verbose (bool):
Whether to enable verbose logging. Default is False.
max_model_length (int | None):
Maximum context length for the model. If None, infers the model's default max length.
api_max_retry (int):
Maximum number of retries for API requests. Default is 8.
api_retry_sleep (float):
Initial sleep time (in seconds) between retries. Default is 1.0.
api_retry_multiplier (float):
Multiplier for increasing sleep time between retries. Default is 2.0.
timeout (float):
Request timeout in seconds. Default is None (no timeout).
generation_parameters (GenerationParameters, optional, defaults to empty GenerationParameters):
Configuration parameters that control text generation behavior, including
temperature, top_p, max_new_tokens, etc.
Expand Down Expand Up @@ -108,6 +123,13 @@ class LiteLLMModelConfig(ModelConfig):
base_url: str | None = None
api_key: str | None = None
concurrent_requests: int = 10
verbose: bool = False
max_model_length: int | None = None

api_max_retry: int = 8
api_retry_sleep: float = 1.0
api_retry_multiplier: float = 2.0
timeout: float | None = None


@requires("litellm")
Expand All @@ -125,15 +147,17 @@ def __init__(self, config: LiteLLMModelConfig) -> None:
self.api_key = config.api_key
self.generation_parameters = config.generation_parameters
self.concurrent_requests = config.concurrent_requests
self._max_length = config.max_model_length

self.API_MAX_RETRY = 5
self.API_RETRY_SLEEP = 3
self.API_RETRY_MULTIPLIER = 2
self.API_MAX_RETRY = config.api_max_retry
self.API_RETRY_SLEEP = config.api_retry_sleep
self.API_RETRY_MULTIPLIER = config.api_retry_multiplier
self.timeout = config.timeout

self._tokenizer = encode
self.pairwise_tokenization = False
litellm.drop_params = True
litellm.set_verbose = False
litellm.verbose = config.verbose
self.prompt_manager = PromptManager(
use_chat_template=True, tokenizer=self.tokenizer, system_prompt=config.system_prompt
)
Expand All @@ -149,58 +173,65 @@ def _prepare_stop_sequence(self, stop_sequence):
stop_sequence = [s for s in stop_sequence if s and s.strip()]
return stop_sequence

def _prepare_max_new_tokens(self, max_new_tokens):
def _prepare_max_new_tokens(self, max_new_tokens) -> int | None:
"""Calculate completion tokens based on max_new_tokens."""
if not max_new_tokens or max_new_tokens <= 0:
return None

if "o1" in self.model:
if supports_reasoning(self.model):
# We need to allow more tokens to include reasoning tokens
max_new_tokens = min(max_new_tokens * 10, 32000)
max_new_tokens = min(max_new_tokens * 10, self.max_length)

logger.warning(
f"Reasoning model detected, increasing max_new_tokens to {max_new_tokens} to allow for reasoning tokens",
)

return max_new_tokens

def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_sequence): # noqa: C901
"""Make API call with retries."""
response = LitellmModelResponse()
for attempt in range(self.API_MAX_RETRY):
try:
stop_sequence = self._prepare_stop_sequence(stop_sequence)
max_new_tokens = self._prepare_max_new_tokens(max_new_tokens)

if return_logits and not self.provider == "openai":
logger.warning("Returning logits is not supported for this provider, ignoring.")

# Prepare kwargs for completion call
kwargs = {
"model": self.model,
"messages": prompt,
"logprobs": return_logits if self.provider == "openai" else None,
"base_url": self.base_url,
"n": num_samples,
"caching": True,
"api_key": self.api_key,
}
stop_sequence = self._prepare_stop_sequence(stop_sequence)
max_new_tokens = self._prepare_max_new_tokens(max_new_tokens)

if return_logits and not self.provider == "openai":
logger.warning("Returning logits is not supported for this provider, ignoring.")

# Prepare kwargs for completion call
kwargs = {
"model": self.model,
"messages": prompt,
"response_format": {"type": "text"},
"max_tokens": max_new_tokens,
"logprobs": return_logits if self.provider == "openai" else None,
"stop": stop_sequence,
"base_url": self.base_url,
"api_key": self.api_key,
"n": num_samples,
"caching": True,
"timeout": self.timeout,
}

if num_samples > 1 and self.generation_parameters.temperature == 0:
raise ValueError(
"num_samples > 1 but temperature is set to 0, this will not sample different outputs."
)

if "o1" in self.model:
logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.")
else:
kwargs.update(self.generation_parameters.to_litellm_dict())
if "o1" in self.model:
logger.warning("O1 models do not support temperature, top_p, stop sequence. Disabling.")
else:
kwargs.update(self.generation_parameters.to_litellm_dict())

if kwargs.get("max_completion_tokens", None) is None:
kwargs["max_completion_tokens"] = max_new_tokens
if kwargs.get("max_completion_tokens", None) is None:
kwargs["max_completion_tokens"] = max_new_tokens

for attempt in range(self.API_MAX_RETRY):
try:
response = litellm.completion(**kwargs)
content = response.choices[0].message.content

# If response is empty, retry without caching (maybe the error is recoverable and solved with a retry)
if response.choices[0].message.content is None:
kwargs["caching"] = False
if not content:
logger.info("Response is empty, retrying without caching")
kwargs["caching"] = False
response = litellm.completion(**kwargs)
content = response.choices[0].message.content

return response
except litellm.BadRequestError as e:
if "message" in e.__dict__:
Expand All @@ -211,7 +242,9 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, stop_se
logger.warning(f"{error_string}. Returning empty response.")
return LitellmModelResponse()
except Exception as e:
wait_time = min(64, self.API_RETRY_SLEEP * (2**attempt)) # Exponential backoff with max 64s
wait_time = min(
64, self.API_RETRY_SLEEP * (self.API_RETRY_MULTIPLIER**attempt)
) # Exponential backoff with max 64s
logger.warning(
f"Error in API call: {e}, waiting {wait_time} seconds before retry {attempt + 1}/{self.API_MAX_RETRY}"
)
Expand Down Expand Up @@ -259,6 +292,38 @@ def __call_api_parallel(

return results

def estimate_context_length(self) -> int:
def fallback():
logger.warning("Failed to fetch model endpoint info from OpenRouter, returning default max length.")
return self._DEFAULT_MAX_LENGTH

# If the model is used through openrouter, the actual model name comes after the prefix
model_name = self.model.removeprefix("openrouter/")
endpoint_info_response = requests.get(
f"https://openrouter.ai/api/v1/models/{model_name}/endpoints",
headers={},
)
if endpoint_info_response.ok:
try:
endpoint_info = endpoint_info_response.json()
context_lengths = {
endpoint["provider_name"]: endpoint["context_length"]
for endpoint in endpoint_info["data"]["endpoints"]
}

if self.provider in context_lengths:
return context_lengths[self.provider]

min_length = min(context_lengths.values())
logger.warning(
f"Estimating model context length as the minimum context length from available OpenRouter providers: {min_length}"
)
return min_length
except (KeyError, TypeError, ValueError, JSONDecodeError):
return fallback()

return fallback()

@cached(SamplingMethod.GENERATIVE)
def greedy_until(
self,
Expand Down Expand Up @@ -322,7 +387,21 @@ def add_special_tokens(self) -> bool:
@property
def max_length(self) -> int:
"""Return the maximum sequence length of the model."""
return 4096
if self._max_length is not None:
return self._max_length

try:
max_tokens = get_max_tokens(self.model)
except Exception:
logger.error(
f"Unable to get the maximum sequence length for model {self.model} from litellm. Fetching information from OpenRouter instead."
)
max_tokens = self.estimate_context_length()

# Avoid future requests
self._max_length = max_tokens

return max_tokens

@cached(SamplingMethod.LOGPROBS)
def loglikelihood(self, docs: list[Doc]) -> list[ModelResponse]:
Expand Down
Loading