Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def get_value_estimates(
"""
vec_vis_obs = SplitObservations.from_observations(decision_requests.obs)

value_estimates, mean_value = self.policy.actor_critic.critic_pass(
value_estimates = self.policy.actor_critic.critic_pass(
np.expand_dims(vec_vis_obs.vector_observations[idx], 0),
np.expand_dims(vec_vis_obs.visual_observations[idx], 0),
)
Expand Down Expand Up @@ -97,11 +97,11 @@ def get_trajectory_value_estimates(
next_obs = [ModelUtils.list_to_tensor(next_obs).unsqueeze(0)]
next_memory = torch.zeros([1, 1, self.policy.m_size])

value_estimates, mean_value = self.policy.actor_critic.critic_pass(
value_estimates = self.policy.actor_critic.critic_pass(
vector_obs, visual_obs, memory
)

next_value_estimate, next_value = self.policy.actor_critic.critic_pass(
next_value_estimate = self.policy.actor_critic.critic_pass(
next_obs, next_obs, next_memory
)

Expand Down
31 changes: 13 additions & 18 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
import numpy as np
import torch

Expand All @@ -14,7 +14,8 @@

from mlagents.trainers.settings import TrainerSettings, TestingConfiguration
from mlagents.trainers.trajectory import SplitObservations
from mlagents.trainers.torch.networks import ActorCritic
from mlagents.trainers.torch.networks import SharedActorCritic, SeparateActorCritic
from mlagents.trainers.torch.utils import ModelUtils

EPSILON = 1e-7 # Small value to avoid divide by zero

Expand All @@ -29,8 +30,8 @@ def __init__(
load: bool = False,
tanh_squash: bool = False,
reparameterize: bool = False,
separate_critic: bool = True,
condition_sigma_on_obs: bool = True,
separate_critic: Optional[bool] = None,
):
"""
Policy that uses a multilayer perceptron to map the observations to actions. Could
Expand Down Expand Up @@ -69,15 +70,16 @@ def __init__(
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
self.actor_critic = ActorCritic(
if separate_critic:
ac_class = SeparateActorCritic
else:
ac_class = SharedActorCritic
self.actor_critic = ac_class(
observation_shapes=self.behavior_spec.observation_shapes,
network_settings=trainer_settings.network_settings,
act_type=behavior_spec.action_type,
act_size=self.act_size,
stream_names=reward_signal_names,
separate_critic=separate_critic
if separate_critic is not None
else self.use_continuous_act,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
Expand Down Expand Up @@ -117,16 +119,11 @@ def sample_actions(
"""
:param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
"""
(
dists,
(value_heads, mean_value),
memories,
) = self.actor_critic.get_dist_and_value(
dists, value_heads, memories = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)

action_list = self.actor_critic.sample_action(dists)
log_probs, entropies, all_logs = self.actor_critic.get_probs_and_entropy(
log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
action_list, dists
)
actions = torch.stack(action_list, dim=-1)
Expand All @@ -146,15 +143,13 @@ def sample_actions(
def evaluate_actions(
self, vec_obs, vis_obs, actions, masks=None, memories=None, seq_len=1
):
dists, (value_heads, mean_value), _ = self.actor_critic.get_dist_and_value(
dists, value_heads, _ = self.actor_critic.get_dist_and_value(
vec_obs, vis_obs, masks, memories, seq_len
)
if len(actions.shape) <= 2:
actions = actions.unsqueeze(-1)
action_list = [actions[..., i] for i in range(actions.shape[2])]
log_probs, entropies, _ = self.actor_critic.get_probs_and_entropy(
action_list, dists
)
log_probs, entropies, _ = ModelUtils.get_probs_and_entropy(action_list, dists)

return log_probs, entropies, value_heads

Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def create_torch_policy(
self.artifact_path,
self.load,
condition_sigma_on_obs=False, # Faster training for PPO
separate_critic=behavior_spec.is_action_continuous(),
)
return policy

Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/tests/test_reward_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.ppo.optimizer_tf import TFPPOOptimizer
from mlagents.trainers.tests.test_simple_rl import PPO_CONFIG, SAC_CONFIG
from mlagents.trainers.settings import (
GAILSettings,
Expand Down Expand Up @@ -75,7 +75,7 @@ def create_optimizer_mock(
if trainer_settings.trainer_type == TrainerType.SAC:
optimizer = SACOptimizer(policy, trainer_settings)
else:
optimizer = PPOOptimizer(policy, trainer_settings)
optimizer = TFPPOOptimizer(policy, trainer_settings)
return optimizer


Expand Down
208 changes: 208 additions & 0 deletions ml-agents/mlagents/trainers/tests/torch/test_networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
import pytest

import torch
from mlagents.trainers.torch.networks import (
NetworkBody,
ValueNetwork,
SimpleActor,
SharedActorCritic,
SeparateActorCritic,
)
from mlagents.trainers.settings import NetworkSettings
from mlagents_envs.base_env import ActionType
from mlagents.trainers.torch.distributions import (
GaussianDistInstance,
CategoricalDistInstance,
)


def test_networkbody_vector():
obs_size = 4
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]

networkbody = NetworkBody(obs_shapes, network_settings, encoded_act_size=2)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, obs_size))
sample_act = torch.ones((1, 2))

for _ in range(100):
encoded, _ = networkbody([sample_obs], [], sample_act)
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)


def test_networkbody_lstm():
obs_size = 4
seq_len = 16
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings(sequence_length=seq_len, memory_size=4)
)
obs_shapes = [(obs_size,)]

networkbody = NetworkBody(obs_shapes, network_settings)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, seq_len, obs_size))

for _ in range(100):
encoded, _ = networkbody([sample_obs], [], memories=torch.ones(1, seq_len, 4))
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)


def test_networkbody_visual():
vec_obs_size = 4
obs_size = (84, 84, 3)
network_settings = NetworkSettings()
obs_shapes = [(vec_obs_size,), obs_size]
torch.random.manual_seed(0)

networkbody = NetworkBody(obs_shapes, network_settings)
optimizer = torch.optim.Adam(networkbody.parameters(), lr=3e-3)
sample_obs = torch.ones((1, 84, 84, 3))
sample_vec_obs = torch.ones((1, vec_obs_size))

for _ in range(100):
encoded, _ = networkbody([sample_vec_obs], [sample_obs])
assert encoded.shape == (1, network_settings.hidden_units)
# Try to force output to 1
loss = torch.nn.functional.mse_loss(encoded, torch.ones(encoded.shape))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for _enc in encoded.flatten():
assert _enc == pytest.approx(1.0, abs=0.1)


def test_valuenetwork():
obs_size = 4
num_outputs = 2
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]

