Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
96278d0
Separate Actor/Critic, remove ActorCritics
andrewcoh Feb 9, 2021
5f8cbc5
update policy to not use critic
andrewcoh Feb 9, 2021
293ec08
add critic to optimizer, ppo runs
andrewcoh Feb 9, 2021
7d20bd9
fix precommit errors
andrewcoh Feb 9, 2021
c669226
fix test_networks
andrewcoh Feb 9, 2021
b22d0ae
Update SAC to use separate policy
Feb 10, 2021
d7e2ca6
make critic a property
andrewcoh Feb 10, 2021
9f6eca7
remove commented code
andrewcoh Feb 11, 2021
944997a
fix saver test
andrewcoh Feb 11, 2021
527ca06
Move value network for SAC to device
Feb 11, 2021
eb15030
Merge remote-tracking branch 'origin/develop-critic-optimizer' into d…
Feb 11, 2021
4d215cf
add SharedActorCritic
andrewcoh Feb 11, 2021
9fac4b1
test for SharedActorCritic
andrewcoh Feb 11, 2021
d5a30f1
fix agent processor test
andrewcoh Feb 11, 2021
65b5992
fix sac shared
andrewcoh Feb 12, 2021
31da276
fix test policy
andrewcoh Feb 12, 2021
c41c9a7
adjust step size gail visual ppo
andrewcoh Feb 12, 2021
817b248
Merge branch 'master' into develop-critic-optimizer
andrewcoh Feb 16, 2021
4eb6cb3
Store and evaluate critic LSTM memories in Optimizer (#4948)
Feb 24, 2021
cbb8b64
address comments
andrewcoh Feb 24, 2021
beae793
raise if SAC using SharedActorCritic
andrewcoh Feb 24, 2021
5911879
Merge branch 'master' into develop-critic-optimizer
andrewcoh Feb 24, 2021
7ce234d
add critic to default device
andrewcoh Feb 24, 2021
6c300c8
Address comments
Feb 25, 2021
36d9532
docstring for action info
andrewcoh Feb 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ml-agents/mlagents/trainers/action_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,20 @@


class ActionInfo(NamedTuple):
"""
A NamedTuple containing actions and related quantities to the policy forward
pass. Additionally contains the agent ids in the corresponding DecisionStep
:param action: The action output of the policy
:param env_action: The possibly clipped action to be executed in the environment
:param outputs: Dict of all quantities associated with the policy forward pass
:param agent_ids: List of int agent ids in DecisionStep
"""

action: Any
env_action: Any
value: Any
outputs: ActionInfoOutputs
agent_ids: List[AgentId]

@staticmethod
def empty() -> "ActionInfo":
return ActionInfo([], [], [], {}, [])
return ActionInfo([], [], {}, [])
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _process_step(
if stored_decision_step is not None and stored_take_action_outputs is not None:
obs = stored_decision_step.obs
if self.policy.use_recurrent:
memory = self.policy.retrieve_memories([global_id])[0, :]
memory = self.policy.retrieve_previous_memories([global_id])[0, :]
else:
memory = None
done = terminated # Since this is an ongoing step
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class BufferKey(enum.Enum):
ENVIRONMENT_REWARDS = "environment_rewards"
MASKS = "masks"
MEMORY = "memory"
CRITIC_MEMORY = "critic_memory"
PREV_ACTION = "prev_action"

ADVANTAGES = "advantages"
Expand Down
140 changes: 127 additions & 13 deletions ml-agents/mlagents/trainers/optimizer/torch_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Dict, Optional, Tuple, List
from mlagents.torch_utils import torch
import numpy as np
import math

from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.buffer import AgentBuffer, AgentBufferField
from mlagents.trainers.trajectory import ObsUtil
from mlagents.trainers.torch.components.bc.module import BCModule
from mlagents.trainers.torch.components.reward_providers import create_reward_provider
Expand All @@ -26,6 +27,7 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
self.global_step = torch.tensor(0)
self.bc_module: Optional[BCModule] = None
self.create_reward_signals(trainer_settings.reward_signals)
self.critic_memory_dict: Dict[str, torch.Tensor] = {}
if trainer_settings.behavioral_cloning is not None:
self.bc_module = BCModule(
self.policy,
Expand All @@ -35,6 +37,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings):
default_num_epoch=3,
)

@property
def critic(self):
raise NotImplementedError

def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
pass

Expand All @@ -49,25 +55,132 @@ def create_reward_signals(self, reward_signal_configs):
reward_signal, self.policy.behavior_spec, settings
)

