Skip to content

Commit ea6aeea

Browse files
authored
Chat.messages() now keeps system messages (#1509)
1 parent b5d8fa3 commit ea6aeea

File tree

2 files changed

+169
-21
lines changed

2 files changed

+169
-21
lines changed

shiny/ui/_chat.py

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,9 @@ def messages(
386386
A tuple of chat messages.
387387
"""
388388

389-
messages = self._get_trimmed_messages(token_limits=token_limits)
389+
messages = self._messages()
390+
if token_limits is not None:
391+
messages = self._trim_messages(messages, token_limits)
390392

391393
res: list[ChatMessage] = []
392394
for i, m in enumerate(messages):
@@ -741,34 +743,63 @@ def _store_message(
741743

742744
return msg
743745

744-
def _get_trimmed_messages(
745-
self,
746-
*,
747-
token_limits: tuple[int, int] | None = (4096, 1000),
746+
@staticmethod
747+
def _trim_messages(
748+
messages: tuple[StoredMessage, ...],
749+
token_limits: tuple[int, int] = (4096, 1000),
748750
) -> tuple[StoredMessage, ...]:
749-
messages = self._messages()
750751

751-
if token_limits is None:
752-
return messages
752+
n_total, n_reserve = token_limits
753+
if n_total <= n_reserve:
754+
raise ValueError(
755+
f"Invalid token limits: {token_limits}. The 1st value must be greater "
756+
"than the 2nd value."
757+
)
753758

754-
# Can't trim if we don't have token counts
755-
token_counts = [m["token_count"] for m in messages]
756-
if None in token_counts:
757-
return messages
759+
# Since don't trim system messages, 1st obtain their total token count
760+
# (so we can determine how many non-system messages can fit)
761+
n_system_tokens: int = 0
762+
n_system_messages: int = 0
763+
n_other_messages: int = 0
764+
for m in messages:
765+
count = m["token_count"]
766+
# Count can be None if the tokenizer is None
767+
if count is None:
768+
return messages
769+
if m["role"] == "system":
770+
n_system_tokens += count
771+
n_system_messages += 1
772+
else:
773+
n_other_messages += 1
774+
775+
remaining_non_system_tokens = n_total - n_reserve - n_system_tokens
758776

759-
token_counts = cast("list[int]", token_counts)
777+
if remaining_non_system_tokens <= 0:
778+
raise ValueError(
779+
f"System messages exceed `.messages(token_limits={token_limits})`. "
780+
"Consider increasing the 1st value of `token_limit` or setting it to "
781+
"`token_limit=None` to disable token limits."
782+
)
760783

761-
# Take the newest messages up to the token limit
762-
limit, reserve = token_limits
763-
max_tokens = limit - reserve
764784
messages2: list[StoredMessage] = []
765-
for i, m in enumerate(reversed(messages)):
766-
if sum(token_counts[-i - 1 :]) > max_tokens:
767-
break
768-
messages2.append(m)
785+
for m in reversed(messages):
786+
if m["role"] == "system":
787+
messages2.append(m)
788+
continue
789+
count = cast(int, m["token_count"]) # Already checked this
790+
remaining_non_system_tokens -= count
791+
if remaining_non_system_tokens >= 0:
792+
messages2.append(m)
769793

770794
messages2.reverse()
771795

796+
if len(messages2) == n_system_messages and n_other_messages > 0:
797+
raise ValueError(
798+
f"Only system messages fit within `.messages(token_limits={token_limits})`. "
799+
"Consider increasing the 1st value of `token_limit` or setting it to "
800+
"`token_limit=None` to disable token limits."
801+
)
802+
772803
return tuple(messages2)
773804

774805
def user_input(self, transform: bool = False) -> str | None:

tests/pytest/test_chat.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,127 @@
22

33
import sys
44
from datetime import datetime
5+
from typing import cast
56

7+
import pytest
8+
9+
from shiny import Session
10+
from shiny._namespaces import ResolvedId, Root
11+
from shiny.session import session_context
12+
from shiny.ui import Chat
13+
from shiny.ui._chat import as_transformed_message
614
from shiny.ui._chat_normalize import normalize_message, normalize_message_chunk
15+
from shiny.ui._chat_types import ChatMessage, StoredMessage
16+
17+
# ----------------------------------------------------------------------
18+
# Helpers
19+
# ----------------------------------------------------------------------
20+
21+
22+
class _MockSession:
23+
ns: ResolvedId = Root
24+
app: object = None
25+
id: str = "mock-session"
26+
27+
def on_ended(self, callback: object) -> None:
28+
pass
29+
30+
def _increment_busy_count(self) -> None:
31+
pass
32+
33+
34+
test_session = cast(Session, _MockSession())
35+
36+
37+
def as_stored_message(message: ChatMessage, token_count: int) -> StoredMessage:
38+
msg = as_transformed_message(message)
39+
return StoredMessage(
40+
**msg,
41+
token_count=token_count,
42+
)
43+
744

8-
# TODO: Feed these messages into an actual Chat() instance?
45+
# ----------------------------------------------------------------------
46+
# Unit tests for Chat._get_trimmed_messages()
47+
# ----------------------------------------------------------------------
48+
49+
50+
def test_chat_message_trimming():
51+
with session_context(test_session):
52+
chat = Chat(id="chat")
53+
54+
msgs = (
55+
as_stored_message(
56+
{"content": "System message", "role": "system"}, token_count=101
57+
),
58+
)
59+
60+
# Throws since system message is too long
61+
with pytest.raises(ValueError):
62+
chat._trim_messages(msgs, token_limits=(100, 0))
63+
64+
msgs = (
65+
as_stored_message(
66+
{"content": "System message", "role": "system"}, token_count=100
67+
),
68+
as_stored_message(
69+
{"content": "User message", "role": "user"}, token_count=1
70+
),
71+
)
72+
73+
# Throws since only the system message fits
74+
with pytest.raises(ValueError):
75+
chat._trim_messages(msgs, token_limits=(100, 0))
76+
77+
# Raising the limit should allow both messages to fit
78+
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
79+
assert len(trimmed) == 2
80+
contents = [msg["content_server"] for msg in trimmed]
81+
assert contents == ["System message", "User message"]
82+
83+
msgs = (
84+
as_stored_message(
85+
{"content": "System message", "role": "system"}, token_count=100
86+
),
87+
as_stored_message(
88+
{"content": "User message", "role": "user"}, token_count=10
89+
),
90+
as_stored_message(
91+
{"content": "User message 2", "role": "user"}, token_count=1
92+
),
93+
)
94+
95+
# Should discard the 1st user message
96+
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
97+
assert len(trimmed) == 2
98+
contents = [msg["content_server"] for msg in trimmed]
99+
assert contents == ["System message", "User message 2"]
100+
101+
msgs = (
102+
as_stored_message(
103+
{"content": "System message", "role": "system"}, token_count=50
104+
),
105+
as_stored_message(
106+
{"content": "User message", "role": "user"}, token_count=10
107+
),
108+
as_stored_message(
109+
{"content": "System message 2", "role": "system"}, token_count=50
110+
),
111+
as_stored_message(
112+
{"content": "User message 2", "role": "user"}, token_count=1
113+
),
114+
)
115+
116+
# Should discard the 1st user message
117+
trimmed = chat._trim_messages(msgs, token_limits=(102, 0))
118+
assert len(trimmed) == 3
119+
contents = [msg["content_server"] for msg in trimmed]
120+
assert contents == ["System message", "System message 2", "User message 2"]
121+
122+
123+
# ----------------------------------------------------------------------
124+
# Unit tests for normalize_message() and normalize_message_chunk()
125+
# ----------------------------------------------------------------------
9126

10127

11128
def test_string_normalization():

0 commit comments

Comments
 (0)