|
1 | 1 | from collections.abc import AsyncIterator
|
| 2 | +from typing import Any, cast |
2 | 3 |
|
3 | 4 | import pytest
|
4 | 5 | from inline_snapshot import snapshot
|
5 | 6 |
|
6 | 7 | from pydantic_ai import Agent
|
7 | 8 | from pydantic_ai.messages import ModelMessage, ModelRequest, ModelRequestPart, ModelResponse, TextPart, UserPromptPart
|
8 | 9 | from pydantic_ai.models.function import AgentInfo, FunctionModel
|
| 10 | +from pydantic_ai.tools import RunContext |
9 | 11 | from pydantic_ai.usage import Usage
|
10 | 12 |
|
11 | 13 | from .conftest import IsDatetime
|
@@ -201,3 +203,101 @@ async def async_processor(messages: list[ModelMessage]) -> list[ModelMessage]:
|
201 | 203 | ModelRequest(parts=[UserPromptPart(content='Question 2', timestamp=IsDatetime())]),
|
202 | 204 | ]
|
203 | 205 | )
|
| 206 | + |
| 207 | + |
| 208 | +async def test_history_processor_with_context(function_model: FunctionModel, received_messages: list[ModelMessage]): |
| 209 | + """Test history processor that takes RunContext.""" |
| 210 | + |
| 211 | + def context_processor(ctx: RunContext[str], messages: list[ModelMessage]) -> list[ModelMessage]: |
| 212 | + # Access deps from context |
| 213 | + prefix = ctx.deps |
| 214 | + processed: list[ModelMessage] = [] |
| 215 | + for msg in messages: |
| 216 | + if isinstance(msg, ModelRequest): |
| 217 | + new_parts: list[ModelRequestPart] = [] |
| 218 | + for part in msg.parts: |
| 219 | + if isinstance(part, UserPromptPart): |
| 220 | + new_parts.append(UserPromptPart(content=f'{prefix}: {part.content}')) |
| 221 | + else: |
| 222 | + new_parts.append(part) # pragma: no cover |
| 223 | + processed.append(ModelRequest(parts=new_parts)) |
| 224 | + else: |
| 225 | + processed.append(msg) # pragma: no cover |
| 226 | + return processed |
| 227 | + |
| 228 | + agent = Agent(function_model, history_processors=[context_processor], deps_type=str) |
| 229 | + await agent.run('test', deps='PREFIX') |
| 230 | + |
| 231 | + # Verify the prefix was added |
| 232 | + assert len(received_messages) == 1 |
| 233 | + assert isinstance(received_messages[0], ModelRequest) |
| 234 | + user_part = received_messages[0].parts[0] |
| 235 | + assert isinstance(user_part, UserPromptPart) |
| 236 | + assert user_part.content == 'PREFIX: test' |
| 237 | + |
| 238 | + |
| 239 | +async def test_history_processor_with_context_async( |
| 240 | + function_model: FunctionModel, received_messages: list[ModelMessage] |
| 241 | +): |
| 242 | + """Test async history processor that takes RunContext.""" |
| 243 | + |
| 244 | + async def async_context_processor(ctx: RunContext[Any], messages: list[ModelMessage]) -> list[ModelMessage]: |
| 245 | + return messages[-1:] # Keep only the last message |
| 246 | + |
| 247 | + message_history = [ |
| 248 | + ModelRequest(parts=[UserPromptPart(content='Question 1')]), |
| 249 | + ModelResponse(parts=[TextPart(content='Answer 1')]), |
| 250 | + ModelRequest(parts=[UserPromptPart(content='Question 2')]), |
| 251 | + ModelResponse(parts=[TextPart(content='Answer 2')]), |
| 252 | + ] |
| 253 | + |
| 254 | + agent = Agent(function_model, history_processors=[async_context_processor]) |
| 255 | + await agent.run('Question 3', message_history=message_history) |
| 256 | + |
| 257 | + # Should have filtered to only recent messages |
| 258 | + assert len(received_messages) <= 2 # Last message from history + new message |
| 259 | + |
| 260 | + |
| 261 | +async def test_history_processor_mixed_signatures(function_model: FunctionModel, received_messages: list[ModelMessage]): |
| 262 | + """Test mixing processors with and without context.""" |
| 263 | + |
| 264 | + def simple_processor(messages: list[ModelMessage]) -> list[ModelMessage]: |
| 265 | + # Filter out responses |
| 266 | + return [msg for msg in messages if isinstance(msg, ModelRequest)] |
| 267 | + |
| 268 | + def context_processor(ctx: RunContext[Any], messages: list[ModelMessage]) -> list[ModelMessage]: |
| 269 | + # Add prefix based on deps |
| 270 | + prefix = getattr(ctx.deps, 'prefix', 'DEFAULT') |
| 271 | + processed: list[ModelMessage] = [] |
| 272 | + for msg in messages: |
| 273 | + if isinstance(msg, ModelRequest): |
| 274 | + new_parts: list[ModelRequestPart] = [] |
| 275 | + for part in msg.parts: |
| 276 | + if isinstance(part, UserPromptPart): |
| 277 | + new_parts.append(UserPromptPart(content=f'{prefix}: {part.content}')) |
| 278 | + else: |
| 279 | + new_parts.append(part) # pragma: no cover |
| 280 | + processed.append(ModelRequest(parts=new_parts)) |
| 281 | + else: |
| 282 | + processed.append(msg) # pragma: no cover |
| 283 | + return processed |
| 284 | + |
| 285 | + message_history = [ |
| 286 | + ModelRequest(parts=[UserPromptPart(content='Question 1')]), |
| 287 | + ModelResponse(parts=[TextPart(content='Answer 1')]), |
| 288 | + ] |
| 289 | + |
| 290 | + # Create deps with prefix attribute |
| 291 | + class Deps: |
| 292 | + prefix = 'TEST' |
| 293 | + |
| 294 | + agent = Agent(function_model, history_processors=[simple_processor, context_processor], deps_type=Deps) |
| 295 | + await agent.run('Question 2', message_history=message_history, deps=Deps()) |
| 296 | + |
| 297 | + # Should have filtered responses and added prefix |
| 298 | + assert len(received_messages) == 2 |
| 299 | + for msg in received_messages: |
| 300 | + assert isinstance(msg, ModelRequest) |
| 301 | + user_part = msg.parts[0] |
| 302 | + assert isinstance(user_part, UserPromptPart) |
| 303 | + assert cast(str, user_part.content).startswith('TEST: ') |
0 commit comments