-
Notifications
You must be signed in to change notification settings - Fork 4.4k
[MLA-1712] Make UnityEnvironment fail fast if the env crashes #4880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4906a82
152c09e
d0c548c
61cc3f6
2cdd68b
cb001e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -177,7 +177,7 @@ def __init__( | |
# If true, this means the environment was successfully loaded | ||
self._loaded = False | ||
# The process that is started. If None, no process was started | ||
self._proc1 = None | ||
self._process: Optional[subprocess.Popen] = None | ||
self._timeout_wait: int = timeout_wait | ||
self._communicator = self._get_communicator(worker_id, base_port, timeout_wait) | ||
self._worker_id = worker_id | ||
|
@@ -194,7 +194,7 @@ def __init__( | |
) | ||
if file_name is not None: | ||
try: | ||
self._proc1 = env_utils.launch_executable( | ||
self._process = env_utils.launch_executable( | ||
file_name, self._executable_args() | ||
) | ||
except UnityEnvironmentException: | ||
|
@@ -249,7 +249,11 @@ def _executable_args(self) -> List[str]: | |
if self._no_graphics: | ||
args += ["-nographics", "-batchmode"] | ||
args += [UnityEnvironment._PORT_COMMAND_LINE_ARG, str(self._port)] | ||
if self._log_folder: | ||
|
||
# If the logfile arg isn't already set in the env args, | ||
# try to set it to an output directory | ||
logfile_set = "-logfile" in (arg.lower() for arg in self._additional_args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was annoying - if you wanted to pass e.g. |
||
if self._log_folder and not logfile_set: | ||
log_file_path = os.path.join( | ||
self._log_folder, f"Player-{self._worker_id}.log" | ||
) | ||
|
@@ -289,7 +293,9 @@ def _update_state(self, output: UnityRLOutputProto) -> None: | |
|
||
def reset(self) -> None: | ||
if self._loaded: | ||
outputs = self._communicator.exchange(self._generate_reset_input()) | ||
outputs = self._communicator.exchange( | ||
self._generate_reset_input(), self._poll_process | ||
) | ||
if outputs is None: | ||
raise UnityCommunicatorStoppedException("Communicator has exited.") | ||
self._update_behavior_specs(outputs) | ||
|
@@ -317,7 +323,7 @@ def step(self) -> None: | |
].action_spec.empty_action(n_agents) | ||
step_input = self._generate_step_input(self._env_actions) | ||
with hierarchical_timer("communicator.exchange"): | ||
outputs = self._communicator.exchange(step_input) | ||
outputs = self._communicator.exchange(step_input, self._poll_process) | ||
if outputs is None: | ||
raise UnityCommunicatorStoppedException("Communicator has exited.") | ||
self._update_behavior_specs(outputs) | ||
|
@@ -377,6 +383,18 @@ def get_steps( | |
self._assert_behavior_exists(behavior_name) | ||
return self._env_state[behavior_name] | ||
|
||
def _poll_process(self) -> None: | ||
""" | ||
Check the status of the subprocess. If it has exited, raise a UnityEnvironmentException | ||
:return: None | ||
""" | ||
if not self._process: | ||
return | ||
poll_res = self._process.poll() | ||
if poll_res is not None: | ||
exc_msg = self._returncode_to_env_message(self._process.returncode) | ||
raise UnityEnvironmentException(exc_msg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure if we want to have separate exception messages for returncode 0 and non-0. |
||
|
||
def close(self): | ||
""" | ||
Sends a shutdown signal to the unity environment, and closes the socket connection. | ||
|
@@ -397,19 +415,16 @@ def _close(self, timeout: Optional[int] = None) -> None: | |
timeout = self._timeout_wait | ||
self._loaded = False | ||
self._communicator.close() | ||
if self._proc1 is not None: | ||
if self._process is not None: | ||
# Wait a bit for the process to shutdown, but kill it if it takes too long | ||
try: | ||
self._proc1.wait(timeout=timeout) | ||
signal_name = self._returncode_to_signal_name(self._proc1.returncode) | ||
signal_name = f" ({signal_name})" if signal_name else "" | ||
return_info = f"Environment shut down with return code {self._proc1.returncode}{signal_name}." | ||
logger.info(return_info) | ||
self._process.wait(timeout=timeout) | ||
logger.info(self._returncode_to_env_message(self._process.returncode)) | ||
except subprocess.TimeoutExpired: | ||
logger.info("Environment timed out shutting down. Killing...") | ||
self._proc1.kill() | ||
self._process.kill() | ||
# Set to None so we don't try to close multiple times. | ||
self._proc1 = None | ||
self._process = None | ||
|
||
@timed | ||
def _generate_step_input( | ||
|
@@ -452,7 +467,7 @@ def _send_academy_parameters( | |
) -> UnityOutputProto: | ||
inputs = UnityInputProto() | ||
inputs.rl_initialization_input.CopyFrom(init_parameters) | ||
return self._communicator.initialize(inputs) | ||
return self._communicator.initialize(inputs, self._poll_process) | ||
|
||
@staticmethod | ||
def _wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto: | ||
|
@@ -473,3 +488,9 @@ def _returncode_to_signal_name(returncode: int) -> Optional[str]: | |
except Exception: | ||
# Should generally be a ValueError, but catch everything just in case. | ||
return None | ||
|
||
@staticmethod | ||
def _returncode_to_env_message(returncode: int) -> str: | ||
signal_name = UnityEnvironment._returncode_to_signal_name(returncode) | ||
signal_name = f" ({signal_name})" if signal_name else "" | ||
return f"Environment shut down with return code {returncode}{signal_name}." |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
import grpc | ||
from typing import Optional | ||
|
||
from multiprocessing import Pipe | ||
from sys import platform | ||
import socket | ||
from multiprocessing import Pipe | ||
import time | ||
from concurrent.futures import ThreadPoolExecutor | ||
|
||
from .communicator import Communicator | ||
from .communicator import Communicator, PollCallback | ||
from mlagents_envs.communicator_objects.unity_to_external_pb2_grpc import ( | ||
UnityToExternalProtoServicer, | ||
add_UnityToExternalProtoServicer_to_server, | ||
|
@@ -86,22 +87,38 @@ def check_port(self, port): | |
finally: | ||
s.close() | ||
|
||
def poll_for_timeout(self): | ||
def poll_for_timeout(self, poll_callback: Optional[PollCallback] = None) -> None: | ||
""" | ||
Polls the GRPC parent connection for data, to be used before calling recv. This prevents | ||
us from hanging indefinitely in the case where the environment process has died or was not | ||
launched. | ||
""" | ||
if not self.unity_to_external.parent_conn.poll(self.timeout_wait): | ||
raise UnityTimeOutException( | ||
"The Unity environment took too long to respond. Make sure that :\n" | ||
"\t The environment does not need user interaction to launch\n" | ||
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n' | ||
"\t The environment and the Python interface have compatible versions." | ||
) | ||
|
||
def initialize(self, inputs: UnityInputProto) -> UnityOutputProto: | ||
self.poll_for_timeout() | ||
Additionally, a callback can be passed to periodically check the state of the environment. | ||
This is used to detect the case when the environment dies without cleaning up the connection, | ||
so that we can stop sooner and raise a more appropriate error. | ||
""" | ||
deadline = time.monotonic() + self.timeout_wait | ||
callback_timeout_wait = self.timeout_wait // 10 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could make the "10" configurable, but seemed like a good rule of thumb. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't really care if the timeout is configurable, but if you do make it configurable please make sure to check it is less than |
||
while time.monotonic() < deadline: | ||
if self.unity_to_external.parent_conn.poll(callback_timeout_wait): | ||
# Got an acknowledgment from the connection | ||
return | ||
if poll_callback: | ||
# Fire the callback - if it detects something wrong, it should raise an exception. | ||
poll_callback() | ||
|
||
# Got this far without reading any data from the connection, so it must be dead. | ||
raise UnityTimeOutException( | ||
"The Unity environment took too long to respond. Make sure that :\n" | ||
"\t The environment does not need user interaction to launch\n" | ||
'\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n' | ||
"\t The environment and the Python interface have compatible versions." | ||
) | ||
|
||
def initialize( | ||
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None | ||
) -> UnityOutputProto: | ||
self.poll_for_timeout(poll_callback) | ||
aca_param = self.unity_to_external.parent_conn.recv().unity_output | ||
message = UnityMessageProto() | ||
message.header.status = 200 | ||
|
@@ -110,12 +127,14 @@ def initialize(self, inputs: UnityInputProto) -> UnityOutputProto: | |
self.unity_to_external.parent_conn.recv() | ||
return aca_param | ||
|
||
def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]: | ||
def exchange( | ||
self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None | ||
) -> Optional[UnityOutputProto]: | ||
message = UnityMessageProto() | ||
message.header.status = 200 | ||
message.unity_input.CopyFrom(inputs) | ||
self.unity_to_external.parent_conn.send(message) | ||
self.poll_for_timeout() | ||
self.poll_for_timeout(poll_callback) | ||
output = self.unity_to_external.parent_conn.recv() | ||
if output.header.status != 200: | ||
return None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Automatically set during player build. Might as well commit it now so we don't have to keep undoing it.