Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,9 @@ def _process_response(self, response: responses.Response) -> ModelResponse:
if isinstance(content, responses.ResponseOutputText): # pragma: no branch
items.append(TextPart(content.text))
elif isinstance(item, responses.ResponseFunctionToolCall):
items.append(ToolCallPart(item.name, item.arguments, tool_call_id=item.call_id))
items.append(
ToolCallPart(item.name, item.arguments, tool_call_id=_combine_tool_call_ids(item.call_id, item.id))
)

finish_reason: FinishReason | None = None
provider_details: dict[str, Any] | None = None
Expand Down Expand Up @@ -1084,27 +1086,29 @@ async def _map_messages( # noqa: C901
elif isinstance(part, UserPromptPart):
openai_messages.append(await self._map_user_prompt(part))
elif isinstance(part, ToolReturnPart):
openai_messages.append(
FunctionCallOutput(
type='function_call_output',
call_id=_guard_tool_call_id(t=part),
output=part.model_response_str(),
)
call_id = _guard_tool_call_id(t=part)
call_id, _ = _split_combined_tool_call_id(call_id)
item = FunctionCallOutput(
type='function_call_output',
call_id=call_id,
output=part.model_response_str(),
)
openai_messages.append(item)
elif isinstance(part, RetryPromptPart):
# TODO(Marcelo): How do we test this conditional branch?
if part.tool_name is None: # pragma: no cover
openai_messages.append(
Message(role='user', content=[{'type': 'input_text', 'text': part.model_response()}])
)
else:
openai_messages.append(
FunctionCallOutput(
type='function_call_output',
call_id=_guard_tool_call_id(t=part),
output=part.model_response(),
)
call_id = _guard_tool_call_id(t=part)
call_id, _ = _split_combined_tool_call_id(call_id)
item = FunctionCallOutput(
type='function_call_output',
call_id=call_id,
output=part.model_response(),
)
openai_messages.append(item)
else:
assert_never(part)
elif isinstance(message, ModelResponse):
Expand Down Expand Up @@ -1141,12 +1145,18 @@ async def _map_messages( # noqa: C901

@staticmethod
def _map_tool_call(t: ToolCallPart) -> responses.ResponseFunctionToolCallParam:
return responses.ResponseFunctionToolCallParam(
arguments=t.args_as_json_str(),
call_id=_guard_tool_call_id(t=t),
call_id = _guard_tool_call_id(t=t)
call_id, id = _split_combined_tool_call_id(call_id)

param = responses.ResponseFunctionToolCallParam(
name=t.tool_name,
arguments=t.args_as_json_str(),
call_id=call_id,
type='function_call',
)
if id: # pragma: no branch
param['id'] = id
return param

def _map_json_schema(self, o: OutputObjectDefinition) -> responses.ResponseFormatTextJSONSchemaConfigParam:
response_format_param: responses.ResponseFormatTextJSONSchemaConfigParam = {
Expand Down Expand Up @@ -1365,7 +1375,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
vendor_part_id=chunk.item.id,
tool_name=chunk.item.name,
args=chunk.item.arguments,
tool_call_id=chunk.item.call_id,
tool_call_id=_combine_tool_call_ids(chunk.item.call_id, chunk.item.id),
)
elif isinstance(chunk.item, responses.ResponseReasoningItem):
pass
Expand Down Expand Up @@ -1506,3 +1516,17 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk | responses.R
u.input_audio_tokens = response_usage.prompt_tokens_details.audio_tokens or 0
u.cache_read_tokens = response_usage.prompt_tokens_details.cached_tokens or 0
return u


def _combine_tool_call_ids(call_id: str, id: str | None) -> str:
# When reasoning, the Responses API requires the `ResponseFunctionToolCall` to be returned with both the `call_id` and `id` fields.
# Our `ToolCallPart` has only the `call_id` field, so we combine the two fields into a single string.
return f'{call_id}|{id}' if id else call_id


def _split_combined_tool_call_id(combined_id: str) -> tuple[str, str | None]:
if '|' in combined_id:
call_id, id = combined_id.split('|', 1)
return call_id, id
else:
return combined_id, None # pragma: no cover
Loading