diff --git a/docs/agents.md b/docs/agents.md index e9efadc0ff..4fee9c8241 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -103,6 +103,90 @@ You can also pass messages from previous runs to continue a conversation or prov ### Additional Configuration +#### Usage Limits + +PydanticAI offers a [`settings.UsageLimits`][pydantic_ai.settings.UsageLimits] structure to help you limit your +usage (tokens and/or requests) on model runs. + +You can apply these settings by passing the `usage_limits` argument to the `run{_sync,_stream}` functions. + +Consider the following example, where we limit the number of response tokens: + +```py +from pydantic_ai import Agent +from pydantic_ai.exceptions import UsageLimitExceeded +from pydantic_ai.settings import UsageLimits + +agent = Agent('claude-3-5-sonnet-latest') + +result_sync = agent.run_sync( + 'What is the capital of Italy? Answer with just the city.', + usage_limits=UsageLimits(response_tokens_limit=10), +) +print(result_sync.data) +#> Rome +print(result_sync.usage()) +""" +Usage(requests=1, request_tokens=62, response_tokens=1, total_tokens=63, details=None) +""" + +try: + result_sync = agent.run_sync( + 'What is the capital of Italy? Answer with a paragraph.', + usage_limits=UsageLimits(response_tokens_limit=10), + ) +except UsageLimitExceeded as e: + print(e) + #> Exceeded the response_tokens_limit of 10 (response_tokens=32) +``` + +Restricting the number of requests can be useful in preventing infinite loops or excessive tool calling: + +```py +from typing_extensions import TypedDict + +from pydantic_ai import Agent, ModelRetry +from pydantic_ai.exceptions import UsageLimitExceeded +from pydantic_ai.settings import UsageLimits + + +class NeverResultType(TypedDict): + """ + Never ever coerce data to this type. + """ + + never_use_this: str + + +agent = Agent( + 'claude-3-5-sonnet-latest', + result_type=NeverResultType, + system_prompt='Any time you get a response, call the `infinite_retry_tool` to produce another response.', +) + + +@agent.tool_plain(retries=5) # (1)! +def infinite_retry_tool() -> int: + raise ModelRetry('Please try again.') + + +try: + result_sync = agent.run_sync( + 'Begin infinite retry loop!', usage_limits=UsageLimits(request_limit=3) # (2)! + ) +except UsageLimitExceeded as e: + print(e) + #> The next request would exceed the request_limit of 3 +``` + +1. This tool has the ability to retry 5 times before erroring, simulating a tool that might get stuck in a loop. +2. This run will error after 3 requests, preventing the infinite tool calling. + +!!! note + This is especially relevant if you're registered a lot of tools, `request_limit` can be used to prevent the model from choosing to make too many of these calls. + +#### Model (Run) Settings + PydanticAI offers a [`settings.ModelSettings`][pydantic_ai.settings.ModelSettings] structure to help you fine tune your requests. This structure allows you to configure common parameters that influence the model's behavior, such as `temperature`, `max_tokens`, `timeout`, and more. diff --git a/docs/api/models/ollama.md b/docs/api/models/ollama.md index 8443fa04e8..557711165b 100644 --- a/docs/api/models/ollama.md +++ b/docs/api/models/ollama.md @@ -32,7 +32,9 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.data) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None) +""" +Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None) +""" ``` ## Example using a remote server @@ -60,7 +62,9 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.data) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None) +""" +Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None) +""" ``` 1. The name of the model running on the remote server diff --git a/docs/api/settings.md b/docs/api/settings.md index c7847c9f78..336021a1a5 100644 --- a/docs/api/settings.md +++ b/docs/api/settings.md @@ -5,3 +5,4 @@ inherited_members: true members: - ModelSettings + - UsageLimits diff --git a/docs/results.md b/docs/results.md index f3f9700ff5..292ed11f56 100644 --- a/docs/results.md +++ b/docs/results.md @@ -19,7 +19,9 @@ result = agent.run_sync('Where were the olympics held in 2012?') print(result.data) #> city='London' country='United Kingdom' print(result.usage()) -#> Usage(request_tokens=57, response_tokens=8, total_tokens=65, details=None) +""" +Usage(requests=1, request_tokens=57, response_tokens=8, total_tokens=65, details=None) +""" ``` _(This example is complete, it can be run "as is")_ diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index 119699afc6..95564c06f8 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -1,8 +1,18 @@ from importlib.metadata import version from .agent import Agent -from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError +from .exceptions import AgentRunError, ModelRetry, UnexpectedModelBehavior, UsageLimitExceeded, UserError from .tools import RunContext, Tool -__all__ = 'Agent', 'Tool', 'RunContext', 'ModelRetry', 'UnexpectedModelBehavior', 'UserError', '__version__' +__all__ = ( + 'Agent', + 'RunContext', + 'Tool', + 'AgentRunError', + 'ModelRetry', + 'UnexpectedModelBehavior', + 'UsageLimitExceeded', + 'UserError', + '__version__', +) __version__ = version('pydantic_ai_slim') diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 72c76eb755..ec055e1595 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -22,7 +22,7 @@ result, ) from .result import ResultData -from .settings import ModelSettings, merge_model_settings +from .settings import ModelSettings, UsageLimits, merge_model_settings from .tools import ( AgentDeps, RunContext, @@ -191,6 +191,7 @@ async def run( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -211,8 +212,9 @@ async def run( message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. @@ -237,12 +239,14 @@ async def run( for tool in self._function_tools.values(): tool.current_retry = 0 - usage = result.Usage() - + usage = result.Usage(requests=0) model_settings = merge_model_settings(self.model_settings, model_settings) + usage_limits = usage_limits or UsageLimits() run_step = 0 while True: + usage_limits.check_before_request(usage) + run_step += 1 with _logfire.span('preparing model and tools {run_step=}', run_step=run_step): agent_model = await self._prepare_model(model_used, deps, messages) @@ -254,6 +258,8 @@ async def run( messages.append(model_response) usage += request_usage + usage.requests += 1 + usage_limits.check_tokens(request_usage) with _logfire.span('handle model response', run_step=run_step) as handle_span: final_result, tool_responses = await self._handle_model_response(model_response, deps, messages) @@ -284,6 +290,7 @@ def run_sync( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -308,8 +315,9 @@ async def main(): message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. @@ -322,8 +330,9 @@ async def main(): message_history=message_history, model=model, deps=deps, - infer_name=False, model_settings=model_settings, + usage_limits=usage_limits, + infer_name=False, ) ) @@ -336,6 +345,7 @@ async def run_stream( model: models.Model | models.KnownModelName | None = None, deps: AgentDeps = None, model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, infer_name: bool = True, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -357,8 +367,9 @@ async def main(): message_history: History of the conversation so far. model: Optional model to use for this run, required if `model` was not set when creating the agent. deps: Optional dependencies to use for this run. - infer_name: Whether to try to infer the agent name from the call frame if it's not set. model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: The result of the run. @@ -387,16 +398,19 @@ async def main(): usage = result.Usage() model_settings = merge_model_settings(self.model_settings, model_settings) + usage_limits = usage_limits or UsageLimits() run_step = 0 while True: run_step += 1 + usage_limits.check_before_request(usage) with _logfire.span('preparing model and tools {run_step=}', run_step=run_step): agent_model = await self._prepare_model(model_used, deps, messages) with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: async with agent_model.request_stream(messages, model_settings) as model_response: + usage.requests += 1 model_req_span.set_attribute('response_type', model_response.__class__.__name__) # We want to end the "model request" span here, but we can't exit the context manager # in the traditional way @@ -435,6 +449,7 @@ async def on_complete(): messages, new_message_index, usage, + usage_limits, result_stream, self._result_schema, deps, @@ -456,7 +471,9 @@ async def on_complete(): tool_responses_str = ' '.join(r.part_kind for r in tool_responses) handle_span.message = f'handle model response -> {tool_responses_str}' # the model_response should have been fully streamed by now, we can add its usage - usage += model_response.usage() + model_response_usage = model_response.usage() + usage += model_response_usage + usage_limits.check_tokens(usage) @contextmanager def override( diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 493f00d37f..b2daea661e 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -2,7 +2,7 @@ import json -__all__ = 'ModelRetry', 'UserError', 'UnexpectedModelBehavior' +__all__ = 'ModelRetry', 'UserError', 'AgentRunError', 'UnexpectedModelBehavior', 'UsageLimitExceeded' class ModelRetry(Exception): @@ -30,7 +30,25 @@ def __init__(self, message: str): super().__init__(message) -class UnexpectedModelBehavior(RuntimeError): +class AgentRunError(RuntimeError): + """Base class for errors occurring during an agent run.""" + + message: str + """The error message.""" + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + def __str__(self) -> str: + return self.message + + +class UsageLimitExceeded(AgentRunError): + """Error raised when a Model's usage exceeds the specified limits.""" + + +class UnexpectedModelBehavior(AgentRunError): """Error caused by unexpected Model behavior, e.g. an unexpected response code.""" message: str diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 325de9ffbc..8b367ab7a1 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -282,11 +282,11 @@ def _map_message(messages: list[ModelMessage]) -> tuple[str, list[MessageParam]] MessageParam( role='user', content=[ - ToolUseBlockParam( - id=_guard_tool_call_id(t=part, model_source='Anthropic'), - input=part.model_response(), - name=part.tool_name, - type='tool_use', + ToolResultBlockParam( + tool_use_id=_guard_tool_call_id(t=part, model_source='Anthropic'), + type='tool_result', + content=part.model_response(), + is_error=True, ), ], ) diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 077b29c69a..4d1c602bc5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -237,7 +237,7 @@ def get(self, *, final: bool = False) -> ModelResponse: return ModelResponse(calls, timestamp=self._timestamp) def usage(self) -> result.Usage: - return result.Usage() + return _estimate_usage([self.get()]) def timestamp(self) -> datetime: return self._timestamp @@ -255,24 +255,24 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage: if isinstance(message, ModelRequest): for part in message.parts: if isinstance(part, (SystemPromptPart, UserPromptPart)): - request_tokens += _string_usage(part.content) + request_tokens += _estimate_string_usage(part.content) elif isinstance(part, ToolReturnPart): - request_tokens += _string_usage(part.model_response_str()) + request_tokens += _estimate_string_usage(part.model_response_str()) elif isinstance(part, RetryPromptPart): - request_tokens += _string_usage(part.model_response()) + request_tokens += _estimate_string_usage(part.model_response()) else: assert_never(part) elif isinstance(message, ModelResponse): for part in message.parts: if isinstance(part, TextPart): - response_tokens += _string_usage(part.content) + response_tokens += _estimate_string_usage(part.content) elif isinstance(part, ToolCallPart): call = part if isinstance(call.args, ArgsJson): args_str = call.args.args_json else: args_str = pydantic_core.to_json(call.args.args_dict).decode() - response_tokens += 1 + _string_usage(args_str) + response_tokens += 1 + _estimate_string_usage(args_str) else: assert_never(part) else: @@ -282,5 +282,5 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> result.Usage: ) -def _string_usage(content: str) -> int: +def _estimate_string_usage(content: str) -> int: return len(re.split(r'[\s",.:]+', content)) diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index f953d8fabe..0f81d7eb7c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -31,6 +31,7 @@ StreamStructuredResponse, StreamTextResponse, ) +from .function import _estimate_string_usage, _estimate_usage # pyright: ignore[reportPrivateUsage] @dataclass @@ -132,14 +133,16 @@ class TestAgentModel(AgentModel): async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None ) -> tuple[ModelResponse, Usage]: - return self._request(messages, model_settings), Usage() + model_response = self._request(messages, model_settings) + usage = _estimate_usage([*messages, model_response]) + return model_response, usage @asynccontextmanager async def request_stream( self, messages: list[ModelMessage], model_settings: ModelSettings | None ) -> AsyncIterator[EitherStreamedResponse]: msg = self._request(messages, model_settings) - usage = Usage() + usage = _estimate_usage(messages) # TODO: Rework this once we make StreamTextResponse more general texts: list[str] = [] @@ -228,7 +231,10 @@ def __post_init__(self): self._iter = iter(words) async def __anext__(self) -> None: - self._buffer.append(_utils.sync_anext(self._iter)) + next_str = _utils.sync_anext(self._iter) + response_tokens = _estimate_string_usage(next_str) + self._usage += Usage(response_tokens=response_tokens, total_tokens=response_tokens) + self._buffer.append(next_str) def get(self, *, final: bool = False) -> Iterable[str]: yield from self._buffer diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 26921e410a..5a19cf9c0f 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -9,6 +9,7 @@ import logfire_api from . import _result, _utils, exceptions, messages as _messages, models +from .settings import UsageLimits from .tools import AgentDeps __all__ = ( @@ -34,6 +35,8 @@ class Usage: You'll need to look up the documentation of the model you're using to convert usage to monetary costs. """ + requests: int = 0 + """Number of requests made.""" request_tokens: int | None = None """Tokens used in processing requests.""" response_tokens: int | None = None @@ -49,7 +52,7 @@ def __add__(self, other: Usage) -> Usage: This is provided so it's trivial to sum usage information from multiple requests and runs. """ counts: dict[str, int] = {} - for f in 'request_tokens', 'response_tokens', 'total_tokens': + for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': self_value = getattr(self, f) other_value = getattr(other, f) if self_value is not None or other_value is not None: @@ -118,6 +121,7 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat usage_so_far: Usage """Usage of the run up until the last request.""" + _usage_limits: UsageLimits | None _stream_response: models.EitherStreamedResponse _result_schema: _result.ResultSchema[ResultData] | None _deps: AgentDeps @@ -173,11 +177,15 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = Debouncing is particularly important for long structured responses to reduce the overhead of performing validation as each token is received. """ + usage_checking_stream = _get_usage_checking_stream_response( + self._stream_response, self._usage_limits, self.usage + ) + with _logfire.span('response stream text') as lf_span: if isinstance(self._stream_response, models.StreamStructuredResponse): raise exceptions.UserError('stream_text() can only be used with text responses') if delta: - async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: + async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: async for _ in group_iter: yield ''.join(self._stream_response.get()) final_delta = ''.join(self._stream_response.get(final=True)) @@ -188,7 +196,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None = # yielding at each step chunks: list[str] = [] combined = '' - async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: + async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: async for _ in group_iter: new = False for chunk in self._stream_response.get(): @@ -225,6 +233,10 @@ async def stream_structured( Returns: An async iterable of the structured response message and whether that is the last message. """ + usage_checking_stream = _get_usage_checking_stream_response( + self._stream_response, self._usage_limits, self.usage + ) + with _logfire.span('response stream structured') as lf_span: if isinstance(self._stream_response, models.StreamTextResponse): raise exceptions.UserError('stream_structured() can only be used with structured responses') @@ -235,7 +247,7 @@ async def stream_structured( if isinstance(item, _messages.ToolCallPart) and item.has_content(): yield msg, False break - async with _utils.group_by_temporal(self._stream_response, debounce_by) as group_iter: + async with _utils.group_by_temporal(usage_checking_stream, debounce_by) as group_iter: async for _ in group_iter: msg = self._stream_response.get() for item in msg.parts: @@ -249,8 +261,13 @@ async def stream_structured( async def get_data(self) -> ResultData: """Stream the whole response, validate and return it.""" - async for _ in self._stream_response: + usage_checking_stream = _get_usage_checking_stream_response( + self._stream_response, self._usage_limits, self.usage + ) + + async for _ in usage_checking_stream: pass + if isinstance(self._stream_response, models.StreamTextResponse): text = ''.join(self._stream_response.get(final=True)) text = await self._validate_text_result(text) @@ -312,3 +329,18 @@ async def _marked_completed(self, message: _messages.ModelResponse) -> None: self.is_complete = True self._all_messages.append(message) await self._on_complete() + + +def _get_usage_checking_stream_response( + stream_response: AsyncIterator[ResultData], limits: UsageLimits | None, get_usage: Callable[[], Usage] +) -> AsyncIterator[ResultData]: + if limits is not None and limits.has_token_limits(): + + async def _usage_checking_iterator(): + async for item in stream_response: + limits.check_tokens(get_usage()) + yield item + + return _usage_checking_iterator() + else: + return stream_response diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index d3e2d429c2..e4c3202f9a 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -1,8 +1,16 @@ from __future__ import annotations +from dataclasses import dataclass +from typing import TYPE_CHECKING + from httpx import Timeout from typing_extensions import TypedDict +from .exceptions import UsageLimitExceeded + +if TYPE_CHECKING: + from .result import Usage + class ModelSettings(TypedDict, total=False): """Settings to configure an LLM. @@ -70,3 +78,60 @@ def merge_model_settings(base: ModelSettings | None, overrides: ModelSettings | return base | overrides else: return base or overrides + + +@dataclass +class UsageLimits: + """Limits on model usage. + + The request count is tracked by pydantic_ai, and the request limit is checked before each request to the model. + Token counts are provided in responses from the model, and the token limits are checked after each response. + + Each of the limits can be set to `None` to disable that limit. + """ + + request_limit: int | None = 50 + """The maximum number of requests allowed to the model.""" + request_tokens_limit: int | None = None + """The maximum number of tokens allowed in requests to the model.""" + response_tokens_limit: int | None = None + """The maximum number of tokens allowed in responses from the model.""" + total_tokens_limit: int | None = None + """The maximum number of tokens allowed in requests and responses combined.""" + + def has_token_limits(self) -> bool: + """Returns `True` if this instance places any limits on token counts. + + If this returns `False`, the `check_tokens` method will never raise an error. + + This is useful because if we have token limits, we need to check them after receiving each streamed message. + If there are no limits, we can skip that processing in the streaming response iterator. + """ + return any( + limit is not None + for limit in (self.request_tokens_limit, self.response_tokens_limit, self.total_tokens_limit) + ) + + def check_before_request(self, usage: Usage) -> None: + """Raises a `UsageLimitExceeded` exception if the next request would exceed the request_limit.""" + request_limit = self.request_limit + if request_limit is not None and usage.requests >= request_limit: + raise UsageLimitExceeded(f'The next request would exceed the request_limit of {request_limit}') + + def check_tokens(self, usage: Usage) -> None: + """Raises a `UsageLimitExceeded` exception if the usage exceeds any of the token limits.""" + request_tokens = usage.request_tokens or 0 + if self.request_tokens_limit is not None and request_tokens > self.request_tokens_limit: + raise UsageLimitExceeded( + f'Exceeded the request_tokens_limit of {self.request_tokens_limit} ({request_tokens=})' + ) + + response_tokens = usage.response_tokens or 0 + if self.response_tokens_limit is not None and response_tokens > self.response_tokens_limit: + raise UsageLimitExceeded( + f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})' + ) + + total_tokens = request_tokens + response_tokens + if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: + raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index d73da5a73b..1fad2628e7 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -91,14 +91,14 @@ async def test_sync_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.usage() == snapshot(Usage(request_tokens=5, response_tokens=10, total_tokens=15)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.data == 'world' - assert result.usage() == snapshot(Usage(request_tokens=5, response_tokens=10, total_tokens=15)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=5, response_tokens=10, total_tokens=15)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -120,7 +120,7 @@ async def test_async_request_text_response(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.usage() == snapshot(Usage(request_tokens=3, response_tokens=5, total_tokens=8)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=3, response_tokens=5, total_tokens=8)) async def test_request_structured_response(allow_model_requests: None): diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index ef6f4b262c..0cc14644b3 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -396,7 +396,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): ModelResponse.from_text(content='Hello world', timestamp=IsNow(tz=timezone.utc)), ] ) - assert result.usage() == snapshot(Usage(request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) result = await agent.run('Hello', message_history=result.new_messages()) assert result.data == 'Hello world' @@ -517,7 +517,7 @@ async def get_location(loc_name: str) -> str: ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) - assert result.usage() == snapshot(Usage(request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9)) async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None): @@ -572,12 +572,12 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot(['Hello ', 'Hello world']) - assert result.usage() == snapshot(Usage(request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] assert chunks == snapshot(['', 'Hello ', 'world']) - assert result.usage() == snapshot(Usage(request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -609,7 +609,7 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot([(1, 2), (1, 2), (1, 2)]) - assert result.usage() == snapshot(Usage(request_tokens=1, response_tokens=2, total_tokens=3)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient): @@ -652,7 +652,7 @@ async def bar(y: str) -> str: async with agent.run_stream('Hello') as result: response = await result.get_data() assert response == snapshot((1, 2)) - assert result.usage() == snapshot(Usage(request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot(Usage(requests=2, request_tokens=3, response_tokens=6, total_tokens=9)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 07983e2e6d..93a7cf6376 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -128,14 +128,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot(Usage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.data == 'world' - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot(Usage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -411,7 +411,7 @@ async def test_stream_structured(allow_model_requests: None): ) assert result.is_complete - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot(Usage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]), diff --git a/tests/models/test_model_function.py b/tests/models/test_model_function.py index 540d42596a..8ebfba7a9a 100644 --- a/tests/models/test_model_function.py +++ b/tests/models/test_model_function.py @@ -399,7 +399,7 @@ async def test_stream_text(): ModelResponse.from_text(content='hello world', timestamp=IsNow(tz=timezone.utc)), ] ) - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot(Usage(requests=1)) class Foo(BaseModel): @@ -420,7 +420,14 @@ async def stream_structured_function( agent = Agent(FunctionModel(stream_function=stream_structured_function), result_type=Foo) async with agent.run_stream('') as result: assert await result.get_data() == snapshot(Foo(x=1)) - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot( + Usage( + requests=1, + request_tokens=50, + response_tokens=4, + total_tokens=54, + ) + ) async def test_pass_neither(): diff --git a/tests/models/test_ollama.py b/tests/models/test_ollama.py index bfebab6022..c608c656fd 100644 --- a/tests/models/test_ollama.py +++ b/tests/models/test_ollama.py @@ -44,14 +44,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot(Usage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.data == 'world' - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot(Usage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 094d339074..722f51f78d 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -137,14 +137,14 @@ async def test_request_simple_success(allow_model_requests: None): result = await agent.run('hello') assert result.data == 'world' - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot(Usage(requests=1)) # reset the index so we get the same response again mock_client.index = 0 # type: ignore result = await agent.run('hello', message_history=result.new_messages()) assert result.data == 'world' - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot(Usage(requests=1)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), @@ -166,7 +166,7 @@ async def test_request_simple_usage(allow_model_requests: None): result = await agent.run('Hello') assert result.data == 'world' - assert result.usage() == snapshot(Usage(request_tokens=2, response_tokens=1, total_tokens=3)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3)) async def test_request_structured_response(allow_model_requests: None): @@ -323,7 +323,13 @@ async def get_location(loc_name: str) -> str: ] ) assert result.usage() == snapshot( - Usage(request_tokens=5, response_tokens=3, total_tokens=9, details={'cached_tokens': 3}) + Usage( + requests=3, + request_tokens=5, + response_tokens=3, + total_tokens=9, + details={'cached_tokens': 3}, + ) ) @@ -358,7 +364,7 @@ async def test_stream_text(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) async def test_stream_text_finish_reason(allow_model_requests: None): @@ -425,7 +431,7 @@ async def test_stream_structured(allow_model_requests: None): ] ) assert result.is_complete - assert result.usage() == snapshot(Usage(request_tokens=20, response_tokens=10, total_tokens=30)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30)) # double check usage matches stream count assert result.usage().response_tokens == len(stream) @@ -482,4 +488,4 @@ async def test_no_delta(allow_model_requests: None): assert not result.is_complete assert [c async for c in result.stream(debounce_by=None)] == snapshot(['hello ', 'hello world']) assert result.is_complete - assert result.usage() == snapshot(Usage(request_tokens=6, response_tokens=3, total_tokens=9)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) diff --git a/tests/test_agent.py b/tests/test_agent.py index 7b724e5211..b77072cce0 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -505,7 +505,7 @@ async def ret_a(x: str) -> str: ], _new_message_index=4, data='{"ret_a":"a-apple"}', - _usage=Usage(), + _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), ) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] @@ -547,7 +547,7 @@ async def ret_a(x: str) -> str: ], _new_message_index=4, data='{"ret_a":"a-apple"}', - _usage=Usage(), + _usage=Usage(requests=1, request_tokens=55, response_tokens=13, total_tokens=68, details=None), ) ) @@ -646,7 +646,7 @@ async def ret_a(x: str) -> str: ), ], _new_message_index=5, - _usage=Usage(), + _usage=Usage(requests=1, request_tokens=59, response_tokens=13, total_tokens=72, details=None), ) ) new_msg_part_kinds = [(m.kind, [p.part_kind for p in m.parts]) for m in result2.all_messages()] diff --git a/tests/test_examples.py b/tests/test_examples.py index 64fca2f7d4..9fcb1fab0f 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -213,6 +213,12 @@ async def async_http_request(url: str, **kwargs: Any) -> httpx.Response: } ), ), + 'What is the capital of Italy? Answer with just the city.': 'Rome', + 'What is the capital of Italy? Answer with a paragraph.': ( + 'The capital of Italy is Rome (Roma, in Italian), which has been a cultural and political center for centuries.' + 'Rome is known for its rich history, stunning architecture, and delicious cuisine.' + ), + 'Begin infinite retry loop!': ToolCallPart(tool_name='infinite_retry_tool', args=ArgsDict({})), } @@ -241,6 +247,8 @@ async def model_logic(messages: list[ModelMessage], info: AgentInfo) -> ModelRes and m.content.startswith("No user found with name 'Joh") ): return ModelResponse(parts=[ToolCallPart(tool_name='get_user_by_name', args=ArgsDict({'name': 'John Doe'}))]) + elif isinstance(m, RetryPromptPart) and m.tool_name == 'infinite_retry_tool': + return ModelResponse(parts=[ToolCallPart(tool_name='infinite_retry_tool', args=ArgsDict({}))]) elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_user_by_name': args = { 'message': 'Hello John, would you be free for coffee sometime next week? Let me know what works for you!', diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 89505d0bd9..85ec8f1e4f 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -161,7 +161,9 @@ async def my_ret(x: int) -> str: }, ] ), - 'usage': IsJson({'request_tokens': None, 'response_tokens': None, 'total_tokens': None, 'details': None}), + 'usage': IsJson( + {'requests': 2, 'request_tokens': 103, 'response_tokens': 12, 'total_tokens': 115, 'details': None} + ), 'logfire.json_schema': IsJson( { 'type': 'object', diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 7e593f9bde..2323f0c138 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -55,10 +55,25 @@ async def ret_a(x: str) -> str: ), ] ) + assert result.usage() == snapshot( + Usage( + requests=2, + request_tokens=103, + response_tokens=5, + total_tokens=108, + ) + ) response = await result.get_data() assert response == snapshot('{"ret_a":"a-apple"}') assert result.is_complete - assert result.usage() == snapshot(Usage()) + assert result.usage() == snapshot( + Usage( + requests=2, + request_tokens=103, + response_tokens=11, + total_tokens=114, + ) + ) assert result.timestamp() == IsNow(tz=timezone.utc) assert result.all_messages() == snapshot( [ diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py new file mode 100644 index 0000000000..926f29a836 --- /dev/null +++ b/tests/test_usage_limits.py @@ -0,0 +1,99 @@ +import re +from datetime import timezone + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai import Agent, UsageLimitExceeded +from pydantic_ai.messages import ArgsDict, ModelRequest, ModelResponse, ToolCallPart, ToolReturnPart, UserPromptPart +from pydantic_ai.models.test import TestModel +from pydantic_ai.result import Usage +from pydantic_ai.settings import UsageLimits + +from .conftest import IsNow + +pytestmark = pytest.mark.anyio + + +def test_request_token_limit(set_event_loop: None) -> None: + test_agent = Agent(TestModel()) + + with pytest.raises( + UsageLimitExceeded, match=re.escape('Exceeded the request_tokens_limit of 5 (request_tokens=59)') + ): + test_agent.run_sync( + 'Hello, this prompt exceeds the request tokens limit.', usage_limits=UsageLimits(request_tokens_limit=5) + ) + + +def test_response_token_limit(set_event_loop: None) -> None: + test_agent = Agent( + TestModel(custom_result_text='Unfortunately, this response exceeds the response tokens limit by a few!') + ) + + with pytest.raises( + UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 5 (response_tokens=11)') + ): + test_agent.run_sync('Hello', usage_limits=UsageLimits(response_tokens_limit=5)) + + +def test_total_token_limit(set_event_loop: None) -> None: + test_agent = Agent(TestModel(custom_result_text='This utilizes 4 tokens!')) + + with pytest.raises(UsageLimitExceeded, match=re.escape('Exceeded the total_tokens_limit of 50 (total_tokens=55)')): + test_agent.run_sync('Hello', usage_limits=UsageLimits(total_tokens_limit=50)) + + +def test_retry_limit(set_event_loop: None) -> None: + test_agent = Agent(TestModel()) + + @test_agent.tool_plain + async def foo(x: str) -> str: + return x + + @test_agent.tool_plain + async def bar(y: str) -> str: + return y + + with pytest.raises(UsageLimitExceeded, match=re.escape('The next request would exceed the request_limit of 1')): + test_agent.run_sync('Hello', usage_limits=UsageLimits(request_limit=1)) + + +async def test_streamed_text_limits() -> None: + m = TestModel() + + test_agent = Agent(m) + assert test_agent.name is None + + @test_agent.tool_plain + async def ret_a(x: str) -> str: + return f'{x}-apple' + + async with test_agent.run_stream('Hello', usage_limits=UsageLimits(response_tokens_limit=10)) as result: + assert test_agent.name == 'test_agent' + assert not result.is_structured + assert not result.is_complete + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[ToolCallPart(tool_name='ret_a', args=ArgsDict(args_dict={'x': 'a'}))], + timestamp=IsNow(tz=timezone.utc), + ), + ModelRequest( + parts=[ToolReturnPart(tool_name='ret_a', content='a-apple', timestamp=IsNow(tz=timezone.utc))] + ), + ] + ) + assert result.usage() == snapshot( + Usage( + requests=2, + request_tokens=103, + response_tokens=5, + total_tokens=108, + ) + ) + with pytest.raises( + UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)') + ): + await result.get_data()