Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import inspect
import sys
import types
from collections.abc import Awaitable
from collections.abc import Awaitable, Iterable
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, Literal, Union, cast, get_args, get_origin

Expand Down Expand Up @@ -113,14 +113,24 @@ def _build_tool(a: Any, tool_name_: str, multiple: bool) -> ResultTool[ResultDat

return cls(tools=tools, allow_text_result=allow_text_result)

def find_named_tool(
self, parts: Iterable[_messages.ModelResponsePart], tool_name: str
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
"""Find a tool that matches one of the calls, with a specific name."""
for part in parts:
if isinstance(part, _messages.ToolCallPart):
if part.tool_name == tool_name:
return part, self.tools[tool_name]

def find_tool(
self, message: _messages.ModelResponse
self,
parts: Iterable[_messages.ModelResponsePart],
) -> tuple[_messages.ToolCallPart, ResultTool[ResultData]] | None:
"""Find a tool that matches one of the calls."""
for item in message.parts:
if isinstance(item, _messages.ToolCallPart):
if result := self.tools.get(item.tool_name):
return item, result
for part in parts:
if isinstance(part, _messages.ToolCallPart):
if result := self.tools.get(part.tool_name):
return part, result

def tool_names(self) -> list[str]:
"""Return the names of the tools."""
Expand Down
219 changes: 114 additions & 105 deletions pydantic_ai_slim/pydantic_ai/agent.py

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations as _annotations

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Callable
from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime
from typing import Generic, TypeVar, cast
Expand Down Expand Up @@ -122,7 +122,8 @@ class StreamedRunResult(_BaseRunResult[ResultData], Generic[AgentDeps, ResultDat
_result_schema: _result.ResultSchema[ResultData] | None
_deps: AgentDeps
_result_validators: list[_result.ResultValidator[AgentDeps, ResultData]]
_on_complete: Callable[[list[_messages.ModelMessage]], None]
_result_tool_name: str | None
_on_complete: Callable[[], Awaitable[None]]
is_complete: bool = field(default=False, init=False)
"""Whether the stream has all been received.

Expand Down Expand Up @@ -205,7 +206,7 @@ async def stream_text(self, *, delta: bool = False, debounce_by: float | None =
combined = await self._validate_text_result(''.join(chunks))
yield combined
lf_span.set_attribute('combined_text', combined)
self._marked_completed(_messages.ModelResponse.from_text(combined))
await self._marked_completed(_messages.ModelResponse.from_text(combined))

async def stream_structured(
self, *, debounce_by: float | None = 0.1
Expand Down Expand Up @@ -244,7 +245,7 @@ async def stream_structured(
msg = self._stream_response.get(final=True)
yield msg, True
lf_span.set_attribute('structured_response', msg)
self._marked_completed(msg)
await self._marked_completed(msg)

async def get_data(self) -> ResultData:
"""Stream the whole response, validate and return it."""
Expand All @@ -253,11 +254,11 @@ async def get_data(self) -> ResultData:
if isinstance(self._stream_response, models.StreamTextResponse):
text = ''.join(self._stream_response.get(final=True))
text = await self._validate_text_result(text)
self._marked_completed(_messages.ModelResponse.from_text(text))
await self._marked_completed(_messages.ModelResponse.from_text(text))
return cast(ResultData, text)
else:
message = self._stream_response.get(final=True)
self._marked_completed(message)
await self._marked_completed(message)
return await self.validate_structured_result(message)

@property
Expand All @@ -282,7 +283,8 @@ async def validate_structured_result(
) -> ResultData:
"""Validate a structured result message."""
assert self._result_schema is not None, 'Expected _result_schema to not be None'
match = self._result_schema.find_tool(message)
assert self._result_tool_name is not None, 'Expected _result_tool_name to not be None'
match = self._result_schema.find_named_tool(message.parts, self._result_tool_name)
if match is None:
raise exceptions.UnexpectedModelBehavior(
f'Invalid message, unable to find tool: {self._result_schema.tool_names()}'
Expand All @@ -306,7 +308,7 @@ async def _validate_text_result(self, text: str) -> str:
)
return text

def _marked_completed(self, message: _messages.ModelResponse) -> None:
async def _marked_completed(self, message: _messages.ModelResponse) -> None:
self.is_complete = True
self._all_messages.append(message)
self._on_complete(self._all_messages)
await self._on_complete()
14 changes: 7 additions & 7 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,13 +669,6 @@ async def bar(y: str) -> str:
ToolReturnPart(tool_name='bar', content='b', timestamp=IsNow(tz=timezone.utc)),
]
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
)
]
),
ModelResponse(
parts=[
ToolCallPart(
Expand All @@ -685,6 +678,13 @@ async def bar(y: str) -> str:
],
timestamp=IsNow(tz=timezone.utc),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
)
]
),
]
)
assert tool_calls == snapshot(["foo(x='a')", "bar(y='b')"])
Expand Down
16 changes: 9 additions & 7 deletions tests/models/test_groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,19 +415,21 @@ async def test_stream_structured(allow_model_requests: None):
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='', timestamp=IsNow(tz=timezone.utc))]),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result', content='Final result processed.', timestamp=IsNow(tz=timezone.utc)
)
]
),
ModelResponse(
parts=[
ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"first": "One", "second": "Two"}'))
],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
timestamp=IsNow(tz=timezone.utc),
)
]
),
]
)

Expand Down
12 changes: 6 additions & 6 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,12 @@ async def get_location(loc_name: str) -> str:
)
]
),
ModelResponse(
parts=[
ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"won": true}'), tool_call_id='1')
],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
ModelRequest(
parts=[
ToolReturnPart(
Expand All @@ -1337,12 +1343,6 @@ async def get_location(loc_name: str) -> str:
)
]
),
ModelResponse(
parts=[
ToolCallPart(tool_name='final_result', args=ArgsJson(args_json='{"won": true}'), tool_call_id='1')
],
timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc),
),
]
)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,13 +1060,13 @@ def another_tool(y: int) -> int: # pragma: no cover
ModelRequest(
parts=[
ToolReturnPart(
tool_name='final_result',
content='Final result processed.',
tool_name='regular_tool',
content='Tool not executed - a final result was already processed.',
timestamp=IsNow(tz=timezone.utc),
),
ToolReturnPart(
tool_name='regular_tool',
content='Tool not executed - a final result was already processed.',
tool_name='final_result',
content='Final result processed.',
timestamp=IsNow(tz=timezone.utc),
),
ToolReturnPart(
Expand Down
Loading
Loading