Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions shiny/ui/_chat_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ def can_normalize_chunk(self, chunk: Any) -> bool:

class DictNormalizer(BaseMessageNormalizer):
def normalize(self, message: Any) -> ChatMessage:
x = self._check_dict(message)
x = cast("dict[str, Any]", message)
if "content" not in x:
raise ValueError("Message must have 'content' key")
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))

def normalize_chunk(self, chunk: Any) -> ChatMessage:
x = self._check_dict(chunk)
x = cast("dict[str, Any]", chunk)
if "content" not in x:
raise ValueError("Message must have 'content' key")
return ChatMessage(content=x["content"], role=x.get("role", "assistant"))

def can_normalize(self, message: Any) -> bool:
Expand All @@ -70,14 +74,6 @@ def can_normalize(self, message: Any) -> bool:
def can_normalize_chunk(self, chunk: Any) -> bool:
return isinstance(chunk, dict)

@staticmethod
def _check_dict(x: Any) -> "dict[str, Any]":
if "content" not in x:
raise ValueError("Message must have 'content' key")
if "role" in x and x["role"] not in ["assistant", "system"]:
raise ValueError("Role must be 'assistant' or 'system")
return x


class LangChainNormalizer(BaseMessageNormalizer):
def normalize(self, message: Any) -> ChatMessage:
Expand Down
18 changes: 18 additions & 0 deletions tests/playwright/shiny/components/chat/append_user_msg/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from shiny import reactive
from shiny.express import render, ui

chat = ui.Chat(id="chat")
chat.ui()


@reactive.effect
async def _():
await chat.append_message({"content": "A user message", "role": "user"})


"chat.messages():"


@render.code
def message_state():
return str(chat.messages())
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from playwright.sync_api import Page, expect

from shiny.playwright import controller
from shiny.run import ShinyAppProc


def test_validate_chat_append_user_message(page: Page, local_app: ShinyAppProc) -> None:
page.goto(local_app.url)

chat = controller.Chat(page, "chat")

# Verify starting state
expect(chat.loc).to_be_visible()
chat.expect_latest_message("A user message")

# Verify that the message state is as expected
message_state = controller.OutputCode(page, "message_state")
message_state_expected = ({"content": "A user message", "role": "user"},)
message_state.expect_value(str(message_state_expected))