Skip to content

Commit 1e9b868

Browse files
authored
Chat's .append_message_stream() now returns an ExtendedTask (#1846)
1 parent 33bf25d commit 1e9b868

File tree

5 files changed

+149
-3
lines changed

5 files changed

+149
-3
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818

1919
* Added a new `.add_sass_layer_file()` method to `ui.Theme` that supports reading a Sass file with layer boundary comments, e.g. `/*-- scss:defaults --*/`. This format [is supported by Quarto](https://quarto.org/docs/output-formats/html-themes-more.html#bootstrap-bootswatch-layering) and makes it easier to store Sass rules and declarations that need to be woven into Shiny's Sass Bootstrap files. (#1790)
2020

21-
* The `ui.Chat()` component's `.on_user_submit()` decorator method now passes the user input to the decorated function. This makes it a bit more obvious how to access the user input inside the decorated function. See the new templates (mentioned below) for examples. (#1801)
21+
* The `ui.Chat()` component gains the following:
22+
* The `.on_user_submit()` decorator method now passes the user input to the decorated function. This makes it a bit easier to access the user input. See the new templates (mentioned below) for examples. (#1801)
23+
* A new `get_latest_stream_result()` method was added for an easy way to access the final result of the stream when it completes. (#1846)
24+
* The `.append_message_stream()` method now returns the `reactive.extended_task` instance that it launches. (#1846)
2225

2326
* `shiny create` includes new and improved `ui.Chat()` template options. Most of these templates leverage the new [`{chatlas}` package](https://posit-dev.github.io/chatlas/), our opinionated approach to interfacing with various LLM. (#1806)
2427

shiny/ui/_chat.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,10 @@ def __init__(
210210
reactive.Value(None)
211211
)
212212

213+
self._latest_stream: reactive.Value[
214+
reactive.ExtendedTask[[], str] | None
215+
] = reactive.Value(None)
216+
213217
# TODO: deprecate messages once we start promoting managing LLM message
214218
# state through other means
215219
@reactive.effect
@@ -607,17 +611,26 @@ async def append_message_stream(self, message: Iterable[Any] | AsyncIterable[Any
607611
Use this method (over `.append_message()`) when `stream=True` (or similar) is
608612
specified in model's completion method.
609613
```
614+
615+
Returns
616+
-------
617+
:
618+
An extended task that represents the streaming task. The `.result()` method
619+
of the task can be called in a reactive context to get the final state of the
620+
stream.
610621
"""
611622

612623
message = _utils.wrap_async_iterable(message)
613624

614625
# Run the stream in the background to get non-blocking behavior
615626
@reactive.extended_task
616627
async def _stream_task():
617-
await self._append_message_stream(message)
628+
return await self._append_message_stream(message)
618629

619630
_stream_task()
620631

632+
self._latest_stream.set(_stream_task)
633+
621634
# Since the task runs in the background (outside/beyond the current context,
622635
# if any), we need to manually raise any exceptions that occur
623636
@reactive.effect
@@ -627,6 +640,35 @@ async def _handle_error():
627640
await self._raise_exception(e)
628641
_handle_error.destroy() # type: ignore
629642

643+
return _stream_task
644+
645+
def get_latest_stream_result(self) -> str | None:
646+
"""
647+
Reactively read the latest message stream result.
648+
649+
This method reads a reactive value containing the result of the latest
650+
`.append_message_stream()`. Therefore, this method must be called in a reactive
651+
context (e.g., a render function, a :func:`~shiny.reactive.calc`, or a
652+
:func:`~shiny.reactive.effect`).
653+
654+
Returns
655+
-------
656+
:
657+
The result of the latest stream (a string).
658+
659+
Raises
660+
------
661+
:
662+
A silent exception if no stream has completed yet.
663+
"""
664+
stream = self._latest_stream()
665+
if stream is None:
666+
from .. import req
667+
668+
req(False)
669+
else:
670+
return stream.result()
671+
630672
async def _append_message_stream(self, message: AsyncIterable[Any]):
631673
id = _utils.private_random_id()
632674

@@ -636,6 +678,7 @@ async def _append_message_stream(self, message: AsyncIterable[Any]):
636678
try:
637679
async for msg in message:
638680
await self._append_message(msg, chunk=True, stream_id=id)
681+
return self._current_stream_message
639682
finally:
640683
await self._append_message(empty, chunk="end", stream_id=id)
641684
await self._flush_pending_messages()

shiny/ui/_markdown_stream.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .._docstring import add_example
88
from .._namespaces import resolve_id
99
from .._typing_extensions import TypedDict
10-
from ..session import require_active_session
10+
from ..session import require_active_session, session_context
1111
from ..types import NotifyException
1212
from ..ui.css import CssUnit, as_css_unit
1313
from . import Tag
@@ -86,6 +86,11 @@ def __init__(
8686

8787
self.on_error = on_error
8888

89+
with session_context(self._session):
90+
self._latest_stream: reactive.Value[
91+
Union[reactive.ExtendedTask[[], str], None]
92+
] = reactive.Value(None)
93+
8994
async def stream(
9095
self,
9196
content: Union[Iterable[str], AsyncIterable[str]],
@@ -109,6 +114,13 @@ async def stream(
109114
----
110115
If you already have the content available as a string, you can do
111116
`.stream([content])` to set the content.
117+
118+
Returns
119+
-------
120+
:
121+
An extended task that represents the streaming task. The `.result()` method
122+
of the task can be called in a reactive context to get the final state of the
123+
stream.
112124
"""
113125

114126
content = _utils.wrap_async_iterable(content)
@@ -139,6 +151,33 @@ async def _handle_error():
139151

140152
return _task
141153

154+
def get_latest_stream_result(self) -> Union[str, None]:
155+
"""
156+
Reactively read the latest stream result.
157+
158+
This method reads a reactive value containing the result of the latest
159+
`.stream()`. Therefore, this method must be called in a reactive context (e.g.,
160+
a render function, a :func:`~shiny.reactive.calc`, or a
161+
:func:`~shiny.reactive.effect`).
162+
163+
Returns
164+
-------
165+
:
166+
The result of the latest stream (a string).
167+
168+
Raises
169+
------
170+
:
171+
A silent exception if no stream has completed yet.
172+
"""
173+
stream = self._latest_stream()
174+
if stream is None:
175+
from .. import req
176+
177+
req(False)
178+
else:
179+
return stream.result()
180+
142181
async def clear(self):
143182
"""
144183
Empty the UI element of the `MarkdownStream`.
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import asyncio
2+
3+
from shiny.express import render, ui
4+
5+
chat = ui.Chat("chat")
6+
7+
chat.ui()
8+
chat.update_user_input(value="Press Enter to start the stream")
9+
10+
11+
async def stream_generator():
12+
for i in range(10):
13+
await asyncio.sleep(0.25)
14+
yield f"Message {i} \n\n"
15+
16+
17+
@chat.on_user_submit
18+
async def _(message: str):
19+
await chat.append_message_stream(stream_generator())
20+
21+
22+
@render.code
23+
async def stream_result_ui():
24+
return chat.get_latest_stream_result()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import re
2+
3+
from playwright.sync_api import Page, expect
4+
from utils.deploy_utils import skip_on_webkit
5+
6+
from shiny.playwright import controller
7+
from shiny.run import ShinyAppProc
8+
9+
10+
@skip_on_webkit
11+
def test_validate_chat_stream_result(page: Page, local_app: ShinyAppProc) -> None:
12+
page.goto(local_app.url)
13+
14+
chat = controller.Chat(page, "chat")
15+
stream_result_ui = controller.OutputCode(page, "stream_result_ui")
16+
17+
expect(chat.loc).to_be_visible(timeout=10 * 1000)
18+
19+
chat.send_user_input()
20+
21+
messages = [
22+
"Message 0",
23+
"Message 1",
24+
"Message 2",
25+
"Message 3",
26+
"Message 4",
27+
"Message 5",
28+
"Message 6",
29+
"Message 7",
30+
"Message 8",
31+
"Message 9",
32+
]
33+
# Allow for any whitespace between messages
34+
chat.expect_messages(re.compile(r"\s*".join(messages)), timeout=30 * 1000)
35+
36+
# Verify that the stream result is as expected
37+
stream_result_ui.expect.to_contain_text("Message 9")

0 commit comments

Comments
 (0)