def _evaluate_by_sequence(
self, tensor_obs: List[torch.Tensor], initial_memory: np.ndarray
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]:
"""
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the
intermediate memories for the critic.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

describe the inputs and their dimensions. it is not clear what initial_memory's shape is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment

:param tensor_obs: A List of tensors of shape (trajectory_len, <obs_dim>) that are the agent's
observations for this trajectory.
:param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,<mem_size>), i.e.
what is returned as the output of a MemoryModules.
:return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial
memories to be used during value function update, and the final memory at the end of the trajectory.
"""
num_experiences = tensor_obs[0].shape[0]
all_next_memories = AgentBufferField()
# In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and
# trajectory is of length 10, the 1st sequence is [pad,pad,obs].
# Compute the number of elements in this padded seq.
leftover = num_experiences % self.policy.sequence_length

# Compute values for the potentially truncated initial sequence
seq_obs = []

first_seq_len = self.policy.sequence_length
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for padding the end of the first sequence and not the last sequence ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither is padded in this function, but the buffer pads the front of the first sequence, so this function accounts for that. I left the padding in the buffer as-is and still need to experiment with that change; it will likely be a future PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then why is the first sequence handled differently ?

for _obs in tensor_obs:
if leftover > 0:
first_seq_len = leftover
first_seq_obs = _obs[0:first_seq_len]
seq_obs.append(first_seq_obs)

# For the first sequence, the initial memory should be the one at the
# beginning of this trajectory.
for _ in range(first_seq_len):
all_next_memories.append(initial_memory.squeeze().detach().numpy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
all_next_memories.append(initial_memory.squeeze().detach().numpy())
all_next_memories.append(_mem.squeeze().detach().numpy())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard to understand why the all_next_memories are a concatenation of initial memories...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactored this to be a bit clearer


init_values, _mem = self.critic.critic_pass(
seq_obs, initial_memory, sequence_length=first_seq_len
)
all_values = {
signal_name: [init_values[signal_name]]
for signal_name in init_values.keys()
}

# Evaluate other trajectories, carrying over _mem after each
# trajectory
for seq_num in range(
1, math.ceil((num_experiences) / (self.policy.sequence_length))
):
seq_obs = []
for _ in range(self.policy.sequence_length):
all_next_memories.append(_mem.squeeze().detach().numpy())
for _obs in tensor_obs:
start = seq_num * self.policy.sequence_length - (
self.policy.sequence_length - leftover
)
end = (seq_num + 1) * self.policy.sequence_length - (
self.policy.sequence_length - leftover
)
seq_obs.append(_obs[start:end])
values, _mem = self.critic.critic_pass(
seq_obs, _mem, sequence_length=self.policy.sequence_length
)
for signal_name, _val in values.items():
all_values[signal_name].append(_val)
# Create one tensor per reward signal
all_value_tensors = {
signal_name: torch.cat(value_list, dim=0)
for signal_name, value_list in all_values.items()
}
next_mem = _mem
return all_value_tensors, all_next_memories, next_mem

def get_trajectory_value_estimates(
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]:
self,
batch: AgentBuffer,
next_obs: List[np.ndarray],
done: bool,
agent_id: str = "",
) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

