diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index fa1188bdbd..793fe92cbb 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -9,7 +9,6 @@ from types import FrameType from typing import Any, Callable, Generic, Literal, cast, final, overload -import logfire_api from typing_extensions import assert_never from . import ( @@ -19,12 +18,14 @@ exceptions, messages as _messages, models, + observe, result, ) from .result import ResultData from .settings import ModelSettings, UsageLimits, merge_model_settings from .tools import ( AgentDeps, + ModelSelectionMode, RunContext, Tool, ToolDefinition, @@ -37,8 +38,6 @@ __all__ = ('Agent',) -_logfire = logfire_api.Logfire(otel_scope='pydantic-ai') - NoneType = type(None) EndStrategy = Literal['early', 'exhaustive'] """The strategy for handling multiple tool calls when a final result is found. @@ -219,20 +218,15 @@ async def run( """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - model_used, mode_selection = await self._get_model(model) + model_used, model_selection_mode = await self._get_model(model) deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 + run_context = RunContext(deps, [], None, model_used, user_prompt, model_selection_mode) + # noinspection PyTypeChecker + observer: observe.AbstractAgentObserver[AgentDeps, ResultData] = observe.get_observer(self, run_context) - with _logfire.span( - '{agent_name} run {prompt=}', - prompt=user_prompt, - agent=self, - mode_selection=mode_selection, - model_name=model_used.name(), - agent_name=self.name or 'agent', - ) as run_span: - run_context = RunContext(deps, 0, [], None, model_used) + with observer.run(stream=False) as run_span: messages = await self._prepare_messages(user_prompt, message_history, run_context) self.last_run_messages = run_context.messages = messages @@ -243,15 +237,14 @@ async def run( model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or UsageLimits() - run_step = 0 while True: usage_limits.check_before_request(usage) - run_step += 1 - with _logfire.span('preparing model and tools {run_step=}', run_step=run_step): + run_context.run_step += 1 + with observer.prepare_model(): agent_model = await self._prepare_model(run_context) - with _logfire.span('model request', run_step=run_step) as model_req_span: + with observer.model_request() as model_req_span: model_response, request_usage = await agent_model.request(messages, model_settings) model_req_span.set_attribute('response', model_response) model_req_span.set_attribute('usage', request_usage) @@ -261,8 +254,10 @@ async def run( usage.requests += 1 usage_limits.check_tokens(request_usage) - with _logfire.span('handle model response', run_step=run_step) as handle_span: - final_result, tool_responses = await self._handle_model_response(model_response, run_context) + with observer.handle_model_response() as handle_span: + final_result, tool_responses = await self._handle_model_response( + model_response, run_context, observer + ) if tool_responses: # Add parts to the conversation as a new message @@ -378,20 +373,15 @@ async def main(): # f_back because `asynccontextmanager` adds one frame if frame := inspect.currentframe(): # pragma: no branch self._infer_name(frame.f_back) - model_used, mode_selection = await self._get_model(model) + model_used, model_selection_mode = await self._get_model(model) deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 + run_context = RunContext(deps, [], None, model_used, user_prompt, model_selection_mode) + # noinspection PyTypeChecker + observer: observe.AbstractAgentObserver[AgentDeps, ResultData] = observe.get_observer(self, run_context) - with _logfire.span( - '{agent_name} run stream {prompt=}', - prompt=user_prompt, - agent=self, - mode_selection=mode_selection, - model_name=model_used.name(), - agent_name=self.name or 'agent', - ) as run_span: - run_context = RunContext(deps, 0, [], None, model_used) + with observer.run(stream=True) as run_span: messages = await self._prepare_messages(user_prompt, message_history, run_context) self.last_run_messages = run_context.messages = messages @@ -402,15 +392,14 @@ async def main(): model_settings = merge_model_settings(self.model_settings, model_settings) usage_limits = usage_limits or UsageLimits() - run_step = 0 while True: - run_step += 1 + run_context.run_step += 1 usage_limits.check_before_request(usage) - with _logfire.span('preparing model and tools {run_step=}', run_step=run_step): + with observer.prepare_model(): agent_model = await self._prepare_model(run_context) - with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: + with observer.model_request() as model_req_span: async with agent_model.request_stream(messages, model_settings) as model_response: usage.requests += 1 model_req_span.set_attribute('response_type', model_response.__class__.__name__) @@ -418,8 +407,10 @@ async def main(): # in the traditional way model_req_span.__exit__(None, None, None) - with _logfire.span('handle model response') as handle_span: - maybe_final_result = await self._handle_streamed_model_response(model_response, run_context) + with observer.handle_model_response() as handle_span: + maybe_final_result = await self._handle_streamed_model_response( + model_response, run_context, observer + ) # Check if we got a final result if isinstance(maybe_final_result, _MarkFinalResult): @@ -439,7 +430,10 @@ async def on_complete(): part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) ] parts = await self._process_function_tools( - tool_calls, result_tool_name, run_context + tool_calls, + result_tool_name, + run_context, + observer, ) if parts: messages.append(_messages.ModelRequest(parts)) @@ -784,7 +778,9 @@ def _register_tool(self, tool: Tool[AgentDeps]) -> None: self._function_tools[tool.name] = tool - async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]: + async def _get_model( + self, model: models.Model | models.KnownModelName | None + ) -> tuple[models.Model, ModelSelectionMode]: """Create a model configured for this agent. Args: @@ -802,18 +798,18 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) - '(Even when `override(model=...)` is customizing the model that will actually be called)' ) model_ = some_model.value - mode_selection = 'override-model' + model_selection_mode: ModelSelectionMode = 'override' elif model is not None: model_ = models.infer_model(model) - mode_selection = 'custom' + model_selection_mode = 'custom' elif self.model is not None: # noinspection PyTypeChecker model_ = self.model = models.infer_model(self.model) - mode_selection = 'from-agent' + model_selection_mode = 'from-agent' else: raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.') - return model_, mode_selection + return model_, model_selection_mode async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel: """Build tools and create an agent model.""" @@ -847,7 +843,10 @@ async def _prepare_messages( return messages async def _handle_model_response( - self, model_response: _messages.ModelResponse, run_context: RunContext[AgentDeps] + self, + model_response: _messages.ModelResponse, + run_context: RunContext[AgentDeps], + observer: observe.AbstractAgentObserver[AgentDeps, ResultData], ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]: """Process a non-streamed response from the model. @@ -869,7 +868,7 @@ async def _handle_model_response( # This accounts for cases like anthropic returns that might contain a text response # and a tool call response, where the text response just indicates the tool call will happen. if tool_calls: - return await self._handle_structured_response(tool_calls, run_context) + return await self._handle_structured_response(tool_calls, run_context, observer) elif texts: text = '\n\n'.join(texts) return await self._handle_text_response(text, run_context) @@ -897,7 +896,10 @@ async def _handle_text_response( return None, [response] async def _handle_structured_response( - self, tool_calls: list[_messages.ToolCallPart], run_context: RunContext[AgentDeps] + self, + tool_calls: list[_messages.ToolCallPart], + run_context: RunContext[AgentDeps], + observer: observe.AbstractAgentObserver[AgentDeps, ResultData], ) -> tuple[_MarkFinalResult[ResultData] | None, list[_messages.ModelRequestPart]]: """Handle a structured response containing tool calls from the model for non-streaming responses.""" assert tool_calls, 'Expected at least one tool call' @@ -919,7 +921,9 @@ async def _handle_structured_response( final_result = _MarkFinalResult(result_data, call.tool_name) # Then build the other request parts based on end strategy - parts += await self._process_function_tools(tool_calls, final_result and final_result.tool_name, run_context) + parts += await self._process_function_tools( + tool_calls, final_result and final_result.tool_name, run_context, observer + ) return final_result, parts @@ -928,6 +932,7 @@ async def _process_function_tools( tool_calls: list[_messages.ToolCallPart], result_tool_name: str | None, run_context: RunContext[AgentDeps], + observer: observe.AbstractAgentObserver[AgentDeps, ResultData], ) -> list[_messages.ModelRequestPart]: """Process function (non-result) tool calls in parallel. @@ -977,7 +982,7 @@ async def _process_function_tools( # Run all tool tasks in parallel if tasks: - with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): + with observer.run_tools(tasks): task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) parts.extend(task_results) return parts @@ -986,6 +991,7 @@ async def _handle_streamed_model_response( self, model_response: models.EitherStreamedResponse, run_context: RunContext[AgentDeps], + observer: observe.AbstractAgentObserver[AgentDeps, ResultData], ) -> ( _MarkFinalResult[models.EitherStreamedResponse] | tuple[_messages.ModelResponse, list[_messages.ModelRequestPart]] @@ -1045,7 +1051,7 @@ async def _handle_streamed_model_response( else: parts.append(self._unknown_tool(call.tool_name, run_context)) - with _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]): + with observer.run_tools(tasks): task_results: Sequence[_messages.ModelRequestPart] = await asyncio.gather(*tasks) parts.extend(task_results) return model_response_msg, parts diff --git a/pydantic_ai_slim/pydantic_ai/observe.py b/pydantic_ai_slim/pydantic_ai/observe.py new file mode 100644 index 0000000000..177681f9ce --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/observe.py @@ -0,0 +1,180 @@ +from __future__ import annotations as _annotations + +import asyncio +from abc import ABC +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, Protocol + +from typing_extensions import Self + +from .result import ResultData +from .tools import AgentDeps, RunContext + +if TYPE_CHECKING: + from logfire import LogfireSpan as _LogfireSpan + + from .agent import Agent + +__all__ = 'get_observer', 'AbstractAgentObserver', 'AbstractSpan' + + +# noinspection PyTypeChecker +def get_observer( + agent: Agent[AgentDeps, ResultData], run_context: RunContext[AgentDeps] +) -> AbstractAgentObserver[AgentDeps, ResultData]: + """Get an observer for the agent.""" + if _LogfireAgentObserver is not None: # pyright: ignore[reportUnnecessaryComparison] + return _LogfireAgentObserver(agent, run_context) + else: + return _NoOpAgentObserver(agent, run_context) + + +class AbstractSpan(Protocol): + """Abstract definition of a span for tracing the progress of an agent run.""" + + def __enter__(self) -> AbstractSpan: ... + + def __exit__( + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any + ) -> None: ... + + def set_attribute(self, key: str, value: Any) -> None: ... + + @property + def message(self) -> str: ... + + @message.setter + def message(self, message: str) -> None: ... + + +@dataclass +class AbstractAgentObserver(ABC, Generic[AgentDeps, ResultData]): + """Abstract bae class of an observer to monitor the progress of an agent run.""" + + agent: Agent[AgentDeps, ResultData] + run_context: RunContext[AgentDeps] + + def run(self, *, stream: bool) -> AbstractSpan: + return NoOpSpan() + + def prepare_model(self) -> AbstractSpan: + return NoOpSpan() + + def model_request(self) -> AbstractSpan: + return NoOpSpan() + + def handle_model_response(self) -> AbstractSpan: + return NoOpSpan() + + def run_tools(self, tasks: list[asyncio.Task[Any]]) -> AbstractSpan: + return NoOpSpan() + + +@dataclass +class _CompoundSpan(AbstractSpan): + _spans: list[AbstractSpan] + + def __enter__(self) -> Self: + for span in self._spans: + span.__enter__() + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: + for span in reversed(self._spans): + span.__exit__(exc_type, exc_value, traceback) + + def set_attribute(self, key: str, value: Any) -> None: + for span in self._spans: + span.set_attribute(key, value) + + # noinspection PyProtocol + @property + def message(self) -> str: + return next((span.message for span in self._spans if span.message), '') + + @message.setter + def message(self, message: str) -> None: + for span in self._spans: + span.message = message + + +@dataclass +class _CompoundAgentObserver(AbstractAgentObserver[AgentDeps, ResultData]): # type: ignore + _observers: list[AbstractAgentObserver[AgentDeps, ResultData]] + + def run(self, *, stream: bool) -> AbstractSpan: + spans = [observer.run(stream=stream) for observer in self._observers] + return _CompoundSpan(spans) + + def prepare_model(self) -> AbstractSpan: + spans = [observer.prepare_model() for observer in self._observers] + return _CompoundSpan(spans) + + def model_request(self) -> AbstractSpan: + spans = [observer.model_request() for observer in self._observers] + return _CompoundSpan(spans) + + def handle_model_response(self) -> AbstractSpan: + spans = [observer.handle_model_response() for observer in self._observers] + return _CompoundSpan(spans) + + def run_tools(self, tasks: list[asyncio.Task[Any]]) -> AbstractSpan: + spans = [observer.run_tools(tasks) for observer in self._observers] + return _CompoundSpan(spans) + + +class _NoOpAgentObserver(AbstractAgentObserver[AgentDeps, ResultData]): + pass + + +try: + # noinspection PyUnresolvedReferences + from logfire import Logfire as _Logfire +except ImportError: + _LogfireAgentObserver = None # pyright: ignore[reportAssignmentType] +else: + _logfire = _Logfire(otel_scope='pydantic-ai') + + class _LogfireAgentObserver(AbstractAgentObserver[AgentDeps, ResultData]): + def run(self, *, stream: bool) -> _LogfireSpan: + return _logfire.span( + '{agent_name} run stream {prompt=}' if stream else '{agent_name} run {prompt=}', + prompt=self.run_context.prompt, + agent=self.agent, + model_selection_mode=self.run_context.model_selection_mode, + model_name=self.run_context.model.name(), + agent_name=self.agent.name or 'agent', + stream=stream, + ) + + def prepare_model(self) -> _LogfireSpan: + return _logfire.span('preparing model and tools {run_step=}', run_step=self.run_context.run_step) + + def model_request(self) -> _LogfireSpan: + return _logfire.span('model request') + + def handle_model_response(self) -> _LogfireSpan: + return _logfire.span('handle model response') + + def run_tools(self, tasks: list[asyncio.Task[Any]]) -> _LogfireSpan: + return _logfire.span('running {tools=}', tools=[t.get_name() for t in tasks]) + + +class NoOpSpan(AbstractSpan): + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: Any) -> None: + pass + + def set_attribute(self, key: str, value: Any) -> None: + pass + + # noinspection PyProtocol + @property + def message(self) -> str: + return '' + + @message.setter + def message(self, message: str) -> None: + pass diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 72a66f5b49..035efac570 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -4,7 +4,7 @@ import inspect from collections.abc import Awaitable from dataclasses import dataclass, field -from typing import Any, Callable, Generic, TypeVar, Union, cast +from typing import Any, Callable, Generic, Literal, TypeVar, Union, cast from pydantic import ValidationError from pydantic_core import SchemaValidator @@ -15,6 +15,7 @@ __all__ = ( 'AgentDeps', + 'ModelSelectionMode', 'RunContext', 'SystemPromptFunc', 'ToolFuncContext', @@ -29,6 +30,15 @@ AgentDeps = TypeVar('AgentDeps') """Type variable for agent dependencies.""" +ModelSelectionMode = Literal['from-agent', 'custom', 'override'] +"""How the [model][pydantic_ai.models.Model] was selected for the run. + +Meanings: + +- `'from-agent'`: The model set on the agent was used +- `'custom'`: The model was set via the `model` kwarg, e.g. on [`run`][pydantic_ai.Agent.run] +- `'override'`: The model was set by the [`override`][pydantic_ai.Agent.override] function. +""" @dataclasses.dataclass @@ -37,14 +47,20 @@ class RunContext(Generic[AgentDeps]): deps: AgentDeps """Dependencies for the agent.""" - retry: int - """Number of retries so far.""" messages: list[_messages.ModelMessage] """Messages exchanged in the conversation so far.""" tool_name: str | None """Name of the tool being called.""" model: models.Model """The model used in this run.""" + prompt: str + """The original user prompt passed to the run.""" + model_selection_mode: ModelSelectionMode + """Dependencies for the agent.""" + retry: int = 0 + """Number of retries so far.""" + run_step: int = 0 + """The current step in the run.""" def replace_with( self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 85ec8f1e4f..3037a5d6b4 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -1,6 +1,7 @@ from __future__ import annotations as _annotations from dataclasses import dataclass +from pathlib import Path from typing import Any, Callable import pytest @@ -8,15 +9,18 @@ from inline_snapshot import snapshot from typing_extensions import NotRequired, TypedDict +import pydantic_ai from pydantic_ai import Agent from pydantic_ai.models.test import TestModel try: + import logfire._internal.stack_info from logfire.testing import CaptureLogfire except ImportError: logfire_installed = False else: logfire_installed = True + logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(pydantic_ai.__file__).parent.absolute()),) class SpanSummary(TypedDict): @@ -91,8 +95,8 @@ async def my_ret(x: int) -> str: ) assert summary.attributes[0] == snapshot( { - 'code.filepath': 'agent.py', - 'code.function': 'run', + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_logfire', 'code.lineno': 123, 'prompt': 'Hello', 'agent': IsJson( @@ -112,8 +116,9 @@ async def my_ret(x: int) -> str: 'model_settings': None, } ), - 'mode_selection': 'from-agent', + 'model_selection_mode': 'from-agent', 'model_name': 'test-model', + 'stream': False, 'agent_name': 'my_agent', 'logfire.msg_template': '{agent_name} run {prompt=}', 'logfire.msg': 'my_agent run prompt=Hello', @@ -177,7 +182,8 @@ async def my_ret(x: int) -> str: 'model': {'type': 'object', 'title': 'TestModel', 'x-python-datatype': 'dataclass'} }, }, - 'mode_selection': {}, + 'model_selection_mode': {}, + 'stream': {}, 'model_name': {}, 'agent_name': {}, 'all_messages': { @@ -264,8 +270,8 @@ async def my_ret(x: int) -> str: ) assert summary.attributes[1] == snapshot( { - 'code.filepath': 'agent.py', - 'code.function': 'run', + 'code.filepath': 'test_logfire.py', + 'code.function': 'test_logfire', 'code.lineno': IsInt(), 'run_step': 1, 'logfire.msg_template': 'preparing model and tools {run_step=}',