diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 4b8d572046..3a0959ab9a 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -184,6 +184,7 @@ async def run( deps: AgentDeps = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, + usage: result.Usage | None = None, infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt in async mode. @@ -206,6 +207,7 @@ async def run( deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: @@ -226,20 +228,19 @@ async def run( model_name=model_used.name(), agent_name=self.name or 'agent', ) as run_span: - run_context = RunContext(deps, 0, [], None, model_used) + run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage()) messages = await self._prepare_messages(user_prompt, message_history, run_context) run_context.messages = messages for tool in self._function_tools.values(): tool.current_retry = 0 - usage = result.Usage(requests=0) model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or UsageLimits() run_step = 0 while True: - usage_limits.check_before_request(usage) + usage_limits.check_before_request(run_context.usage) run_step += 1 with _logfire.span('preparing model and tools {run_step=}', run_step=run_step): @@ -251,9 +252,8 @@ async def run( model_req_span.set_attribute('usage', request_usage) messages.append(model_response) - usage += request_usage - usage.requests += 1 - usage_limits.check_tokens(request_usage) + run_context.usage.incr(request_usage, requests=1) + usage_limits.check_tokens(run_context.usage) with _logfire.span('handle model response', run_step=run_step) as handle_span: final_result, tool_responses = await self._handle_model_response(model_response, run_context) @@ -266,10 +266,10 @@ async def run( if final_result is not None: result_data = final_result.data run_span.set_attribute('all_messages', messages) - run_span.set_attribute('usage', usage) + run_span.set_attribute('usage', run_context.usage) handle_span.set_attribute('result', result_data) handle_span.message = 'handle model response -> final result' - return result.RunResult(messages, new_message_index, result_data, usage) + return result.RunResult(messages, new_message_index, result_data, run_context.usage) else: # continue the conversation handle_span.set_attribute('tool_responses', tool_responses) @@ -285,6 +285,7 @@ def run_sync( deps: AgentDeps = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, + usage: result.Usage | None = None, infer_name: bool = True, ) -> result.RunResult[ResultData]: """Run the agent with a user prompt synchronously. @@ -311,6 +312,7 @@ async def main(): deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: @@ -326,6 +328,7 @@ async def main(): deps=deps, model_settings=model_settings, usage_limits=usage_limits, + usage=usage, infer_name=False, ) ) @@ -340,6 +343,7 @@ async def run_stream( deps: AgentDeps = None, model_settings: ModelSettings | None = None, usage_limits: UsageLimits | None = None, + usage: result.Usage | None = None, infer_name: bool = True, ) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -363,6 +367,7 @@ async def main(): deps: Optional dependencies to use for this run. model_settings: Optional settings to use for this model's request. usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. Returns: @@ -385,28 +390,27 @@ async def main(): model_name=model_used.name(), agent_name=self.name or 'agent', ) as run_span: - run_context = RunContext(deps, 0, [], None, model_used) + run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage()) messages = await self._prepare_messages(user_prompt, message_history, run_context) run_context.messages = messages for tool in self._function_tools.values(): tool.current_retry = 0 - usage = result.Usage() model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or UsageLimits() run_step = 0 while True: run_step += 1 - usage_limits.check_before_request(usage) + usage_limits.check_before_request(run_context.usage) with _logfire.span('preparing model and tools {run_step=}', run_step=run_step): agent_model = await self._prepare_model(run_context) with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: async with agent_model.request_stream(messages, model_settings) as model_response: - usage.requests += 1 + run_context.usage.requests += 1 model_req_span.set_attribute('response_type', model_response.__class__.__name__) # We want to end the "model request" span here, but we can't exit the context manager # in the traditional way @@ -442,7 +446,6 @@ async def on_complete(): yield result.StreamedRunResult( messages, new_message_index, - usage, usage_limits, result_stream, self._result_schema, @@ -466,8 +469,8 @@ async def on_complete(): handle_span.message = f'handle model response -> {tool_responses_str}' # the model_response should have been fully streamed by now, we can add its usage model_response_usage = model_response.usage() - usage += model_response_usage - usage_limits.check_tokens(usage) + run_context.usage.incr(model_response_usage) + usage_limits.check_tokens(run_context.usage) @contextmanager def override( diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 9cba5f0a4e..14606ea7dc 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Awaitable, Callable +from copy import copy from dataclasses import dataclass, field from datetime import datetime from typing import Generic, Union, cast @@ -63,25 +64,33 @@ class Usage: details: dict[str, int] | None = None """Any extra details returned by the model.""" - def __add__(self, other: Usage) -> Usage: - """Add two Usages together. + def incr(self, incr_usage: Usage, *, requests: int = 0) -> None: + """Increment the usage in place. - This is provided so it's trivial to sum usage information from multiple requests and runs. + Args: + incr_usage: The usage to increment by. + requests: The number of requests to increment by in addition to `incr_usage.requests`. """ - counts: dict[str, int] = {} + self.requests += requests for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens': self_value = getattr(self, f) - other_value = getattr(other, f) + other_value = getattr(incr_usage, f) if self_value is not None or other_value is not None: - counts[f] = (self_value or 0) + (other_value or 0) + setattr(self, f, (self_value or 0) + (other_value or 0)) - details = self.details.copy() if self.details is not None else None - if other.details is not None: - details = details or {} - for key, value in other.details.items(): - details[key] = details.get(key, 0) + value + if incr_usage.details: + self.details = self.details or {} + for key, value in incr_usage.details.items(): + self.details[key] = self.details.get(key, 0) + value - return Usage(**counts, details=details or None) + def __add__(self, other: Usage) -> Usage: + """Add two Usages together. + + This is provided so it's trivial to sum usage information from multiple requests and runs. + """ + new_usage = copy(self) + new_usage.incr(other) + return new_usage @dataclass @@ -136,8 +145,6 @@ def usage(self) -> Usage: class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]): """Result of a streamed run that returns structured data via a tool call.""" - usage_so_far: Usage - """Usage of the run up until the last request.""" _usage_limits: UsageLimits | None _stream_response: models.EitherStreamedResponse _result_schema: _result.ResultSchema[ResultData] | None @@ -306,7 +313,7 @@ def usage(self) -> Usage: !!! note This won't return the full usage until the stream is finished. """ - return self.usage_so_far + self._stream_response.usage() + return self._run_ctx.usage + self._stream_response.usage() def timestamp(self) -> datetime: """Get the timestamp of the response.""" diff --git a/pydantic_ai_slim/pydantic_ai/settings.py b/pydantic_ai_slim/pydantic_ai/settings.py index ce0c7db4c8..9fcadf5a53 100644 --- a/pydantic_ai_slim/pydantic_ai/settings.py +++ b/pydantic_ai_slim/pydantic_ai/settings.py @@ -136,6 +136,6 @@ def check_tokens(self, usage: Usage) -> None: f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})' ) - total_tokens = request_tokens + response_tokens + total_tokens = usage.total_tokens or 0 if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit: raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})') diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 72a66f5b49..34015b4f2b 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -4,7 +4,7 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import Any, Callable, Generic, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast from pydantic import ValidationError from pydantic_core import SchemaValidator @@ -13,6 +13,9 @@ from . import _pydantic, _utils, messages as _messages, models from .exceptions import ModelRetry, UnexpectedModelBehavior +if TYPE_CHECKING: + from .result import Usage + __all__ = ( 'AgentDeps', 'RunContext', @@ -45,6 +48,8 @@ class RunContext(Generic[AgentDeps]): """Name of the tool being called.""" model: models.Model """The model used in this run.""" + usage: Usage + """LLM usage associated with the run.""" def replace_with( self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 926f29a836..37b4514299 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -1,11 +1,20 @@ +import functools +import operator import re from datetime import timezone import pytest from inline_snapshot import snapshot -from pydantic_ai import Agent, UsageLimitExceeded -from pydantic_ai.messages import ArgsDict, ModelRequest, ModelResponse, ToolCallPart, ToolReturnPart, UserPromptPart +from pydantic_ai import Agent, RunContext, UsageLimitExceeded +from pydantic_ai.messages import ( + ArgsDict, + ModelRequest, + ModelResponse, + ToolCallPart, + ToolReturnPart, + UserPromptPart, +) from pydantic_ai.models.test import TestModel from pydantic_ai.result import Usage from pydantic_ai.settings import UsageLimits @@ -97,3 +106,67 @@ async def ret_a(x: str) -> str: UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)') ): await result.get_data() + + +def test_usage_so_far(set_event_loop: None) -> None: + test_agent = Agent(TestModel()) + + with pytest.raises( + UsageLimitExceeded, match=re.escape('Exceeded the total_tokens_limit of 105 (total_tokens=163)') + ): + test_agent.run_sync( + 'Hello, this prompt exceeds the request tokens limit.', + usage_limits=UsageLimits(total_tokens_limit=105), + usage=Usage(total_tokens=100), + ) + + +async def test_multi_agent_usage_no_incr(): + delegate_agent = Agent(TestModel(), result_type=int) + + controller_agent1 = Agent(TestModel()) + run_1_usages: list[Usage] = [] + + @controller_agent1.tool + async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int: + delegate_result = await delegate_agent.run(sentence) + delegate_usage = delegate_result.usage() + run_1_usages.append(delegate_usage) + assert delegate_usage == snapshot(Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55)) + return delegate_result.data + + result1 = await controller_agent1.run('foobar') + assert result1.data == snapshot('{"delegate_to_other_agent1":0}') + run_1_usages.append(result1.usage()) + assert result1.usage() == snapshot(Usage(requests=2, request_tokens=103, response_tokens=13, total_tokens=116)) + + controller_agent2 = Agent(TestModel()) + + @controller_agent2.tool + async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int: + delegate_result = await delegate_agent.run(sentence, usage=ctx.usage) + delegate_usage = delegate_result.usage() + assert delegate_usage == snapshot(Usage(requests=2, request_tokens=102, response_tokens=9, total_tokens=111)) + return delegate_result.data + + result2 = await controller_agent2.run('foobar') + assert result2.data == snapshot('{"delegate_to_other_agent2":0}') + assert result2.usage() == snapshot(Usage(requests=3, request_tokens=154, response_tokens=17, total_tokens=171)) + + # confirm the usage from result2 is the sum of the usage from result1 + assert result2.usage() == functools.reduce(operator.add, run_1_usages) + + +async def test_multi_agent_usage_sync(): + """As in `test_multi_agent_usage_async`, with a sync tool.""" + controller_agent = Agent(TestModel()) + + @controller_agent.tool + def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int: + new_usage = Usage(requests=5, request_tokens=2, response_tokens=3, total_tokens=4) + ctx.usage.incr(new_usage) + return 0 + + result = await controller_agent.run('foobar') + assert result.data == snapshot('{"delegate_to_other_agent":0}') + assert result.usage() == snapshot(Usage(requests=7, request_tokens=105, response_tokens=16, total_tokens=120))