"""
Get value estimates and memories for a trajectory, in batch form.
:param batch: An AgentBuffer that consists of a trajectory.
:param next_obs: the next observation (after the trajectory). Used for boostrapping
if this is not a termiinal trajectory.
:param done: Set true if this is a terminal trajectory.
:param agent_id: Agent ID of the agent that this trajectory belongs to.
:returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)],
the final value estimate as a Dict of [name, float], and optionally (if using memories)
an AgentBufferField of initial critic memories to be used during update.
"""
n_obs = len(self.policy.behavior_spec.observation_specs)
current_obs = ObsUtil.from_buffer(batch, n_obs)

if agent_id in self.critic_memory_dict:
memory = self.critic_memory_dict[agent_id]
else:
memory = (
torch.zeros((1, 1, self.critic.memory_size))
if self.policy.use_recurrent
else None
)

# Convert to tensors
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]
current_obs = [
ModelUtils.list_to_tensor(obs) for obs in ObsUtil.from_buffer(batch, n_obs)
]
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

memory = torch.zeros([1, 1, self.policy.m_size])

next_obs = [obs.unsqueeze(0) for obs in next_obs]

value_estimates, next_memory = self.policy.actor_critic.critic_pass(
current_obs, memory, sequence_length=batch.num_experiences
)
# If we're using LSTM, we want to get all the intermediate memories.
all_next_memories: Optional[AgentBufferField] = None
if self.policy.use_recurrent:
(
value_estimates,
all_next_memories,
next_memory,
) = self._evaluate_by_sequence(current_obs, memory)
else:
value_estimates, next_memory = self.critic.critic_pass(
current_obs, memory, sequence_length=batch.num_experiences
)

