Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/durable_execution/temporal.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ As workflows and activities run in separate processes, any values passed between

To account for these limitations, tool functions and the [event stream handler](#streaming) running inside activities receive a limited version of the agent's [`RunContext`][pydantic_ai.tools.RunContext], and it's your responsibility to make sure that the [dependencies](../dependencies.md) object provided to [`TemporalAgent.run()`][pydantic_ai.durable_exec.temporal.TemporalAgent.run] can be serialized using Pydantic.

Specifically, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, and `run_step` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error.
Specifically, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` fields are available by default, and trying to access `model`, `usage`, `prompt`, `messages`, or `tracer` will raise an error.
If you need one or more of these attributes to be available inside activities, you can create a [`TemporalRunContext`][pydantic_ai.durable_exec.temporal.TemporalRunContext] subclass with custom `serialize_run_context` and `deserialize_run_context` class methods and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent] as `run_context_type`.

### Streaming
Expand Down
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,17 @@ class RunContext(Generic[AgentDepsT]):
tool_name: str | None = None
"""Name of the tool being called."""
retry: int = 0
"""Number of retries so far."""
"""Number of retries of this tool so far."""
max_retries: int = 0
"""The maximum number of retries of this tool."""
run_step: int = 0
"""The current step in the run."""
tool_call_approved: bool = False
"""Whether a tool call that required approval has now been approved."""

@property
def last_attempt(self) -> bool:
"""Whether this is the last attempt at running this tool before an error is raised."""
return self.retry == self.max_retries

__repr__ = _utils.dataclasses_no_defaults_repr
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ async def _call_tool(
tool_name=name,
tool_call_id=call.tool_call_id,
retry=self.ctx.retries.get(name, 0),
max_retries=tool.max_retries,
)

pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class TemporalRunContext(RunContext[AgentDepsT]):
"""The [`RunContext`][pydantic_ai.tools.RunContext] subclass to use to serialize and deserialize the run context for use inside a Temporal activity.

By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry` and `run_step` attributes will be available.
By default, only the `deps`, `retries`, `tool_call_id`, `tool_name`, `tool_call_approved`, `retry`, `max_retries` and `run_step` attributes will be available.
To make another attribute available, create a `TemporalRunContext` subclass with a custom `serialize_run_context` class method that returns a dictionary that includes the attribute and pass it to [`TemporalAgent`][pydantic_ai.durable_exec.temporal.TemporalAgent].
"""

Expand Down Expand Up @@ -42,6 +42,7 @@ def serialize_run_context(cls, ctx: RunContext[Any]) -> dict[str, Any]:
'tool_name': ctx.tool_name,
'tool_call_approved': ctx.tool_call_approved,
'retry': ctx.retry,
'max_retries': ctx.max_retries,
'run_step': ctx.run_step,
}

Expand Down
10 changes: 8 additions & 2 deletions pydantic_ai_slim/pydantic_ai/toolsets/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,13 @@ def add_tool(self, tool: Tool[AgentDepsT]) -> None:
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
tools: dict[str, ToolsetTool[AgentDepsT]] = {}
for original_name, tool in self.tools.items():
run_context = replace(ctx, tool_name=original_name, retry=ctx.retries.get(original_name, 0))
max_retries = tool.max_retries if tool.max_retries is not None else self.max_retries
run_context = replace(
ctx,
tool_name=original_name,
retry=ctx.retries.get(original_name, 0),
max_retries=max_retries,
)
tool_def = await tool.prepare_tool_def(run_context)
if not tool_def:
continue
Expand All @@ -324,7 +330,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
tools[new_name] = FunctionToolsetTool(
toolset=self,
tool_def=tool_def,
max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries,
max_retries=max_retries,
args_validator=tool.function_schema.validator,
call_func=tool.function_schema.call,
is_async=tool.function_schema.is_async,
Expand Down
95 changes: 95 additions & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1263,7 +1263,11 @@ async def function(*args: Any, **kwargs: Any) -> str:
def test_tool_retries():
prepare_tools_retries: list[int] = []
prepare_retries: list[int] = []
prepare_max_retries: list[int] = []
prepare_last_attempt: list[bool] = []
call_retries: list[int] = []
call_max_retries: list[int] = []
call_last_attempt: list[bool] = []

async def prepare_tool_defs(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition] | None:
nonlocal prepare_tools_retries
Expand All @@ -1276,20 +1280,30 @@ async def prepare_tool_defs(ctx: RunContext[None], tool_defs: list[ToolDefinitio
async def prepare_tool_def(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition | None:
nonlocal prepare_retries
prepare_retries.append(ctx.retry)
prepare_max_retries.append(ctx.max_retries)
prepare_last_attempt.append(ctx.last_attempt)
return tool_def

@agent.tool(retries=5, prepare=prepare_tool_def)
def infinite_retry_tool(ctx: RunContext[None]) -> int:
nonlocal call_retries
call_retries.append(ctx.retry)
call_max_retries.append(ctx.max_retries)
call_last_attempt.append(ctx.last_attempt)
raise ModelRetry('Please try again.')

with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"):
agent.run_sync('Begin infinite retry loop!')

assert prepare_tools_retries == snapshot([0, 1, 2, 3, 4, 5])

assert prepare_retries == snapshot([0, 1, 2, 3, 4, 5])
assert prepare_max_retries == snapshot([5, 5, 5, 5, 5, 5])
assert prepare_last_attempt == snapshot([False, False, False, False, False, True])

assert call_retries == snapshot([0, 1, 2, 3, 4, 5])
assert call_max_retries == snapshot([5, 5, 5, 5, 5, 5])
assert call_last_attempt == snapshot([False, False, False, False, False, True])


def test_tool_raises_call_deferred():
Expand Down Expand Up @@ -2093,3 +2107,84 @@ def standalone_func(ctx: RunContext[None], b: float) -> float:
toolset.add_function(standalone_func, metadata={'method': 'add_function'})
standalone_tool_def = toolset.tools['standalone_func']
assert standalone_tool_def.metadata == {'foo': 'bar', 'method': 'add_function'}


def test_retry_tool_until_last_attempt():
model = TestModel()
agent = Agent(model, retries=2)

@agent.tool
def always_fail(ctx: RunContext[None]) -> str:
if ctx.last_attempt:
return 'I guess you never learn'
else:
raise ModelRetry('Please try again.')

result = agent.run_sync('Always fail!')
assert result.output == snapshot('{"always_fail":"I guess you never learn"}')
assert result.all_messages() == snapshot(
[
ModelRequest(
parts=[
UserPromptPart(
content='Always fail!',
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[ToolCallPart(tool_name='always_fail', args={}, tool_call_id=IsStr())],
usage=RequestUsage(input_tokens=52, output_tokens=2),
model_name='test',
timestamp=IsDatetime(),
),
ModelRequest(
parts=[
RetryPromptPart(
content='Please try again.',
tool_name='always_fail',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[ToolCallPart(tool_name='always_fail', args={}, tool_call_id=IsStr())],
usage=RequestUsage(input_tokens=62, output_tokens=4),
model_name='test',
timestamp=IsDatetime(),
),
ModelRequest(
parts=[
RetryPromptPart(
content='Please try again.',
tool_name='always_fail',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[ToolCallPart(tool_name='always_fail', args={}, tool_call_id=IsStr())],
usage=RequestUsage(input_tokens=72, output_tokens=6),
model_name='test',
timestamp=IsDatetime(),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='always_fail',
content='I guess you never learn',
tool_call_id=IsStr(),
timestamp=IsDatetime(),
)
]
),
ModelResponse(
parts=[TextPart(content='{"always_fail":"I guess you never learn"}')],
usage=RequestUsage(input_tokens=77, output_tokens=14),
model_name='test',
timestamp=IsDatetime(),
),
]
)
Loading