2
2
3
3
import argparse
4
4
import asyncio
5
+ import importlib
5
6
import sys
6
7
from asyncio import CancelledError
7
8
from collections .abc import Sequence
12
13
13
14
from typing_inspection .introspection import get_literal_values
14
15
16
+ from pydantic_ai .result import OutputDataT
17
+ from pydantic_ai .tools import AgentDepsT
18
+
15
19
from . import __version__
16
20
from .agent import Agent
17
21
from .exceptions import UserError
@@ -123,6 +127,11 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in
123
127
# e.g. we want to show `openai:gpt-4o` but not `gpt-4o`
124
128
qualified_model_names = [n for n in get_literal_values (KnownModelName .__value__ ) if ':' in n ]
125
129
arg .completer = argcomplete .ChoicesCompleter (qualified_model_names ) # type: ignore[reportPrivateUsage]
130
+ parser .add_argument (
131
+ '-a' ,
132
+ '--agent' ,
133
+ help = 'Custom Agent to use, in format "module:variable", e.g. "mymodule.submodule:my_agent"' ,
134
+ )
126
135
parser .add_argument (
127
136
'-l' ,
128
137
'--list-models' ,
@@ -155,8 +164,22 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in
155
164
console .print (f' { model } ' , highlight = False )
156
165
return 0
157
166
167
+ agent : Agent [None , str ] = cli_agent
168
+ if args .agent :
169
+ try :
170
+ module_path , variable_name = args .agent .split (':' )
171
+ module = importlib .import_module (module_path )
172
+ agent = getattr (module , variable_name )
173
+ if not isinstance (agent , Agent ):
174
+ console .print (f'[red]Error: { args .agent } is not an Agent instance[/red]' )
175
+ return 1
176
+ console .print (f'[green]Using custom agent:[/green] [magenta]{ args .agent } [/magenta]' , highlight = False )
177
+ except ValueError :
178
+ console .print ('[red]Error: Agent must be specified in "module:variable" format[/red]' )
179
+ return 1
180
+
158
181
try :
159
- cli_agent .model = infer_model (args .model )
182
+ agent .model = infer_model (args .model )
160
183
except UserError as e :
161
184
console .print (f'Error initializing [magenta]{ args .model } [/magenta]:\n [red]{ e } [/red]' )
162
185
return 1
@@ -171,21 +194,27 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in
171
194
172
195
if prompt := cast (str , args .prompt ):
173
196
try :
174
- asyncio .run (ask_agent (cli_agent , prompt , stream , console , code_theme ))
197
+ asyncio .run (ask_agent (agent , prompt , stream , console , code_theme ))
175
198
except KeyboardInterrupt :
176
199
pass
177
200
return 0
178
201
179
202
# doing this instead of `PromptSession[Any](history=` allows mocking of PromptSession in tests
180
203
session : PromptSession [Any ] = PromptSession (history = FileHistory (str (PROMPT_HISTORY_PATH )))
181
204
try :
182
- return asyncio .run (run_chat (session , stream , cli_agent , console , code_theme , prog_name ))
205
+ return asyncio .run (run_chat (session , stream , agent , console , code_theme , prog_name ))
183
206
except KeyboardInterrupt : # pragma: no cover
184
207
return 0
185
208
186
209
187
210
async def run_chat (
188
- session : PromptSession [Any ], stream : bool , agent : Agent , console : Console , code_theme : str , prog_name : str
211
+ session : PromptSession [Any ],
212
+ stream : bool ,
213
+ agent : Agent [AgentDepsT , OutputDataT ],
214
+ console : Console ,
215
+ code_theme : str ,
216
+ prog_name : str ,
217
+ deps : AgentDepsT = None ,
189
218
) -> int :
190
219
multiline = False
191
220
messages : list [ModelMessage ] = []
@@ -207,30 +236,31 @@ async def run_chat(
207
236
return exit_value
208
237
else :
209
238
try :
210
- messages = await ask_agent (agent , text , stream , console , code_theme , messages )
239
+ messages = await ask_agent (agent , text , stream , console , code_theme , deps , messages )
211
240
except CancelledError : # pragma: no cover
212
241
console .print ('[dim]Interrupted[/dim]' )
213
242
214
243
215
244
async def ask_agent (
216
- agent : Agent ,
245
+ agent : Agent [ AgentDepsT , OutputDataT ] ,
217
246
prompt : str ,
218
247
stream : bool ,
219
248
console : Console ,
220
249
code_theme : str ,
250
+ deps : AgentDepsT = None ,
221
251
messages : list [ModelMessage ] | None = None ,
222
252
) -> list [ModelMessage ]:
223
253
status = Status ('[dim]Working on it…[/dim]' , console = console )
224
254
225
255
if not stream :
226
256
with status :
227
- result = await agent .run (prompt , message_history = messages )
228
- content = result .output
257
+ result = await agent .run (prompt , message_history = messages , deps = deps )
258
+ content = str ( result .output )
229
259
console .print (Markdown (content , code_theme = code_theme ))
230
260
return result .all_messages ()
231
261
232
262
with status , ExitStack () as stack :
233
- async with agent .iter (prompt , message_history = messages ) as agent_run :
263
+ async with agent .iter (prompt , message_history = messages , deps = deps ) as agent_run :
234
264
live = Live ('' , refresh_per_second = 15 , console = console , vertical_overflow = 'ellipsis' )
235
265
async for node in agent_run :
236
266
if Agent .is_model_request_node (node ):
0 commit comments