diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 2edefc92..f7c400ca 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -164,6 +164,32 @@ def filter_agent_id(self, agent_id: list[tuple]): } """ +# Keyboard shortcut JavaScript - based on https://github.com/gradio-app/gradio/issues/6101 +shortcut_js = """ + +""" + def run_gradio(results_dir: Path): """ @@ -173,14 +199,12 @@ def run_gradio(results_dir: Path): global info info.results_dir = results_dir - with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: + with gr.Blocks(theme=gr.themes.Soft(), css=css, head=shortcut_js) as demo: agent_id = gr.State(value=None) episode_id = gr.State(value=EpisodeId()) agent_task_id = gr.State(value=None) step_id = gr.State(value=None) - hidden_key_input = gr.Textbox(visible=False, elem_id="key_capture") - with gr.Accordion("Help", open=False): gr.Markdown( """\ @@ -302,6 +326,16 @@ def run_gradio(results_dir: Path): action_info = gr.Markdown(label="Action Info", elem_classes="my-markdown") state_error = gr.Markdown(label="Next Step Error", elem_classes="my-markdown") + with gr.Row(variant="panel", elem_classes=["items-center", "justify-center"]): + step_indicator = gr.Markdown("### Step 0/0", elem_classes=["text-center"]) + prev_btn = gr.Button( + "◀ Previous", size="md", scale=0, elem_id="prev_btn", elem_classes=["mx-auto"] + ) + next_btn = gr.Button( + "Next ▶", size="md", scale=0, elem_id="next_btn", elem_classes=["mx-auto"] + ) + gr.Markdown("(Shortcut: Ctrl/Cmd + ← →)", elem_classes=["text-center"]) + profiling_gr = gr.Image( label="Profiling", show_label=False, interactive=False, show_download_button=False ) @@ -511,31 +545,37 @@ def run_gradio(results_dir: Path): demo.load(fn=refresh_exp_dir_choices, inputs=exp_dir_choice, outputs=exp_dir_choice) - demo.load( - None, - None, - None, - js=""" - function() { - document.addEventListener('keydown', function(e) { - if ((e.key === 'ArrowLeft' || e.key === 'ArrowRight') && (e.metaKey || e.ctrlKey)) { - e.preventDefault(); - const hiddenInput = document.querySelector('#key_capture input, #key_capture textarea'); - if (hiddenInput) { - let event = e.key === 'ArrowLeft' ? 'Cmd+Left' : 'Cmd+Right'; - hiddenInput.value = event; - hiddenInput.dispatchEvent(new Event('input', {bubbles: true})); - } - } - }); - } - """, - ) - hidden_key_input.change( - handle_key_event, - inputs=[hidden_key_input, step_id], - outputs=[hidden_key_input, step_id], - ) + # Simple navigation button events + def navigate_prev(step_id: StepId): + global info + if step_id and step_id.step is not None and step_id.episode_id: + step = max(0, step_id.step - 1) + info.step = step + return StepId(episode_id=step_id.episode_id, step=step) + return step_id + + def navigate_next(step_id: StepId): + global info + if step_id and step_id.step is not None and step_id.episode_id and info.exp_result: + step = min(len(info.exp_result.steps_info) - 1, step_id.step + 1) + info.step = step + return StepId(episode_id=step_id.episode_id, step=step) + return step_id + + prev_btn.click(navigate_prev, inputs=[step_id], outputs=[step_id]) + next_btn.click(navigate_next, inputs=[step_id], outputs=[step_id]) + + # Update step indicator display + def format_step_indicator(step_id): + global info + if not step_id or not info.exp_result or not info.exp_result.steps_info: + return "### Step 0/0" + # 1-based for user, total steps is len-1 (last is terminal) + current = (step_id.step + 1) if step_id.step is not None else 0 + total = max(len(info.exp_result.steps_info) - 1, 0) + return f"### Step {current}/{total}" + + step_id.change(format_step_indicator, inputs=[step_id], outputs=[step_indicator]) demo.queue() @@ -546,25 +586,6 @@ def run_gradio(results_dir: Path): demo.launch(server_port=port, share=do_share) -def handle_key_event(key_event, step_id: StepId): - - if key_event: - global info - - # print(f"Key event: {key_event}") - step = step_id.step - if key_event.startswith("Cmd+Left"): - step = max(0, step - 1) - elif key_event.startswith("Cmd+Right"): - step = min(len(info.exp_result.steps_info) - 2, step + 1) - else: - return gr.update() - # print(f"Updating step to {step} from key event {key_event}") - info.step = step - step_id = StepId(episode_id=step_id.episode_id, step=step) - return ("", step_id) - - def tab_select(evt: gr.SelectData): global info info.active_tab = evt.value @@ -947,7 +968,7 @@ def get_episode_info(info: Info): info = f"""\ ### {env_args.task_name} (seed: {env_args.task_seed}) -### Step {info.step} / {len(steps_info) - 1} (Reward: {cum_reward:.1f}) +### (Reward: {cum_reward:.1f}) **Goal:**