next_value_estimate, _ = self.policy.actor_critic.critic_pass(
# Store the memory for the next trajectory
self.critic_memory_dict[agent_id] = next_memory

next_value_estimate, _ = self.critic.critic_pass(
next_obs, next_memory, sequence_length=1
)

Expand All @@ -79,5 +192,6 @@ def get_trajectory_value_estimates(
for k in next_value_estimate:
if not self.reward_signals[k].ignore_done:
next_value_estimate[k] = 0.0

return value_estimates, next_value_estimate
if agent_id in self.critic_memory_dict:
self.critic_memory_dict.pop(agent_id)
return value_estimates, next_value_estimate, all_next_memories
15 changes: 15 additions & 0 deletions ml-agents/mlagents/trainers/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
self.network_settings: NetworkSettings = trainer_settings.network_settings
self.seed = seed
self.previous_action_dict: Dict[str, np.ndarray] = {}
self.previous_memory_dict: Dict[str, np.ndarray] = {}
self.memory_dict: Dict[str, np.ndarray] = {}
self.normalize = trainer_settings.network_settings.normalize
self.use_recurrent = self.network_settings.memory is not None
Expand Down Expand Up @@ -72,6 +73,11 @@ def save_memories(
if memory_matrix is None:
return

# Pass old memories into previous_memory_dict
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.previous_memory_dict[agent_id] = self.memory_dict[agent_id]

for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]

Expand All @@ -82,10 +88,19 @@ def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix

def retrieve_previous_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_memory_dict:
memory_matrix[index, :] = self.previous_memory_dict[agent_id]
return memory_matrix

def remove_memories(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.memory_dict.pop(agent_id)
if agent_id in self.previous_memory_dict:
self.previous_memory_dict.pop(agent_id)

def make_empty_previous_action(self, num_agents: int) -> np.ndarray:
"""
Expand Down
62 changes: 33 additions & 29 deletions ml-agents/mlagents/trainers/policy/torch_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from mlagents_envs.timers import timed

from mlagents.trainers.settings import TrainerSettings
from mlagents.trainers.torch.networks import (
SharedActorCritic,
SeparateActorCritic,
GlobalSteps,
)
from mlagents.trainers.torch.networks import SimpleActor, SharedActorCritic, GlobalSteps

from mlagents.trainers.torch.utils import ModelUtils
from mlagents.trainers.buffer import AgentBuffer
Expand Down Expand Up @@ -61,31 +57,40 @@ def __init__(
) # could be much simpler if TorchPolicy is nn.Module
self.grads = None

reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [key.value for key, _ in reward_signal_configs.items()]

self.stats_name_to_update_name = {
"Losses/Value Loss": "value_loss",
"Losses/Policy Loss": "policy_loss",
}
if separate_critic:
ac_class = SeparateActorCritic
self.actor = SimpleActor(
observation_specs=self.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=behavior_spec.action_spec,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self.shared_critic = False
else:
ac_class = SharedActorCritic
self.actor_critic = ac_class(
observation_specs=self.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=behavior_spec.action_spec,
stream_names=reward_signal_names,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
reward_signal_configs = trainer_settings.reward_signals
reward_signal_names = [
key.value for key, _ in reward_signal_configs.items()
]
self.actor = SharedActorCritic(
observation_specs=self.behavior_spec.observation_specs,
network_settings=trainer_settings.network_settings,
action_spec=behavior_spec.action_spec,
stream_names=reward_signal_names,
conditional_sigma=self.condition_sigma_on_obs,
tanh_squash=tanh_squash,
)
self.shared_critic = True

# Save the m_size needed for export
self._export_m_size = self.m_size
# m_size needed for training is determined by network, not trainer settings
self.m_size = self.actor_critic.memory_size
self.m_size = self.actor.memory_size

self.actor_critic.to(default_device())
self.actor.to(default_device())
self._clip_action = not tanh_squash

@property
Expand Down Expand Up @@ -115,7 +120,7 @@ def update_normalization(self, buffer: AgentBuffer) -> None:
"""

if self.normalize:
self.actor_critic.update_normalization(buffer)
self.actor.update_normalization(buffer)

@timed
def sample_actions(
Expand All @@ -132,7 +137,7 @@ def sample_actions(
:param seq_len: Sequence length when using RNN.
:return: Tuple of AgentAction, ActionLogProbs, entropies, and output memories.
"""
actions, log_probs, entropies, memories = self.actor_critic.get_action_stats(
actions, log_probs, entropies, memories = self.actor.get_action_and_stats(
obs, masks, memories, seq_len
)
return (actions, log_probs, entropies, memories)
Expand All @@ -144,11 +149,11 @@ def evaluate_actions(
masks: Optional[torch.Tensor] = None,
memories: Optional[torch.Tensor] = None,
seq_len: int = 1,
) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]:
log_probs, entropies, value_heads = self.actor_critic.get_stats_and_value(
) -> Tuple[ActionLogProbs, torch.Tensor]:
log_probs, entropies = self.actor.get_stats(
obs, actions, masks, memories, seq_len
)
return log_probs, entropies, value_heads
return log_probs, entropies

@timed
def evaluate(
Expand Down Expand Up @@ -210,7 +215,6 @@ def get_action(
return ActionInfo(
action=run_out.get("action"),
env_action=run_out.get("env_action"),
value=run_out.get("value"),
outputs=run_out,
agent_ids=list(decision_requests.agent_id),
)
Expand Down Expand Up @@ -239,13 +243,13 @@ def increment_step(self, n_steps):
return self.get_current_step()

def load_weights(self, values: List[np.ndarray]) -> None:
self.actor_critic.load_state_dict(values)
self.actor.load_state_dict(values)

def init_load_weights(self) -> None:
pass

def get_weights(self) -> List[np.ndarray]:
return copy.deepcopy(self.actor_critic.state_dict())
return copy.deepcopy(self.actor.state_dict())

def get_modules(self):
return {"Policy": self.actor_critic, "global_step": self.global_step}
return {"Policy": self.actor, "global_step": self.global_step}
Loading