Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 69 additions & 48 deletions src/agentlab/analyze/agent_xray.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,32 @@ def filter_agent_id(self, agent_id: list[tuple]):
}
"""

# Keyboard shortcut JavaScript - based on https://github.com/gradio-app/gradio/issues/6101
shortcut_js = """
<script>
function shortcuts(e) {
var event = document.all ? window.event : e;
switch (e.target.tagName.toLowerCase()) {
case "input":
case "textarea":
case "select":
case "button":
return;
default:
if ((e.key === 'ArrowLeft' || e.key === 'ArrowRight') && (e.metaKey || e.ctrlKey)) {
e.preventDefault();
if (e.key === 'ArrowLeft') {
document.getElementById("prev_btn").click();
} else {
document.getElementById("next_btn").click();
}
}
}
}
document.addEventListener('keydown', shortcuts, false);
</script>
"""


def run_gradio(results_dir: Path):
"""
Expand All @@ -173,14 +199,12 @@ def run_gradio(results_dir: Path):
global info
info.results_dir = results_dir

with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
with gr.Blocks(theme=gr.themes.Soft(), css=css, head=shortcut_js) as demo:
agent_id = gr.State(value=None)
episode_id = gr.State(value=EpisodeId())
agent_task_id = gr.State(value=None)
step_id = gr.State(value=None)

hidden_key_input = gr.Textbox(visible=False, elem_id="key_capture")

with gr.Accordion("Help", open=False):
gr.Markdown(
"""\
Expand Down Expand Up @@ -302,6 +326,16 @@ def run_gradio(results_dir: Path):
action_info = gr.Markdown(label="Action Info", elem_classes="my-markdown")
state_error = gr.Markdown(label="Next Step Error", elem_classes="my-markdown")

with gr.Row(variant="panel", elem_classes=["items-center", "justify-center"]):
step_indicator = gr.Markdown("### Step 0/0", elem_classes=["text-center"])
prev_btn = gr.Button(
"◀ Previous", size="md", scale=0, elem_id="prev_btn", elem_classes=["mx-auto"]
)
next_btn = gr.Button(
"Next ▶", size="md", scale=0, elem_id="next_btn", elem_classes=["mx-auto"]
)
gr.Markdown("(Shortcut: Ctrl/Cmd + ← →)", elem_classes=["text-center"])

profiling_gr = gr.Image(
label="Profiling", show_label=False, interactive=False, show_download_button=False
)
Expand Down Expand Up @@ -511,31 +545,37 @@ def run_gradio(results_dir: Path):

demo.load(fn=refresh_exp_dir_choices, inputs=exp_dir_choice, outputs=exp_dir_choice)

demo.load(
None,
None,
None,
js="""
function() {
document.addEventListener('keydown', function(e) {
if ((e.key === 'ArrowLeft' || e.key === 'ArrowRight') && (e.metaKey || e.ctrlKey)) {
e.preventDefault();
const hiddenInput = document.querySelector('#key_capture input, #key_capture textarea');
if (hiddenInput) {
let event = e.key === 'ArrowLeft' ? 'Cmd+Left' : 'Cmd+Right';
hiddenInput.value = event;
hiddenInput.dispatchEvent(new Event('input', {bubbles: true}));
}
}
});
}
""",
)
hidden_key_input.change(
handle_key_event,
inputs=[hidden_key_input, step_id],
outputs=[hidden_key_input, step_id],
)
# Simple navigation button events
def navigate_prev(step_id: StepId):
global info
if step_id and step_id.step is not None and step_id.episode_id:
step = max(0, step_id.step - 1)
info.step = step
return StepId(episode_id=step_id.episode_id, step=step)
return step_id

def navigate_next(step_id: StepId):
global info
if step_id and step_id.step is not None and step_id.episode_id and info.exp_result:
step = min(len(info.exp_result.steps_info) - 1, step_id.step + 1)
info.step = step
return StepId(episode_id=step_id.episode_id, step=step)
return step_id

prev_btn.click(navigate_prev, inputs=[step_id], outputs=[step_id])
next_btn.click(navigate_next, inputs=[step_id], outputs=[step_id])

# Update step indicator display
def format_step_indicator(step_id):
global info
if not step_id or not info.exp_result or not info.exp_result.steps_info:
return "### Step 0/0"
# 1-based for user, total steps is len-1 (last is terminal)
current = (step_id.step + 1) if step_id.step is not None else 0
total = max(len(info.exp_result.steps_info) - 1, 0)
return f"### Step {current}/{total}"

step_id.change(format_step_indicator, inputs=[step_id], outputs=[step_indicator])

demo.queue()

Expand All @@ -546,25 +586,6 @@ def run_gradio(results_dir: Path):
demo.launch(server_port=port, share=do_share)


def handle_key_event(key_event, step_id: StepId):

if key_event:
global info

# print(f"Key event: {key_event}")
step = step_id.step
if key_event.startswith("Cmd+Left"):
step = max(0, step - 1)
elif key_event.startswith("Cmd+Right"):
step = min(len(info.exp_result.steps_info) - 2, step + 1)
else:
return gr.update()
# print(f"Updating step to {step} from key event {key_event}")
info.step = step
step_id = StepId(episode_id=step_id.episode_id, step=step)
return ("", step_id)


def tab_select(evt: gr.SelectData):
global info
info.active_tab = evt.value
Expand Down Expand Up @@ -947,7 +968,7 @@ def get_episode_info(info: Info):

info = f"""\
### {env_args.task_name} (seed: {env_args.task_seed})
### Step {info.step} / {len(steps_info) - 1} (Reward: {cum_reward:.1f})
### (Reward: {cum_reward:.1f})

**Goal:**

Expand Down
Loading