Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@

_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')

# while waiting for https://github.com/pydantic/logfire/issues/745
try:
import logfire._internal.stack_info
except ImportError:
pass
else:
from pathlib import Path

logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)

NoneType = type(None)
EndStrategy = Literal['early', 'exhaustive']
"""The strategy for handling multiple tool calls when a final result is found.
Expand Down Expand Up @@ -215,7 +225,7 @@ async def run(
"""
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())
model_used, mode_selection = await self._get_model(model)
model_used = await self._get_model(model)

deps = self._get_deps(deps)
new_message_index = len(message_history) if message_history else 0
Expand All @@ -224,11 +234,10 @@ async def run(
'{agent_name} run {prompt=}',
prompt=user_prompt,
agent=self,
mode_selection=mode_selection,
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
messages = await self._prepare_messages(user_prompt, message_history, run_context)
run_context.messages = messages

Expand All @@ -238,15 +247,14 @@ async def run(
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
usage_limits.check_before_request(run_context.usage)

run_step += 1
with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
run_context.run_step += 1
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request', run_step=run_step) as model_req_span:
with _logfire.span('model request', run_step=run_context.run_step) as model_req_span:
model_response, request_usage = await agent_model.request(messages, model_settings)
model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('usage', request_usage)
Expand All @@ -255,7 +263,7 @@ async def run(
run_context.usage.incr(request_usage, requests=1)
usage_limits.check_tokens(run_context.usage)

with _logfire.span('handle model response', run_step=run_step) as handle_span:
with _logfire.span('handle model response', run_step=run_context.run_step) as handle_span:
final_result, tool_responses = await self._handle_model_response(model_response, run_context)

if tool_responses:
Expand Down Expand Up @@ -377,7 +385,7 @@ async def main():
# f_back because `asynccontextmanager` adds one frame
if frame := inspect.currentframe(): # pragma: no branch
self._infer_name(frame.f_back)
model_used, mode_selection = await self._get_model(model)
model_used = await self._get_model(model)

deps = self._get_deps(deps)
new_message_index = len(message_history) if message_history else 0
Expand All @@ -386,11 +394,10 @@ async def main():
'{agent_name} run stream {prompt=}',
prompt=user_prompt,
agent=self,
mode_selection=mode_selection,
model_name=model_used.name(),
agent_name=self.name or 'agent',
) as run_span:
run_context = RunContext(deps, 0, [], None, model_used, usage or result.Usage())
run_context = RunContext(deps, model_used, usage or result.Usage(), user_prompt)
messages = await self._prepare_messages(user_prompt, message_history, run_context)
run_context.messages = messages

Expand All @@ -400,15 +407,14 @@ async def main():
model_settings = merge_model_settings(self.model_settings, model_settings)
usage_limits = usage_limits or UsageLimits()

run_step = 0
while True:
run_step += 1
run_context.run_step += 1
usage_limits.check_before_request(run_context.usage)

with _logfire.span('preparing model and tools {run_step=}', run_step=run_step):
with _logfire.span('preparing model and tools {run_step=}', run_step=run_context.run_step):
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
with _logfire.span('model request {run_step=}', run_step=run_context.run_step) as model_req_span:
async with agent_model.request_stream(messages, model_settings) as model_response:
run_context.usage.requests += 1
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
Expand Down Expand Up @@ -781,14 +787,14 @@ def _register_tool(self, tool: Tool[AgentDeps]) -> None:

self._function_tools[tool.name] = tool

async def _get_model(self, model: models.Model | models.KnownModelName | None) -> tuple[models.Model, str]:
async def _get_model(self, model: models.Model | models.KnownModelName | None) -> models.Model:
"""Create a model configured for this agent.

Args:
model: model to use for this run, required if `model` was not set when creating the agent.

Returns:
a tuple of `(model used, how the model was selected)`
The model used
"""
model_: models.Model
if some_model := self._override_model:
Expand All @@ -799,18 +805,15 @@ async def _get_model(self, model: models.Model | models.KnownModelName | None) -
'(Even when `override(model=...)` is customizing the model that will actually be called)'
)
model_ = some_model.value
mode_selection = 'override-model'
elif model is not None:
model_ = models.infer_model(model)
mode_selection = 'custom'
elif self.model is not None:
# noinspection PyTypeChecker
model_ = self.model = models.infer_model(self.model)
mode_selection = 'from-agent'
else:
raise exceptions.UserError('`model` must be set either when creating the agent or when calling it.')

return model_, mode_selection
return model_

async def _prepare_model(self, run_context: RunContext[AgentDeps]) -> models.AgentModel:
"""Build tools and create an agent model."""
Expand Down
16 changes: 10 additions & 6 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,20 @@ class RunContext(Generic[AgentDeps]):

deps: AgentDeps
"""Dependencies for the agent."""
retry: int
"""Number of retries so far."""
messages: list[_messages.ModelMessage]
"""Messages exchanged in the conversation so far."""
tool_name: str | None
"""Name of the tool being called."""
model: models.Model
"""The model used in this run."""
usage: Usage
"""LLM usage associated with the run."""
prompt: str
"""The original user prompt passed to the run."""
messages: list[_messages.ModelMessage] = field(default_factory=list)
"""Messages exchanged in the conversation so far."""
tool_name: str | None = None
"""Name of the tool being called."""
retry: int = 0
"""Number of retries so far."""
run_step: int = 0
"""The current step in the run."""

def replace_with(
self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET
Expand Down
10 changes: 4 additions & 6 deletions tests/test_logfire.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ async def my_ret(x: int) -> str:
)
assert summary.attributes[0] == snapshot(
{
'code.filepath': 'agent.py',
'code.function': 'run',
'code.filepath': 'test_logfire.py',
'code.function': 'test_logfire',
'code.lineno': 123,
'prompt': 'Hello',
'agent': IsJson(
Expand All @@ -111,7 +111,6 @@ async def my_ret(x: int) -> str:
'model_settings': None,
}
),
'mode_selection': 'from-agent',
'model_name': 'test-model',
'agent_name': 'my_agent',
'logfire.msg_template': '{agent_name} run {prompt=}',
Expand Down Expand Up @@ -176,7 +175,6 @@ async def my_ret(x: int) -> str:
'model': {'type': 'object', 'title': 'TestModel', 'x-python-datatype': 'dataclass'}
},
},
'mode_selection': {},
'model_name': {},
'agent_name': {},
'all_messages': {
Expand Down Expand Up @@ -263,8 +261,8 @@ async def my_ret(x: int) -> str:
)
assert summary.attributes[1] == snapshot(
{
'code.filepath': 'agent.py',
'code.function': 'run',
'code.filepath': 'test_logfire.py',
'code.function': 'test_logfire',
'code.lineno': IsInt(),
'run_step': 1,
'logfire.msg_template': 'preparing model and tools {run_step=}',
Expand Down
Loading