|
2 | 2 |
|
3 | 3 | import sys
|
4 | 4 | from datetime import datetime
|
| 5 | +from typing import cast |
5 | 6 |
|
| 7 | +import pytest |
| 8 | + |
| 9 | +from shiny import Session |
| 10 | +from shiny._namespaces import ResolvedId, Root |
| 11 | +from shiny.session import session_context |
| 12 | +from shiny.ui import Chat |
| 13 | +from shiny.ui._chat import as_transformed_message |
6 | 14 | from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk
|
| 15 | +from shiny.ui._chat_types import ChatMessage, StoredMessage |
| 16 | + |
| 17 | +# ---------------------------------------------------------------------- |
| 18 | +# Helpers |
| 19 | +# ---------------------------------------------------------------------- |
| 20 | + |
| 21 | + |
| 22 | +class _MockSession: |
| 23 | + ns: ResolvedId = Root |
| 24 | + app: object = None |
| 25 | + id: str = "mock-session" |
| 26 | + |
| 27 | + def on_ended(self, callback: object) -> None: |
| 28 | + pass |
| 29 | + |
| 30 | + def _increment_busy_count(self) -> None: |
| 31 | + pass |
| 32 | + |
| 33 | + |
| 34 | +test_session = cast(Session, _MockSession()) |
| 35 | + |
| 36 | + |
| 37 | +def as_stored_message(message: ChatMessage, token_count: int) -> StoredMessage: |
| 38 | + msg = as_transformed_message(message) |
| 39 | + return StoredMessage( |
| 40 | + **msg, |
| 41 | + token_count=token_count, |
| 42 | + ) |
| 43 | + |
7 | 44 |
|
8 |
| -# TODO: Feed these messages into an actual Chat() instance? |
| 45 | +# ---------------------------------------------------------------------- |
| 46 | +# Unit tests for Chat._get_trimmed_messages() |
| 47 | +# ---------------------------------------------------------------------- |
| 48 | + |
| 49 | + |
| 50 | +def test_chat_message_trimming(): |
| 51 | + with session_context(test_session): |
| 52 | + chat = Chat(id="chat") |
| 53 | + |
| 54 | + msgs = ( |
| 55 | + as_stored_message( |
| 56 | + {"content": "System message", "role": "system"}, token_count=101 |
| 57 | + ), |
| 58 | + ) |
| 59 | + |
| 60 | + # Throws since system message is too long |
| 61 | + with pytest.raises(ValueError): |
| 62 | + chat._trim_messages(msgs, token_limits=(100, 0)) |
| 63 | + |
| 64 | + msgs = ( |
| 65 | + as_stored_message( |
| 66 | + {"content": "System message", "role": "system"}, token_count=100 |
| 67 | + ), |
| 68 | + as_stored_message( |
| 69 | + {"content": "User message", "role": "user"}, token_count=1 |
| 70 | + ), |
| 71 | + ) |
| 72 | + |
| 73 | + # Throws since only the system message fits |
| 74 | + with pytest.raises(ValueError): |
| 75 | + chat._trim_messages(msgs, token_limits=(100, 0)) |
| 76 | + |
| 77 | + # Raising the limit should allow both messages to fit |
| 78 | + trimmed = chat._trim_messages(msgs, token_limits=(102, 0)) |
| 79 | + assert len(trimmed) == 2 |
| 80 | + contents = [msg["content_server"] for msg in trimmed] |
| 81 | + assert contents == ["System message", "User message"] |
| 82 | + |
| 83 | + msgs = ( |
| 84 | + as_stored_message( |
| 85 | + {"content": "System message", "role": "system"}, token_count=100 |
| 86 | + ), |
| 87 | + as_stored_message( |
| 88 | + {"content": "User message", "role": "user"}, token_count=10 |
| 89 | + ), |
| 90 | + as_stored_message( |
| 91 | + {"content": "User message 2", "role": "user"}, token_count=1 |
| 92 | + ), |
| 93 | + ) |
| 94 | + |
| 95 | + # Should discard the 1st user message |
| 96 | + trimmed = chat._trim_messages(msgs, token_limits=(102, 0)) |
| 97 | + assert len(trimmed) == 2 |
| 98 | + contents = [msg["content_server"] for msg in trimmed] |
| 99 | + assert contents == ["System message", "User message 2"] |
| 100 | + |
| 101 | + msgs = ( |
| 102 | + as_stored_message( |
| 103 | + {"content": "System message", "role": "system"}, token_count=50 |
| 104 | + ), |
| 105 | + as_stored_message( |
| 106 | + {"content": "User message", "role": "user"}, token_count=10 |
| 107 | + ), |
| 108 | + as_stored_message( |
| 109 | + {"content": "System message 2", "role": "system"}, token_count=50 |
| 110 | + ), |
| 111 | + as_stored_message( |
| 112 | + {"content": "User message 2", "role": "user"}, token_count=1 |
| 113 | + ), |
| 114 | + ) |
| 115 | + |
| 116 | + # Should discard the 1st user message |
| 117 | + trimmed = chat._trim_messages(msgs, token_limits=(102, 0)) |
| 118 | + assert len(trimmed) == 3 |
| 119 | + contents = [msg["content_server"] for msg in trimmed] |
| 120 | + assert contents == ["System message", "System message 2", "User message 2"] |
| 121 | + |
| 122 | + |
| 123 | +# ---------------------------------------------------------------------- |
| 124 | +# Unit tests for normalize_message() and normalize_message_chunk() |
| 125 | +# ---------------------------------------------------------------------- |
9 | 126 |
|
10 | 127 |
|
11 | 128 | def test_string_normalization():
|
|
0 commit comments