Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 53 additions & 47 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from types import FrameType
from typing import Any, Callable, Generic, Literal, cast, final, overload

import logfire_api
from typing_extensions import assert_never

from . import (
Expand All @@ -19,12 +18,14 @@
exceptions,
messages as _messages,
models,
observe,
result,
)
from .result import ResultData
from .settings import ModelSettings, UsageLimits, merge_model_settings
from .tools import (
AgentDeps,
ModelSelectionMode,
RunContext,
Tool,
ToolDefinition,
Expand All @@ -37,8 +38,6 @@

__all__ = ('Agent',)

_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')

NoneType = type(None)
EndStrategy = Literal['early', 'exhaustive']
"""The strategy for handling multiple tool calls when a final result is found.
Expand Down Expand Up @@ -219,20 +218,15 @@ async def run(
"""
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())
model_used, mode_selection = await self._get_model(model)
model_used, model_selection_mode = await self._get_model(model)

deps = self._get_deps(deps)
new_message_index = len(message_history) if message_history else 0
run_context = RunContext(deps, [], None, model_used, user_prompt, model_selection_mode)
# noinspection PyTypeChecker
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this here?

observer: observe.AbstractAgentObserver[AgentDeps, ResultData] = observe.get_observer(self, run_context)

with _logfire.span(
'{agent_name} run {prompt=}',
prompt=user_prompt,
agent=self,
mode_selection=mode_selection,
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
run_context = RunContext(deps, 0, [], None, model_used)
with observer.run(stream=False) as run_span:
messages = await self._prepare_messages(user_prompt, message_history, run_context)
self.last_run_messages = run_context.messages = messages

Expand All @@ -243,15 +237,14 @@ async def run(
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
usage_limits.check_before_request(usage)

run_step += 1
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
run_context.run_step += 1
with observer.prepare_model():
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request', run_step=run_step) as model_req_span:
with observer.model_request() as model_req_span:
model_response, request_usage = await agent_model.request(messages, model_settings)
model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('usage', request_usage)
Expand All @@ -261,8 +254,10 @@ async def run(
usage.requests += 1
usage_limits.check_tokens(request_usage)

with _logfire.span('handle model response', run_step=run_step) as handle_span:
final_result, tool_responses = await self._handle_model_response(model_response, run_context)
with observer.handle_model_response() as handle_span:
final_result, tool_responses = await self._handle_model_response(
model_response, run_context, observer
)

if tool_responses:
# Add parts to the conversation as a new message
Expand Down Expand Up @@ -378,20 +373,15 @@ async def main():
# f_back because `asynccontextmanager` adds one frame
if frame := inspect.currentframe(): # pragma: no branch
self._infer_name(frame.f_back)
model_used, mode_selection = await self._get_model(model)
model_used, model_selection_mode = await self._get_model(model)

deps = self._get_deps(deps)
new_message_index = len(message_history) if message_history else 0
run_context = RunContext(deps, [], None, model_used, user_prompt, model_selection_mode)
# noinspection PyTypeChecker
observer: observe.AbstractAgentObserver[AgentDeps, ResultData] = observe.get_observer(self, run_context)

with _logfire.span(
'{agent_name} run stream {prompt=}',
prompt=user_prompt,
agent=self,
mode_selection=mode_selection,
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
run_context = RunContext(deps, 0, [], None, model_used)
with observer.run(stream=True) as run_span:
messages = await self._prepare_messages(user_prompt, message_history, run_context)
self.last_run_messages = run_context.messages = messages

Expand All @@ -402,24 +392,25 @@ async def main():
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
run_step += 1
run_context.run_step += 1
usage_limits.check_before_request(usage)

with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
with observer.prepare_model():
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
with observer.model_request() as model_req_span:
async with agent_model.request_stream(messages, model_settings) as model_response:
usage.requests += 1
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
# We want to end the "model request" span here, but we can't exit the context manager
# in the traditional way
model_req_span.__exit__(None, None, None)

with _logfire.span('handle model response') as handle_span:
maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)
with observer.handle_model_response() as handle_span:
maybe_final_result = await self._handle_streamed_model_response(
model_response, run_context, observer
)

# Check if we got a final result
if isinstance(maybe_final_result, _MarkFinalResult):
Expand All @@ -439,7 +430,10 @@ async def on_complete():
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
]
parts = await self._process_function_tools(
tool_calls, result_tool_name, run_context
tool_calls,
result_tool_name,
run_context,
observer,
)
if parts:
messages.append(_messages.ModelRequest(parts))
Expand Down Expand Up @@ -784,7 +778,9 @@ def _register_tool(self, tool: Tool[AgentDeps]) -> None:

self._function_tools[tool.name] = tool

async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
async def _get_model(
self, model: models.Model | models.KnownModelName | None
) -> tuple[models.Model, ModelSelectionMode]:
"""Create a model configured for this agent.

Args:
Expand All @@ -802,18 +798,18 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
'(Even when `override(model=...)` is customizing the model that will actually be called)'
)
model_ = some_model.value
mode_selection = 'override-model'
model_selection_mode: ModelSelectionMode = 'override'
elif model is not None:
model_ = models.infer_model(model)
mode_selection = 'custom'
model_selection_mode = 'custom'
elif self.model is not None:
# noinspection PyTypeChecker
model_ = self.model = models.infer_model(self.model)
mode_selection = 'from-agent'
model_selection_mode = 'from-agent'
else:
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')

return model_, mode_selection
return model_, model_selection_mode

async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
"""Build tools and create an agent model."""
Expand Down Expand Up @@ -847,7 +843,10 @@ async def _prepare_messages(
return messages

async def _handle_model_response(
self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps]
self,
model_response: _messages.ModelResponse,
run_context: RunContext[AgentDeps],
observer: observe.AbstractAgentObserver[AgentDeps, ResultData],
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
"""Process a non-streamed response from the model.

Expand All @@ -869,7 +868,7 @@ async def _handle_model_response(
# This accounts for cases like anthropic returns that might contain a text response
# and a tool call response, where the text response just indicates the tool call will happen.
if tool_calls:
return await self._handle_structured_response(tool_calls, run_context)
return await self._handle_structured_response(tool_calls, run_context, observer)
elif texts:
text = '\n\n'.join(texts)
return await self._handle_text_response(text, run_context)
Expand Down Expand Up @@ -897,7 +896,10 @@ async def _handle_text_response(
return None, [response]

async def _handle_structured_response(
self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps]
self,
tool_calls: list[_messages.ToolCallPart],
run_context: RunContext[AgentDeps],
observer: observe.AbstractAgentObserver[AgentDeps, ResultData],
) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]:
"""Handle a structured response containing tool calls from the model for non-streaming responses."""
assert tool_calls, 'Expected at least one tool call'
Expand All @@ -919,7 +921,9 @@ async def _handle_structured_response(
final_result = _MarkFinalResult(result_data, call.tool_name)

# Then build the other request parts based on end strategy
parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context)
parts += await self._process_function_tools(
tool_calls, final_result and final_result.tool_name, run_context, observer
)

return final_result, parts

Expand All @@ -928,6 +932,7 @@ async def _process_function_tools(
tool_calls: list[_messages.ToolCallPart],
result_tool_name: str | None,
run_context: RunContext[AgentDeps],
observer: observe.AbstractAgentObserver[AgentDeps, ResultData],
) -> list[_messages.ModelRequestPart]:
"""Process function (non-result) tool calls in parallel.

Expand Down Expand Up @@ -977,7 +982,7 @@ async def _process_function_tools(

# Run all tool tasks in parallel
if tasks:
with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
with observer.run_tools(tasks):
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
parts.extend(task_results)
return parts
Expand All @@ -986,6 +991,7 @@ async def _handle_streamed_model_response(
self,
model_response: models.EitherStreamedResponse,
run_context: RunContext[AgentDeps],
observer: observe.AbstractAgentObserver[AgentDeps, ResultData],
) -> (
_MarkFinalResult[models.EitherStreamedResponse]
| tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]]
Expand Down Expand Up @@ -1045,7 +1051,7 @@ async def _handle_streamed_model_response(
else:
parts.append(self._unknown_tool(call.tool_name, run_context))

with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]):
with observer.run_tools(tasks):
task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks)
parts.extend(task_results)
return model_response_msg, parts
Expand Down
Loading
Loading