Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ async def run(
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: result.Usage | None = None,
infer_name: bool = True,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt in async mode.
Expand All @@ -206,6 +207,7 @@ async def run(
deps: Optional dependencies to use for this run.
model_settings: Optional settings to use for this model's request.
usage_limits: Optional limits on model request count or token usage.
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.

Returns:
Expand All @@ -226,20 +228,19 @@ async def run(
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
run_context = RunContext(deps, 0, [], None, model_used)
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
messages = await self._prepare_messages(user_prompt, message_history, run_context)
run_context.messages = messages

for tool in self._function_tools.values():
tool.current_retry = 0

usage = result.Usage(requests=0)
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
usage_limits.check_before_request(usage)
usage_limits.check_before_request(run_context.usage)

run_step += 1
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
Expand All @@ -251,9 +252,8 @@ async def run(
model_req_span.set_attribute('usage', request_usage)

messages.append(model_response)
usage += request_usage
usage.requests += 1
usage_limits.check_tokens(request_usage)
run_context.usage.incr(request_usage, requests=1)
usage_limits.check_tokens(run_context.usage)

with _logfire.span('handle model response', run_step=run_step) as handle_span:
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
Expand All @@ -266,10 +266,10 @@ async def run(
if final_result is not None:
result_data = final_result.data
run_span.set_attribute('all_messages', messages)
run_span.set_attribute('usage', usage)
run_span.set_attribute('usage', run_context.usage)
handle_span.set_attribute('result', result_data)
handle_span.message = 'handle model response -> final result'
return result.RunResult(messages, new_message_index, result_data, usage)
return result.RunResult(messages, new_message_index, result_data, run_context.usage)
else:
# continue the conversation
handle_span.set_attribute('tool_responses', tool_responses)
Expand All @@ -285,6 +285,7 @@ def run_sync(
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: result.Usage | None = None,
infer_name: bool = True,
) -> result.RunResult[ResultData]:
"""Run the agent with a user prompt synchronously.
Expand All @@ -311,6 +312,7 @@ async def main():
deps: Optional dependencies to use for this run.
model_settings: Optional settings to use for this model's request.
usage_limits: Optional limits on model request count or token usage.
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.

Returns:
Expand All @@ -326,6 +328,7 @@ async def main():
deps=deps,
model_settings=model_settings,
usage_limits=usage_limits,
usage=usage,
infer_name=False,
)
)
Expand All @@ -340,6 +343,7 @@ async def run_stream(
deps: AgentDeps = None,
model_settings: ModelSettings | None = None,
usage_limits: UsageLimits | None = None,
usage: result.Usage | None = None,
infer_name: bool = True,
) -> AsyncIterator[result.StreamedRunResult[AgentDeps, ResultData]]:
"""Run the agent with a user prompt in async mode, returning a streamed response.
Expand All @@ -363,6 +367,7 @@ async def main():
deps: Optional dependencies to use for this run.
model_settings: Optional settings to use for this model's request.
usage_limits: Optional limits on model request count or token usage.
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.

Returns:
Expand All @@ -385,28 +390,27 @@ async def main():
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
run_context = RunContext(deps, 0, [], None, model_used)
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
messages = await self._prepare_messages(user_prompt, message_history, run_context)
run_context.messages = messages

for tool in self._function_tools.values():
tool.current_retry = 0

usage = result.Usage()
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
run_step += 1
usage_limits.check_before_request(usage)
usage_limits.check_before_request(run_context.usage)

with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
async with agent_model.request_stream(messages, model_settings) as model_response:
usage.requests += 1
run_context.usage.requests += 1
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
# We want to end the "model request" span here, but we can't exit the context manager
# in the traditional way
Expand Down Expand Up @@ -442,7 +446,6 @@ async def on_complete():
yield result.StreamedRunResult(
messages,
new_message_index,
usage,
usage_limits,
result_stream,
self._result_schema,
Expand All @@ -466,8 +469,8 @@ async def on_complete():
handle_span.message = f'handle model response -> {tool_responses_str}'
# the model_response should have been fully streamed by now, we can add its usage
model_response_usage = model_response.usage()
usage += model_response_usage
usage_limits.check_tokens(usage)
run_context.usage.incr(model_response_usage)
usage_limits.check_tokens(run_context.usage)

@contextmanager
def override(
Expand Down
37 changes: 22 additions & 15 deletions pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Awaitable, Callable
from copy import copy
from dataclasses import dataclass, field
from datetime import datetime
from typing import Generic, Union, cast
Expand Down Expand Up @@ -63,25 +64,33 @@ class Usage:
details: dict[str, int] | None = None
"""Any extra details returned by the model."""

def __add__(self, other: Usage) -> Usage:
"""Add two Usages together.
def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
"""Increment the usage in place.

This is provided so it's trivial to sum usage information from multiple requests and runs.
Args:
incr_usage: The usage to increment by.
requests: The number of requests to increment by in addition to `incr_usage.requests`.
"""
counts: dict[str, int] = {}
self.requests += requests
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
self_value = getattr(self, f)
other_value = getattr(other, f)
other_value = getattr(incr_usage, f)
if self_value is not None or other_value is not None:
counts[f] = (self_value or 0) + (other_value or 0)
setattr(self, f, (self_value or 0) + (other_value or 0))

details = self.details.copy() if self.details is not None else None
if other.details is not None:
details = details or {}
for key, value in other.details.items():
details[key] = details.get(key, 0) + value
if incr_usage.details:
self.details = self.details or {}
for key, value in incr_usage.details.items():
self.details[key] = self.details.get(key, 0) + value

return Usage(**counts, details=details or None)
def __add__(self, other: Usage) -> Usage:
"""Add two Usages together.

This is provided so it's trivial to sum usage information from multiple requests and runs.
"""
new_usage = copy(self)
new_usage.incr(other)
return new_usage


@dataclass
Expand Down Expand Up @@ -136,8 +145,6 @@ def usage(self) -> Usage:
class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultData]):
"""Result of a streamed run that returns structured data via a tool call."""

usage_so_far: Usage
"""Usage of the run up until the last request."""
_usage_limits: UsageLimits | None
_stream_response: models.EitherStreamedResponse
_result_schema: _result.ResultSchema[ResultData] | None
Expand Down Expand Up @@ -306,7 +313,7 @@ def usage(self) -> Usage:
!!! note
This won't return the full usage until the stream is finished.
"""
return self.usage_so_far + self._stream_response.usage()
return self._run_ctx.usage + self._stream_response.usage()

def timestamp(self) -> datetime:
"""Get the timestamp of the response."""
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,6 @@ def check_tokens(self, usage: Usage) -> None:
f'Exceeded the response_tokens_limit of {self.response_tokens_limit} ({response_tokens=})'
)

total_tokens = request_tokens + response_tokens
total_tokens = usage.total_tokens or 0
if self.total_tokens_limit is not None and total_tokens > self.total_tokens_limit:
raise UsageLimitExceeded(f'Exceeded the total_tokens_limit of {self.total_tokens_limit} ({total_tokens=})')
7 changes: 6 additions & 1 deletion pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, TypeVar, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union, cast

from pydantic import ValidationError
from pydantic_core import SchemaValidator
Expand All @@ -13,6 +13,9 @@
from . import _pydantic, _utils, messages as _messages, models
from .exceptions import ModelRetry, UnexpectedModelBehavior

if TYPE_CHECKING:
from .result import Usage

__all__ = (
'AgentDeps',
'RunContext',
Expand Down Expand Up @@ -45,6 +48,8 @@ class RunContext(Generic[AgentDeps]):
"""Name of the tool being called."""
model: models.Model
"""The model used in this run."""
usage: Usage
"""LLM usage associated with the run."""

def replace_with(
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
Expand Down
77 changes: 75 additions & 2 deletions tests/test_usage_limits.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import functools
import operator
import re
from datetime import timezone

import pytest
from inline_snapshot import snapshot

from pydantic_ai import Agent, UsageLimitExceeded
from pydantic_ai.messages import ArgsDict, ModelRequest, ModelResponse, ToolCallPart, ToolReturnPart, UserPromptPart
from pydantic_ai import Agent, RunContext, UsageLimitExceeded
from pydantic_ai.messages import (
ArgsDict,
ModelRequest,
ModelResponse,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
)
from pydantic_ai.models.test import TestModel
from pydantic_ai.result import Usage
from pydantic_ai.settings import UsageLimits
Expand Down Expand Up @@ -97,3 +106,67 @@ async def ret_a(x: str) -> str:
UsageLimitExceeded, match=re.escape('Exceeded the response_tokens_limit of 10 (response_tokens=11)')
):
await result.get_data()


def test_usage_so_far(set_event_loop: None) -> None:
test_agent = Agent(TestModel())

with pytest.raises(
UsageLimitExceeded, match=re.escape('Exceeded the total_tokens_limit of 105 (total_tokens=163)')
):
test_agent.run_sync(
'Hello, this prompt exceeds the request tokens limit.',
usage_limits=UsageLimits(total_tokens_limit=105),
usage=Usage(total_tokens=100),
)


async def test_multi_agent_usage_no_incr():
delegate_agent = Agent(TestModel(), result_type=int)

controller_agent1 = Agent(TestModel())
run_1_usages: list[Usage] = []

@controller_agent1.tool
async def delegate_to_other_agent1(ctx: RunContext[None], sentence: str) -> int:
delegate_result = await delegate_agent.run(sentence)
delegate_usage = delegate_result.usage()
run_1_usages.append(delegate_usage)
assert delegate_usage == snapshot(Usage(requests=1, request_tokens=51, response_tokens=4, total_tokens=55))
return delegate_result.data

result1 = await controller_agent1.run('foobar')
assert result1.data == snapshot('{"delegate_to_other_agent1":0}')
run_1_usages.append(result1.usage())
assert result1.usage() == snapshot(Usage(requests=2, request_tokens=103, response_tokens=13, total_tokens=116))

controller_agent2 = Agent(TestModel())

@controller_agent2.tool
async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int:
delegate_result = await delegate_agent.run(sentence, usage=ctx.usage)
delegate_usage = delegate_result.usage()
assert delegate_usage == snapshot(Usage(requests=2, request_tokens=102, response_tokens=9, total_tokens=111))
return delegate_result.data

result2 = await controller_agent2.run('foobar')
assert result2.data == snapshot('{"delegate_to_other_agent2":0}')
assert result2.usage() == snapshot(Usage(requests=3, request_tokens=154, response_tokens=17, total_tokens=171))

# confirm the usage from result2 is the sum of the usage from result1
assert result2.usage() == functools.reduce(operator.add, run_1_usages)


async def test_multi_agent_usage_sync():
"""As in `test_multi_agent_usage_async`, with a sync tool."""
controller_agent = Agent(TestModel())

@controller_agent.tool
def delegate_to_other_agent(ctx: RunContext[None], sentence: str) -> int:
new_usage = Usage(requests=5, request_tokens=2, response_tokens=3, total_tokens=4)
ctx.usage.incr(new_usage)
return 0

result = await controller_agent.run('foobar')
assert result.data == snapshot('{"delegate_to_other_agent":0}')
assert result.usage() == snapshot(Usage(requests=7, request_tokens=105, response_tokens=16, total_tokens=120))
Loading