From 4906a82e3621ed053d9849d42d6049927285b73c Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 21 Jan 2021 17:52:01 -0800 Subject: [PATCH 1/6] Warn if null graphics, dont render in GridWorld if null graphics --- .../Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs | 3 ++- com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs | 5 +++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs index 1d742d2fe6..edb70fee9e 100644 --- a/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs +++ b/Project/Assets/ML-Agents/Examples/GridWorld/Scripts/GridAgent.cs @@ -3,6 +3,7 @@ using System.Linq; using Unity.MLAgents; using Unity.MLAgents.Actuators; +using UnityEngine.Rendering; using UnityEngine.Serialization; public class GridAgent : Agent @@ -150,7 +151,7 @@ public void FixedUpdate() void WaitTimeInference() { - if (renderCamera != null) + if (renderCamera != null && SystemInfo.graphicsDeviceType != GraphicsDeviceType.Null) { renderCamera.Render(); } diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index c135796cc4..e43ad0a58e 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -1,4 +1,5 @@ using UnityEngine; +using UnityEngine.Rendering; namespace Unity.MLAgents.Sensors { @@ -145,6 +146,10 @@ public static Texture2D ObservationToTexture(Camera obsCamera, int width, int he RenderTexture.active = tempRt; obsCamera.targetTexture = tempRt; + if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null) + { + Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render."); + } obsCamera.Render(); texture2D.ReadPixels(new Rect(0, 0, texture2D.width, texture2D.height), 0, 0); From 152c09e7553f4983aa69d0deba4459f23936bd40 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Thu, 21 Jan 2021 17:52:53 -0800 Subject: [PATCH 2/6] log cmdline args, don't overwrite -logfile arg --- ml-agents-envs/mlagents_envs/env_utils.py | 8 ++++++-- ml-agents-envs/mlagents_envs/environment.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/ml-agents-envs/mlagents_envs/env_utils.py b/ml-agents-envs/mlagents_envs/env_utils.py index 17f0c822c7..f2fe48be8b 100644 --- a/ml-agents-envs/mlagents_envs/env_utils.py +++ b/ml-agents-envs/mlagents_envs/env_utils.py @@ -7,6 +7,9 @@ from mlagents_envs.exception import UnityEnvironmentException +logger = get_logger(__name__) + + def get_platform(): """ returns the platform of the operating system : linux, darwin or win32 @@ -27,7 +30,7 @@ def validate_environment_path(env_path: str) -> Optional[str]: .replace(".x86", "") ) true_filename = os.path.basename(os.path.normpath(env_path)) - get_logger(__name__).debug(f"The true file name is {true_filename}") + logger.debug(f"The true file name is {true_filename}") if not (glob.glob(env_path) or glob.glob(env_path + ".*")): return None @@ -99,7 +102,8 @@ def launch_executable(file_name: str, args: List[str]) -> subprocess.Popen: f"Couldn't launch the {file_name} environment. Provided filename does not match any environments." ) else: - get_logger(__name__).debug(f"This is the launch string {launch_string}") + logger.debug(f"The launch string is {launch_string}") + logger.debug(f"Running with args {args}") # Launch Unity environment subprocess_args = [launch_string] + args try: diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 9a0f1bd059..3f3c2e10cf 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -177,7 +177,7 @@ def __init__( # If true, this means the environment was successfully loaded self._loaded = False # The process that is started. If None, no process was started - self._proc1 = None + self._proc1: Optional[subprocess.Popen] = None self._timeout_wait: int = timeout_wait self._communicator = self._get_communicator(worker_id, base_port, timeout_wait) self._worker_id = worker_id @@ -249,7 +249,11 @@ def _executable_args(self) -> List[str]: if self._no_graphics: args += ["-nographics", "-batchmode"] args += [UnityEnvironment._PORT_COMMAND_LINE_ARG, str(self._port)] - if self._log_folder: + + # If the logfile arg isn't already set in the env args, + # try to set it to an output directory + logfile_set = "-logfile" in (arg.lower() for arg in self._additional_args) + if self._log_folder and not logfile_set: log_file_path = os.path.join( self._log_folder, f"Player-{self._worker_id}.log" ) From d0c548ca42d4206d7095add73668a01e83e9d397 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 22 Jan 2021 10:49:45 -0800 Subject: [PATCH 3/6] make sure process is still alive while waiting for connection timeout --- .../UnityConnectSettings.asset | 2 +- .../Runtime/Sensors/CameraSensor.cs | 9 ++-- ml-agents-envs/mlagents_envs/communicator.py | 17 +++++-- ml-agents-envs/mlagents_envs/environment.py | 31 +++++++++--- .../mlagents_envs/mock_communicator.py | 12 +++-- .../mlagents_envs/rpc_communicator.py | 47 +++++++++++++------ .../trainers/subprocess_env_manager.py | 2 +- 7 files changed, 87 insertions(+), 33 deletions(-) diff --git a/Project/ProjectSettings/UnityConnectSettings.asset b/Project/ProjectSettings/UnityConnectSettings.asset index c3ae9a0208..fa0b146579 100644 --- a/Project/ProjectSettings/UnityConnectSettings.asset +++ b/Project/ProjectSettings/UnityConnectSettings.asset @@ -4,7 +4,7 @@ UnityConnectSettings: m_ObjectHideFlags: 0 serializedVersion: 1 - m_Enabled: 1 + m_Enabled: 0 m_TestMode: 0 m_EventOldUrl: https://api.uca.cloud.unity3d.com/v1/events m_EventUrl: https://cdp.cloud.unity3d.com/v1/events diff --git a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs index e43ad0a58e..ab7fb5e786 100644 --- a/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs +++ b/com.unity.ml-agents/Runtime/Sensors/CameraSensor.cs @@ -129,6 +129,11 @@ public SensorCompressionType GetCompressionType() /// Texture2D to render to. public static Texture2D ObservationToTexture(Camera obsCamera, int width, int height) { + if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null) + { + Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render."); + } + var texture2D = new Texture2D(width, height, TextureFormat.RGB24, false); var oldRec = obsCamera.rect; obsCamera.rect = new Rect(0f, 0f, 1f, 1f); @@ -146,10 +151,6 @@ public static Texture2D ObservationToTexture(Camera obsCamera, int width, int he RenderTexture.active = tempRt; obsCamera.targetTexture = tempRt; - if (SystemInfo.graphicsDeviceType == GraphicsDeviceType.Null) - { - Debug.LogError("GraphicsDeviceType is Null. This will likely crash when trying to render."); - } obsCamera.Render(); texture2D.ReadPixels(new Rect(0, 0, texture2D.width, texture2D.height), 0, 0); diff --git a/ml-agents-envs/mlagents_envs/communicator.py b/ml-agents-envs/mlagents_envs/communicator.py index af4a29a781..2223f34d3a 100644 --- a/ml-agents-envs/mlagents_envs/communicator.py +++ b/ml-agents-envs/mlagents_envs/communicator.py @@ -1,8 +1,13 @@ -from typing import Optional +from typing import Callable, Optional from mlagents_envs.communicator_objects.unity_output_pb2 import UnityOutputProto from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto +# Function to call while waiting for a connection timeout. +# This should raise an exception if it needs to break from waiting for the timeout. +PollCallback = Callable[[], None] + + class Communicator: def __init__(self, worker_id=0, base_port=5005): """ @@ -12,17 +17,23 @@ def __init__(self, worker_id=0, base_port=5005): :int base_port: Baseline port number to connect to Unity environment over. worker_id increments over this. """ - def initialize(self, inputs: UnityInputProto) -> UnityOutputProto: + def initialize( + self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None + ) -> UnityOutputProto: """ Used to exchange initialization parameters between Python and the Environment :param inputs: The initialization input that will be sent to the environment. + :param poll_callback: Optional callback to be used while polling the connection. :return: UnityOutput: The initialization output sent by Unity """ - def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]: + def exchange( + self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None + ) -> Optional[UnityOutputProto]: """ Used to send an input and receive an output from the Environment :param inputs: The UnityInput that needs to be sent the Environment + :param poll_callback: Optional callback to be used while polling the connection. :return: The UnityOutputs generated by the Environment """ diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 3f3c2e10cf..785c8988fc 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -293,7 +293,9 @@ def _update_state(self, output: UnityRLOutputProto) -> None: def reset(self) -> None: if self._loaded: - outputs = self._communicator.exchange(self._generate_reset_input()) + outputs = self._communicator.exchange( + self._generate_reset_input(), self._poll_process + ) if outputs is None: raise UnityCommunicatorStoppedException("Communicator has exited.") self._update_behavior_specs(outputs) @@ -321,7 +323,7 @@ def step(self) -> None: ].action_spec.empty_action(n_agents) step_input = self._generate_step_input(self._env_actions) with hierarchical_timer("communicator.exchange"): - outputs = self._communicator.exchange(step_input) + outputs = self._communicator.exchange(step_input, self._poll_process) if outputs is None: raise UnityCommunicatorStoppedException("Communicator has exited.") self._update_behavior_specs(outputs) @@ -381,6 +383,18 @@ def get_steps( self._assert_behavior_exists(behavior_name) return self._env_state[behavior_name] + def _poll_process(self) -> None: + """ + Check the status of the subprocess. If it has exited, raise a UnityEnvironmentException + :return: None + """ + if not self._proc1: + return + poll_res = self._proc1.poll() + if poll_res is not None: + exc_msg = self._returncode_to_env_message(self._proc1.returncode) + raise UnityEnvironmentException(exc_msg) + def close(self): """ Sends a shutdown signal to the unity environment, and closes the socket connection. @@ -405,10 +419,7 @@ def _close(self, timeout: Optional[int] = None) -> None: # Wait a bit for the process to shutdown, but kill it if it takes too long try: self._proc1.wait(timeout=timeout) - signal_name = self._returncode_to_signal_name(self._proc1.returncode) - signal_name = f" ({signal_name})" if signal_name else "" - return_info = f"Environment shut down with return code {self._proc1.returncode}{signal_name}." - logger.info(return_info) + logger.info(self._returncode_to_env_message(self._proc1.returncode)) except subprocess.TimeoutExpired: logger.info("Environment timed out shutting down. Killing...") self._proc1.kill() @@ -456,7 +467,7 @@ def _send_academy_parameters( ) -> UnityOutputProto: inputs = UnityInputProto() inputs.rl_initialization_input.CopyFrom(init_parameters) - return self._communicator.initialize(inputs) + return self._communicator.initialize(inputs, self._poll_process) @staticmethod def _wrap_unity_input(rl_input: UnityRLInputProto) -> UnityInputProto: @@ -477,3 +488,9 @@ def _returncode_to_signal_name(returncode: int) -> Optional[str]: except Exception: # Should generally be a ValueError, but catch everything just in case. return None + + @staticmethod + def _returncode_to_env_message(returncode: int) -> str: + signal_name = UnityEnvironment._returncode_to_signal_name(returncode) + signal_name = f" ({signal_name})" if signal_name else "" + return f"Environment shut down with return code {returncode}{signal_name}." diff --git a/ml-agents-envs/mlagents_envs/mock_communicator.py b/ml-agents-envs/mlagents_envs/mock_communicator.py index 2a358ec987..0e425e2759 100755 --- a/ml-agents-envs/mlagents_envs/mock_communicator.py +++ b/ml-agents-envs/mlagents_envs/mock_communicator.py @@ -1,4 +1,6 @@ -from .communicator import Communicator +from typing import Optional + +from .communicator import Communicator, PollCallback from .environment import UnityEnvironment from mlagents_envs.communicator_objects.unity_rl_output_pb2 import UnityRLOutputProto from mlagents_envs.communicator_objects.brain_parameters_pb2 import ( @@ -39,7 +41,9 @@ def __init__( self.brain_name = brain_name self.vec_obs_size = vec_obs_size - def initialize(self, inputs: UnityInputProto) -> UnityOutputProto: + def initialize( + self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None + ) -> UnityOutputProto: if self.is_discrete: action_spec = ActionSpecProto( num_discrete_actions=2, discrete_branch_sizes=[3, 2] @@ -94,7 +98,9 @@ def _get_agent_infos(self): ) return dict_agent_info - def exchange(self, inputs: UnityInputProto) -> UnityOutputProto: + def exchange( + self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None + ) -> UnityOutputProto: result = UnityRLOutputProto(agentInfos=self._get_agent_infos()) return UnityOutputProto(rl_output=result) diff --git a/ml-agents-envs/mlagents_envs/rpc_communicator.py b/ml-agents-envs/mlagents_envs/rpc_communicator.py index 20ff3cc9a3..c77f851582 100644 --- a/ml-agents-envs/mlagents_envs/rpc_communicator.py +++ b/ml-agents-envs/mlagents_envs/rpc_communicator.py @@ -6,7 +6,7 @@ from multiprocessing import Pipe from concurrent.futures import ThreadPoolExecutor -from .communicator import Communicator +from .communicator import Communicator, PollCallback from mlagents_envs.communicator_objects.unity_to_external_pb2_grpc import ( UnityToExternalProtoServicer, add_UnityToExternalProtoServicer_to_server, @@ -86,22 +86,39 @@ def check_port(self, port): finally: s.close() - def poll_for_timeout(self): + def poll_for_timeout(self, poll_callback: Optional[PollCallback] = None) -> None: """ Polls the GRPC parent connection for data, to be used before calling recv. This prevents us from hanging indefinitely in the case where the environment process has died or was not launched. - """ - if not self.unity_to_external.parent_conn.poll(self.timeout_wait): - raise UnityTimeOutException( - "The Unity environment took too long to respond. Make sure that :\n" - "\t The environment does not need user interaction to launch\n" - '\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n' - "\t The environment and the Python interface have compatible versions." - ) - def initialize(self, inputs: UnityInputProto) -> UnityOutputProto: - self.poll_for_timeout() + Additionally, a callback can be passed to periodically check the state of the environment. + This is used to detect the case when the environment dies without cleaning up the connection, + so that we can stop sooner and raise a more appropriate error. + """ + wait_time_remaining = self.timeout_wait + callback_timeout_wait = self.timeout_wait // 10 + while wait_time_remaining > 0: + if self.unity_to_external.parent_conn.poll(callback_timeout_wait): + # Got an acknowledgment from the connection + return + if poll_callback: + # Fire the callback - if it detects something wrong, it should raise an exception. + poll_callback() + wait_time_remaining -= callback_timeout_wait + + # Got this far without reading any data from the connection, so it must be dead. + raise UnityTimeOutException( + "The Unity environment took too long to respond. Make sure that :\n" + "\t The environment does not need user interaction to launch\n" + '\t The Agents\' Behavior Parameters > Behavior Type is set to "Default"\n' + "\t The environment and the Python interface have compatible versions." + ) + + def initialize( + self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None + ) -> UnityOutputProto: + self.poll_for_timeout(poll_callback) aca_param = self.unity_to_external.parent_conn.recv().unity_output message = UnityMessageProto() message.header.status = 200 @@ -110,12 +127,14 @@ def initialize(self, inputs: UnityInputProto) -> UnityOutputProto: self.unity_to_external.parent_conn.recv() return aca_param - def exchange(self, inputs: UnityInputProto) -> Optional[UnityOutputProto]: + def exchange( + self, inputs: UnityInputProto, poll_callback: Optional[PollCallback] = None + ) -> Optional[UnityOutputProto]: message = UnityMessageProto() message.header.status = 200 message.unity_input.CopyFrom(inputs) self.unity_to_external.parent_conn.send(message) - self.poll_for_timeout() + self.poll_for_timeout(poll_callback) output = self.unity_to_external.parent_conn.recv() if output.header.status != 200: return None diff --git a/ml-agents/mlagents/trainers/subprocess_env_manager.py b/ml-agents/mlagents/trainers/subprocess_env_manager.py index 2c7392a1da..cd599ab37f 100644 --- a/ml-agents/mlagents/trainers/subprocess_env_manager.py +++ b/ml-agents/mlagents/trainers/subprocess_env_manager.py @@ -189,7 +189,7 @@ def _generate_all_results() -> AllStepResult: ) _send_response(EnvironmentCommand.ENV_EXITED, ex) except Exception as ex: - logger.error( + logger.exception( f"UnityEnvironment worker {worker_id}: environment raised an unexpected exception." ) step_queue.put( From 61cc3f6eb7397bad0c735cc87cbdec3c33b45d8e Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 22 Jan 2021 15:49:40 -0800 Subject: [PATCH 4/6] deadline and rename proc1 --- ml-agents-envs/mlagents_envs/environment.py | 20 +++++++++---------- .../mlagents_envs/rpc_communicator.py | 8 ++++---- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ml-agents-envs/mlagents_envs/environment.py b/ml-agents-envs/mlagents_envs/environment.py index 785c8988fc..14b4dddb53 100644 --- a/ml-agents-envs/mlagents_envs/environment.py +++ b/ml-agents-envs/mlagents_envs/environment.py @@ -177,7 +177,7 @@ def __init__( # If true, this means the environment was successfully loaded self._loaded = False # The process that is started. If None, no process was started - self._proc1: Optional[subprocess.Popen] = None + self._process: Optional[subprocess.Popen] = None self._timeout_wait: int = timeout_wait self._communicator = self._get_communicator(worker_id, base_port, timeout_wait) self._worker_id = worker_id @@ -194,7 +194,7 @@ def __init__( ) if file_name is not None: try: - self._proc1 = env_utils.launch_executable( + self._process = env_utils.launch_executable( file_name, self._executable_args() ) except UnityEnvironmentException: @@ -388,11 +388,11 @@ def _poll_process(self) -> None: Check the status of the subprocess. If it has exited, raise a UnityEnvironmentException :return: None """ - if not self._proc1: + if not self._process: return - poll_res = self._proc1.poll() + poll_res = self._process.poll() if poll_res is not None: - exc_msg = self._returncode_to_env_message(self._proc1.returncode) + exc_msg = self._returncode_to_env_message(self._process.returncode) raise UnityEnvironmentException(exc_msg) def close(self): @@ -415,16 +415,16 @@ def _close(self, timeout: Optional[int] = None) -> None: timeout = self._timeout_wait self._loaded = False self._communicator.close() - if self._proc1 is not None: + if self._process is not None: # Wait a bit for the process to shutdown, but kill it if it takes too long try: - self._proc1.wait(timeout=timeout) - logger.info(self._returncode_to_env_message(self._proc1.returncode)) + self._process.wait(timeout=timeout) + logger.info(self._returncode_to_env_message(self._process.returncode)) except subprocess.TimeoutExpired: logger.info("Environment timed out shutting down. Killing...") - self._proc1.kill() + self._process.kill() # Set to None so we don't try to close multiple times. - self._proc1 = None + self._process = None @timed def _generate_step_input( diff --git a/ml-agents-envs/mlagents_envs/rpc_communicator.py b/ml-agents-envs/mlagents_envs/rpc_communicator.py index c77f851582..15c2698b09 100644 --- a/ml-agents-envs/mlagents_envs/rpc_communicator.py +++ b/ml-agents-envs/mlagents_envs/rpc_communicator.py @@ -1,9 +1,10 @@ import grpc from typing import Optional +from multiprocessing import Pipe from sys import platform import socket -from multiprocessing import Pipe +import time from concurrent.futures import ThreadPoolExecutor from .communicator import Communicator, PollCallback @@ -96,16 +97,15 @@ def poll_for_timeout(self, poll_callback: Optional[PollCallback] = None) -> None This is used to detect the case when the environment dies without cleaning up the connection, so that we can stop sooner and raise a more appropriate error. """ - wait_time_remaining = self.timeout_wait + deadline = time.monotonic() + self.timeout_wait callback_timeout_wait = self.timeout_wait // 10 - while wait_time_remaining > 0: + while time.monotonic() < deadline: if self.unity_to_external.parent_conn.poll(callback_timeout_wait): # Got an acknowledgment from the connection return if poll_callback: # Fire the callback - if it detects something wrong, it should raise an exception. poll_callback() - wait_time_remaining -= callback_timeout_wait # Got this far without reading any data from the connection, so it must be dead. raise UnityTimeOutException( From 2cdd68ba798a65fe58a5ed46f9341275de883131 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 22 Jan 2021 15:55:35 -0800 Subject: [PATCH 5/6] changelog --- com.unity.ml-agents/CHANGELOG.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/com.unity.ml-agents/CHANGELOG.md b/com.unity.ml-agents/CHANGELOG.md index a130dc88ae..95a13c6a37 100755 --- a/com.unity.ml-agents/CHANGELOG.md +++ b/com.unity.ml-agents/CHANGELOG.md @@ -24,15 +24,21 @@ removed when training with a player. The Editor still requires it to be clamped Updated the Basic example and the Match3 Example to use Actuators. Changed the namespace and file names of classes in com.unity.ml-agents.extensions. (#4849) - #### ml-agents / ml-agents-envs / gym-unity (Python) ### Bug Fixes #### com.unity.ml-agents (C#) - Fix a compile warning about using an obsolete enum in `GrpcExtensions.cs`. (#4812) +- CameraSensor now logs an error if the GraphicsDevice is null. (#4880) #### ml-agents / ml-agents-envs / gym-unity (Python) - Fixed a bug that would cause an exception when `RunOptions` was deserialized via `pickle`. (#4842) - Fixed the computation of entropy for continuous actions. (#4869) +- Fixed a bug that would cause `UnityEnvironment` to wait the full timeout + period and report a misleading error message if the executable crashed + without closing the connection. It now periodically checks the process status + while waiting for a connection, and raises a better error message if it crashes. (#4880) +- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is + no longer overwritten. (#4880) ## [1.7.2-preview] - 2020-12-22 From cb001e689feb9d561ef3f990037463e459f6f49e Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Fri, 22 Jan 2021 16:24:14 -0800 Subject: [PATCH 6/6] add unit tests --- .../tests/test_rpc_communicator.py | 54 ++++++++++++++++++- 1 file changed, 53 insertions(+), 1 deletion(-) diff --git a/ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py b/ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py index 86b9b23c63..e912878a26 100644 --- a/ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py +++ b/ml-agents-envs/mlagents_envs/tests/test_rpc_communicator.py @@ -1,7 +1,16 @@ import pytest +from unittest import mock +import grpc + +import mlagents_envs.rpc_communicator from mlagents_envs.rpc_communicator import RpcCommunicator -from mlagents_envs.exception import UnityWorkerInUseException +from mlagents_envs.exception import ( + UnityWorkerInUseException, + UnityTimeOutException, + UnityEnvironmentException, +) +from mlagents_envs.communicator_objects.unity_input_pb2 import UnityInputProto def test_rpc_communicator_checks_port_on_create(): @@ -28,3 +37,46 @@ def test_rpc_communicator_create_multiple_workers(): second_comm = RpcCommunicator(worker_id=1) first_comm.close() second_comm.close() + + +@mock.patch.object(grpc, "server") +@mock.patch.object( + mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" +) +def test_rpc_communicator_initialize_OK(mock_impl, mock_grpc_server): + comm = RpcCommunicator(timeout_wait=0.25) + comm.unity_to_external.parent_conn.poll.return_value = True + input = UnityInputProto() + comm.initialize(input) + comm.unity_to_external.parent_conn.poll.assert_called() + + +@mock.patch.object(grpc, "server") +@mock.patch.object( + mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" +) +def test_rpc_communicator_initialize_timeout(mock_impl, mock_grpc_server): + comm = RpcCommunicator(timeout_wait=0.25) + comm.unity_to_external.parent_conn.poll.return_value = None + input = UnityInputProto() + # Expect a timeout + with pytest.raises(UnityTimeOutException): + comm.initialize(input) + comm.unity_to_external.parent_conn.poll.assert_called() + + +@mock.patch.object(grpc, "server") +@mock.patch.object( + mlagents_envs.rpc_communicator, "UnityToExternalServicerImplementation" +) +def test_rpc_communicator_initialize_callback(mock_impl, mock_grpc_server): + def callback(): + raise UnityEnvironmentException + + comm = RpcCommunicator(timeout_wait=0.25) + comm.unity_to_external.parent_conn.poll.return_value = None + input = UnityInputProto() + # Expect a timeout + with pytest.raises(UnityEnvironmentException): + comm.initialize(input, poll_callback=callback) + comm.unity_to_external.parent_conn.poll.assert_called()