Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions ml-agents-envs/mlagents_envs/base_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,28 +410,20 @@ def random_action(self, n_agents: int) -> ActionTuple:
return ActionTuple(continuous=_continuous, discrete=_discrete)

def _validate_action(
self, actions: ActionTuple, n_agents: Optional[int], name: str
self, actions: ActionTuple, n_agents: int, name: str
) -> ActionTuple:
"""
Validates that action has the correct action dim
for the correct number of agents and ensures the type.
"""
_expected_shape = (
(n_agents, self.continuous_size)
if n_agents is not None
else (self.continuous_size,)
)
_expected_shape = (n_agents, self.continuous_size)
if actions.continuous.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a continuous input of dimension "
f"{_expected_shape} for (<number of agents>, <action size>) but "
f"received input of dimension {actions.continuous.shape}"
)
_expected_shape = (
(n_agents, self.discrete_size)
if n_agents is not None
else (self.discrete_size,)
)
_expected_shape = (n_agents, self.discrete_size)
if actions.discrete.shape != _expected_shape:
raise UnityActionException(
f"The behavior {name} needs a discrete input of dimension "
Expand Down
4 changes: 2 additions & 2 deletions ml-agents-envs/mlagents_envs/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,9 +365,9 @@ def set_action_for_agent(
if behavior_name not in self._env_state:
return
action_spec = self._env_specs[behavior_name].action_spec
num_agents = len(self._env_state[behavior_name][0])
action = action_spec._validate_action(action, None, behavior_name)
action = action_spec._validate_action(action, 1, behavior_name)
if behavior_name not in self._env_actions:
num_agents = len(self._env_state[behavior_name][0])
self._env_actions[behavior_name] = action_spec.empty_action(num_agents)
try:
index = np.where(self._env_state[behavior_name][0].agent_id == agent_id)[0][
Expand Down
55 changes: 55 additions & 0 deletions ml-agents-envs/mlagents_envs/tests/test_set_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from mlagents_envs.registry import default_registry
from mlagents_envs.side_channel.engine_configuration_channel import (
EngineConfigurationChannel,
)
from mlagents_envs.base_env import ActionTuple
import numpy as np

BALL_ID = "3DBall"


def test_set_action_single_agent():
engine_config_channel = EngineConfigurationChannel()
env = default_registry[BALL_ID].make(
base_port=6000,
worker_id=0,
no_graphics=True,
side_channels=[engine_config_channel],
)
engine_config_channel.set_configuration_parameters(time_scale=100)
for _ in range(3):
env.reset()
behavior_name = list(env.behavior_specs.keys())[0]
d, t = env.get_steps(behavior_name)
for _ in range(50):
for agent_id in d.agent_id:
action = np.ones((1, 2))
action_tuple = ActionTuple()
action_tuple.add_continuous(action)
env.set_action_for_agent(behavior_name, agent_id, action_tuple)
env.step()
d, t = env.get_steps(behavior_name)
env.close()


def test_set_action_multi_agent():
engine_config_channel = EngineConfigurationChannel()
env = default_registry[BALL_ID].make(
base_port=6001,
worker_id=0,
no_graphics=True,
side_channels=[engine_config_channel],
)
engine_config_channel.set_configuration_parameters(time_scale=100)
for _ in range(3):
env.reset()
behavior_name = list(env.behavior_specs.keys())[0]
d, t = env.get_steps(behavior_name)
for _ in range(50):
action = np.ones((len(d), 2))
action_tuple = ActionTuple()
action_tuple.add_continuous(action)
env.set_actions(behavior_name, action_tuple)
env.step()
d, t = env.get_steps(behavior_name)
env.close()