stream_names = [f"stream_name{n}" for n in range(4)]
value_net = ValueNetwork(
stream_names, obs_shapes, network_settings, outputs_per_stream=num_outputs
)
optimizer = torch.optim.Adam(value_net.parameters(), lr=3e-3)

for _ in range(50):
sample_obs = torch.ones((1, obs_size))
values, _ = value_net([sample_obs], [])
loss = 0
for s_name in stream_names:
assert values[s_name].shape == (1, num_outputs)
# Try to force output to 1
loss += torch.nn.functional.mse_loss(
values[s_name], torch.ones((1, num_outputs))
)

optimizer.zero_grad()
loss.backward()
optimizer.step()
# In the last step, values should be close to 1
for value in values.values():
for _out in value:
assert _out[0] == pytest.approx(1.0, abs=0.1)


@pytest.mark.parametrize("action_type", [ActionType.DISCRETE, ActionType.CONTINUOUS])
def test_simple_actor(action_type):
obs_size = 4
network_settings = NetworkSettings()
obs_shapes = [(obs_size,)]
act_size = [2]
masks = None if action_type == ActionType.CONTINUOUS else torch.ones((1, 1))
actor = SimpleActor(obs_shapes, network_settings, action_type, act_size)
# Test get_dist
sample_obs = torch.ones((1, obs_size))
dists, _ = actor.get_dists([sample_obs], [], masks=masks)
for dist in dists:
if action_type == ActionType.CONTINUOUS:
assert isinstance(dist, GaussianDistInstance)
else:
assert isinstance(dist, CategoricalDistInstance)

# Test sample_actions
actions = actor.sample_action(dists)
for act in actions:
if action_type == ActionType.CONTINUOUS:
assert act.shape == (1, act_size[0])
else:
assert act.shape == (1, 1)

# Test forward
actions, probs, ver_num, mem_size, is_cont, act_size_vec = actor.forward(
[sample_obs], [], masks=masks
)
for act in actions:
if action_type == ActionType.CONTINUOUS:
assert act.shape == (
act_size[0],
1,
) # This is different from above for ONNX export
else:
assert act.shape == (1, 1)

# TODO: Once export works properly. fix the shapes here.
assert mem_size == 0
assert is_cont == int(action_type == ActionType.CONTINUOUS)
assert act_size_vec == torch.tensor(act_size)


@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic])
@pytest.mark.parametrize("lstm", [True, False])
def test_actor_critic(ac_type, lstm):
obs_size = 4
network_settings = NetworkSettings(
memory=NetworkSettings.MemorySettings() if lstm else None
)
obs_shapes = [(obs_size,)]
act_size = [2]
stream_names = [f"stream_name{n}" for n in range(4)]
actor = ac_type(
obs_shapes, network_settings, ActionType.CONTINUOUS, act_size, stream_names
)
if lstm:
sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size))
memories = torch.ones(
(
1,
network_settings.memory.sequence_length,
network_settings.memory.memory_size,
)
)
else:
sample_obs = torch.ones((1, obs_size))
memories = None
# Test critic pass
value_out = actor.critic_pass([sample_obs], [], memories=memories)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
else:
assert value_out[stream].shape == (1,)

# Test get_dist_and_value
dists, value_out, _ = actor.get_dist_and_value([sample_obs], [], memories=memories)
for dist in dists:
assert isinstance(dist, GaussianDistInstance)
for stream in stream_names:
if lstm:
assert value_out[stream].shape == (network_settings.memory.sequence_length,)
else:
assert value_out[stream].shape == (1,)
11 changes: 5 additions & 6 deletions ml-agents/mlagents/trainers/torch/decoders.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import List, Dict

import torch
from torch import nn


class ValueHeads(nn.Module):
def __init__(self, stream_names, input_size, output_size=1):
def __init__(self, stream_names: List[str], input_size: int, output_size: int = 1):
super().__init__()
self.stream_names = stream_names
_value_heads = {}
Expand All @@ -13,11 +15,8 @@ def __init__(self, stream_names, input_size, output_size=1):
_value_heads[name] = value
self.value_heads = nn.ModuleDict(_value_heads)

def forward(self, hidden):
def forward(self, hidden: torch.Tensor) -> Dict[str, torch.Tensor]:
value_outputs = {}
for stream_name, head in self.value_heads.items():
value_outputs[stream_name] = head(hidden).squeeze(-1)
return (
value_outputs,
torch.mean(torch.stack(list(value_outputs.values())), dim=0),
)
return value_outputs
Loading