Skip to content

Commit 05c315c

Browse files
committed
Tool support, closes #20
1 parent 353781d commit 05c315c

File tree

4 files changed

+536
-10
lines changed

4 files changed

+536
-10
lines changed

llm_openai.py

Lines changed: 204 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import json
12
from enum import Enum
3+
import llm
24
from llm import (
35
AsyncKeyModel,
46
KeyModel,
@@ -174,6 +176,7 @@ def __init__(
174176
if reasoning:
175177
options.append(ReasoningOptions)
176178
self.Options = combine_options(*options)
179+
self.supports_tools = True
177180

178181
def __str__(self):
179182
return f"OpenAI: {self.model_id}"
@@ -219,9 +222,30 @@ def _build_messages(self, prompt, conversation):
219222
messages.append(
220223
{"role": "user", "content": prev_response.prompt.prompt}
221224
)
222-
messages.append(
223-
{"role": "assistant", "content": prev_response.text_or_raise()}
224-
)
225+
for tool_result in getattr(prev_response.prompt, "tool_results", []):
226+
if not tool_result.tool_call_id:
227+
continue
228+
messages.append(
229+
{
230+
"type": "function_call_output",
231+
"call_id": tool_result.tool_call_id,
232+
"output": tool_result.output,
233+
}
234+
)
235+
prev_text = prev_response.text_or_raise()
236+
if prev_text:
237+
messages.append({"role": "assistant", "content": prev_text})
238+
tool_calls = prev_response.tool_calls_or_raise()
239+
if tool_calls:
240+
for tool_call in tool_calls:
241+
messages.append(
242+
{
243+
"type": "function_call",
244+
"call_id": tool_call.tool_call_id,
245+
"name": tool_call.name,
246+
"arguments": json.dumps(tool_call.arguments),
247+
}
248+
)
225249
if prompt.system and prompt.system != current_system:
226250
messages.append({"role": "system", "content": prompt.system})
227251
if not prompt.attachments:
@@ -233,6 +257,16 @@ def _build_messages(self, prompt, conversation):
233257
for attachment in prompt.attachments:
234258
attachment_message.append(_attachment(attachment, image_detail))
235259
messages.append({"role": "user", "content": attachment_message})
260+
for tool_result in getattr(prompt, "tool_results", []):
261+
if not tool_result.tool_call_id:
262+
continue
263+
messages.append(
264+
{
265+
"type": "function_call_output",
266+
"call_id": tool_result.tool_call_id,
267+
"output": tool_result.output,
268+
}
269+
)
236270
return messages
237271

238272
def _build_kwargs(self, prompt, conversation):
@@ -248,6 +282,27 @@ def _build_kwargs(self, prompt, conversation):
248282
value = getattr(prompt.options, option, None)
249283
if value is not None:
250284
kwargs[option] = value
285+
286+
if prompt.tools:
287+
tool_defs = []
288+
for tool in prompt.tools:
289+
if not getattr(tool, "name", None):
290+
continue
291+
parameters = tool.input_schema or {
292+
"type": "object",
293+
"properties": {},
294+
}
295+
tool_defs.append(
296+
{
297+
"type": "function",
298+
"name": tool.name,
299+
"description": tool.description or None,
300+
"parameters": parameters,
301+
"strict": False,
302+
}
303+
)
304+
if tool_defs:
305+
kwargs["tools"] = tool_defs
251306
if self.supports_schema and prompt.schema:
252307
kwargs["text"] = {
253308
"format": {
@@ -258,17 +313,149 @@ def _build_kwargs(self, prompt, conversation):
258313
}
259314
return kwargs
260315

261-
def _handle_event(self, event, response):
262-
if event.type == "response.output_text.delta":
316+
def _update_tool_call_from_event(self, event, _tc_buf):
317+
"""
318+
Accumulate streaming tool-call args by tool_call_id.
319+
_tc_buf is a dict[id] -> {"name": str, "arguments": str}
320+
"""
321+
et = getattr(event, "type", None)
322+
# Python SDK surfaces rich objects; also support dict fallbacks:
323+
obj = getattr(event, "to_dict", None)
324+
if callable(obj):
325+
payload = event.to_dict()
326+
else:
327+
payload = getattr(event, "__dict__", {}) or {}
328+
329+
# The SDK emits specific typed events for tool-calls; normalize:
330+
# Expected shapes (SDK may differ slightly by version):
331+
# - response.tool_call.delta => { "id", "type":"function", "name"?, "arguments_delta" }
332+
# - response.tool_call.completed => { "id", "type":"function", "name", "arguments" }
333+
# Keep this resilient by checking common fields.
334+
item = payload.get("response", payload.get("data", payload))
335+
tool = item.get("tool_call") if isinstance(item, dict) else None
336+
if not tool and "tool_call" in payload:
337+
tool = payload["tool_call"]
338+
339+
# Some SDKs put fields at top-level for tool events:
340+
if (
341+
tool is None
342+
and ("tool_call_id" in payload or "id" in payload)
343+
and (
344+
"arguments_delta" in payload
345+
or "arguments" in payload
346+
or "name" in payload
347+
)
348+
):
349+
tool = payload
350+
351+
if not tool:
352+
return None
353+
354+
tool_id = tool.get("id") or tool.get("tool_call_id")
355+
if not tool_id:
356+
return None
357+
358+
entry = _tc_buf.setdefault(tool_id, {"name": None, "arguments": ""})
359+
360+
# Name may arrive early or only at completion:
361+
if tool.get("name"):
362+
entry["name"] = tool["name"]
363+
364+
# Streaming deltas:
365+
if "arguments_delta" in tool and tool["arguments_delta"]:
366+
entry["arguments"] += tool["arguments_delta"]
367+
368+
# Completion:
369+
if (
370+
"arguments" in tool
371+
and tool["arguments"]
372+
and not tool.get("arguments_delta")
373+
):
374+
entry["arguments"] = tool["arguments"]
375+
return tool_id # signal completion for this id
376+
377+
return None
378+
379+
def _finalize_streaming_tool_calls(self, response, _tc_buf):
380+
# Called when we know streaming has finished or when a tool_call completed event fires.
381+
for tool_id, data in list(_tc_buf.items()):
382+
if data.get("name") and data.get("arguments") is not None:
383+
self._add_tool_call(
384+
response,
385+
tool_id,
386+
data.get("name"),
387+
data.get("arguments"),
388+
)
389+
del _tc_buf[tool_id]
390+
391+
def _add_tool_call(self, response, tool_id, name, arguments):
392+
try:
393+
parsed_arguments = json.loads(arguments or "{}")
394+
except Exception:
395+
parsed_arguments = arguments or ""
396+
response.add_tool_call(
397+
llm.ToolCall(
398+
tool_call_id=tool_id,
399+
name=name or "unknown_tool",
400+
arguments=parsed_arguments,
401+
)
402+
)
403+
404+
def _add_tool_calls_from_output(self, response, output):
405+
if not output:
406+
return
407+
for item in output:
408+
if hasattr(item, "model_dump"):
409+
data = item.model_dump()
410+
elif isinstance(item, dict):
411+
data = item
412+
else:
413+
data = getattr(item, "__dict__", {}) or {}
414+
415+
itype = data.get("type")
416+
if itype not in {"tool_call", "function_call"}:
417+
continue
418+
419+
tool_id = (
420+
data.get("call_id")
421+
or data.get("id")
422+
or data.get("tool_call_id")
423+
or f"call_{len(output)}"
424+
)
425+
name = data.get("name") or "unknown_tool"
426+
arguments = data.get("arguments") or "{}"
427+
self._add_tool_call(response, tool_id, name, arguments)
428+
429+
def _handle_event(self, event, response, _tc_buf=None):
430+
et = getattr(event, "type", None)
431+
if et == "response.output_text.delta":
263432
return event.delta
264-
elif event.type == "response.completed":
433+
434+
# Accumulate tool-call pieces if provided
435+
if _tc_buf is not None and et and "tool_call" in et:
436+
completed_id = self._update_tool_call_from_event(event, _tc_buf)
437+
if completed_id:
438+
# finalize this single tool call immediately
439+
entry = _tc_buf.pop(completed_id)
440+
self._finalize_streaming_tool_calls(response, {completed_id: entry})
441+
442+
if et == "response.completed":
265443
response.response_json = event.response.model_dump()
266444
self.set_usage(response, event.response.usage)
445+
self._add_tool_calls_from_output(
446+
response, getattr(event.response, "output", None)
447+
)
448+
# finalize any remaining buffered tool-calls
449+
if _tc_buf:
450+
self._finalize_streaming_tool_calls(response, _tc_buf)
267451
return None
268452

269453
def _finish_non_streaming_response(self, response, client_response):
270454
response.response_json = client_response.model_dump()
271455
self.set_usage(response, client_response.usage)
456+
self._add_tool_calls_from_output(
457+
response, getattr(client_response, "output", None)
458+
)
272459

273460

274461
class ResponsesModel(_SharedResponses, KeyModel):
@@ -284,13 +471,17 @@ def execute(
284471
kwargs = self._build_kwargs(prompt, conversation)
285472
kwargs["stream"] = stream
286473
if stream:
474+
# Buffer for assembling tool-call deltas across events
475+
_tc_buf = {}
287476
for event in client.responses.create(**kwargs):
288-
delta = self._handle_event(event, response)
477+
delta = self._handle_event(event, response, _tc_buf)
289478
if delta is not None:
290479
yield delta
291480
else:
292481
client_response = client.responses.create(**kwargs)
293-
yield client_response.output_text
482+
text = getattr(client_response, "output_text", None)
483+
if text:
484+
yield text
294485
self._finish_non_streaming_response(response, client_response)
295486

296487

@@ -307,13 +498,16 @@ async def execute(
307498
kwargs = self._build_kwargs(prompt, conversation)
308499
kwargs["stream"] = stream
309500
if stream:
501+
_tc_buf = {}
310502
async for event in await client.responses.create(**kwargs):
311-
delta = self._handle_event(event, response)
503+
delta = self._handle_event(event, response, _tc_buf)
312504
if delta is not None:
313505
yield delta
314506
else:
315507
client_response = await client.responses.create(**kwargs)
316-
yield client_response.output_text
508+
text = getattr(client_response, "output_text", None)
509+
if text:
510+
yield text
317511
self._finish_non_streaming_response(response, client_response)
318512

319513

tests/__snapshots__/test_openai.ambr

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,11 @@
2424
# name: test_options[options4]
2525
'Hi there! How can I assist you today?'
2626
# ---
27+
# name: test_tools
28+
'''
29+
I called simple_tool with 5; it returned:
30+
"This is a simple tool, 5"
31+
32+
Anything else you’d like me to run or change?
33+
'''
34+
# ---

0 commit comments

Comments
 (0)