-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Move the Critic into the Optimizer #4939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
96278d0
5f8cbc5
293ec08
7d20bd9
c669226
b22d0ae
d7e2ca6
9f6eca7
944997a
527ca06
eb15030
4d215cf
9fac4b1
d5a30f1
65b5992
31da276
c41c9a7
817b248
4eb6cb3
cbb8b64
beae793
5911879
7ce234d
6c300c8
36d9532
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,8 +1,9 @@ | ||||||
from typing import Dict, Optional, Tuple, List | ||||||
from mlagents.torch_utils import torch | ||||||
import numpy as np | ||||||
import math | ||||||
|
||||||
from mlagents.trainers.buffer import AgentBuffer | ||||||
from mlagents.trainers.buffer import AgentBuffer, AgentBufferField | ||||||
from mlagents.trainers.trajectory import ObsUtil | ||||||
from mlagents.trainers.torch.components.bc.module import BCModule | ||||||
from mlagents.trainers.torch.components.reward_providers import create_reward_provider | ||||||
|
@@ -26,6 +27,7 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): | |||||
self.global_step = torch.tensor(0) | ||||||
self.bc_module: Optional[BCModule] = None | ||||||
self.create_reward_signals(trainer_settings.reward_signals) | ||||||
self.critic_memory_dict: Dict[str, torch.Tensor] = {} | ||||||
if trainer_settings.behavioral_cloning is not None: | ||||||
self.bc_module = BCModule( | ||||||
self.policy, | ||||||
|
@@ -35,6 +37,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): | |||||
default_num_epoch=3, | ||||||
) | ||||||
|
||||||
@property | ||||||
def critic(self): | ||||||
raise NotImplementedError | ||||||
|
||||||
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: | ||||||
pass | ||||||
|
||||||
|
@@ -49,25 +55,132 @@ def create_reward_signals(self, reward_signal_configs): | |||||
reward_signal, self.policy.behavior_spec, settings | ||||||
) | ||||||
|
||||||
def _evaluate_by_sequence( | ||||||
self, tensor_obs: List[torch.Tensor], initial_memory: np.ndarray | ||||||
) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]: | ||||||
""" | ||||||
Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the | ||||||
intermediate memories for the critic. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. describe the inputs and their dimensions. it is not clear what initial_memory's shape is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment |
||||||
:param tensor_obs: A List of tensors of shape (trajectory_len, <obs_dim>) that are the agent's | ||||||
observations for this trajectory. | ||||||
:param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,<mem_size>), i.e. | ||||||
what is returned as the output of a MemoryModules. | ||||||
:return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial | ||||||
memories to be used during value function update, and the final memory at the end of the trajectory. | ||||||
""" | ||||||
num_experiences = tensor_obs[0].shape[0] | ||||||
all_next_memories = AgentBufferField() | ||||||
# In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and | ||||||
# trajectory is of length 10, the 1st sequence is [pad,pad,obs]. | ||||||
# Compute the number of elements in this padded seq. | ||||||
leftover = num_experiences % self.policy.sequence_length | ||||||
|
||||||
# Compute values for the potentially truncated initial sequence | ||||||
seq_obs = [] | ||||||
|
||||||
first_seq_len = self.policy.sequence_length | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason for padding the end of the first sequence and not the last sequence ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Neither is padded in this function, but the buffer pads the front of the first sequence, so this function accounts for that. I left the padding in the buffer as-is and still need to experiment with that change; it will likely be a future PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then why is the first sequence handled differently ? |
||||||
for _obs in tensor_obs: | ||||||
if leftover > 0: | ||||||
first_seq_len = leftover | ||||||
first_seq_obs = _obs[0:first_seq_len] | ||||||
seq_obs.append(first_seq_obs) | ||||||
|
||||||
# For the first sequence, the initial memory should be the one at the | ||||||
# beginning of this trajectory. | ||||||
for _ in range(first_seq_len): | ||||||
all_next_memories.append(initial_memory.squeeze().detach().numpy()) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hard to understand why the all_next_memories are a concatenation of initial memories... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refactored this to be a bit clearer |
||||||
|
||||||
init_values, _mem = self.critic.critic_pass( | ||||||
seq_obs, initial_memory, sequence_length=first_seq_len | ||||||
) | ||||||
all_values = { | ||||||
signal_name: [init_values[signal_name]] | ||||||
for signal_name in init_values.keys() | ||||||
} | ||||||
|
||||||
# Evaluate other trajectories, carrying over _mem after each | ||||||
# trajectory | ||||||
for seq_num in range( | ||||||
1, math.ceil((num_experiences) / (self.policy.sequence_length)) | ||||||
): | ||||||
seq_obs = [] | ||||||
for _ in range(self.policy.sequence_length): | ||||||
all_next_memories.append(_mem.squeeze().detach().numpy()) | ||||||
for _obs in tensor_obs: | ||||||
start = seq_num * self.policy.sequence_length - ( | ||||||
self.policy.sequence_length - leftover | ||||||
) | ||||||
end = (seq_num + 1) * self.policy.sequence_length - ( | ||||||
self.policy.sequence_length - leftover | ||||||
) | ||||||
seq_obs.append(_obs[start:end]) | ||||||
values, _mem = self.critic.critic_pass( | ||||||
seq_obs, _mem, sequence_length=self.policy.sequence_length | ||||||
) | ||||||
for signal_name, _val in values.items(): | ||||||
all_values[signal_name].append(_val) | ||||||
# Create one tensor per reward signal | ||||||
all_value_tensors = { | ||||||
signal_name: torch.cat(value_list, dim=0) | ||||||
for signal_name, value_list in all_values.items() | ||||||
} | ||||||
next_mem = _mem | ||||||
return all_value_tensors, all_next_memories, next_mem | ||||||
|
||||||
def get_trajectory_value_estimates( | ||||||
self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool | ||||||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: | ||||||
self, | ||||||
batch: AgentBuffer, | ||||||
next_obs: List[np.ndarray], | ||||||
done: bool, | ||||||
agent_id: str = "", | ||||||
) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. docstring There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||||||
""" | ||||||
Get value estimates and memories for a trajectory, in batch form. | ||||||
:param batch: An AgentBuffer that consists of a trajectory. | ||||||
:param next_obs: the next observation (after the trajectory). Used for boostrapping | ||||||
if this is not a termiinal trajectory. | ||||||
:param done: Set true if this is a terminal trajectory. | ||||||
:param agent_id: Agent ID of the agent that this trajectory belongs to. | ||||||
:returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)], | ||||||
the final value estimate as a Dict of [name, float], and optionally (if using memories) | ||||||
an AgentBufferField of initial critic memories to be used during update. | ||||||
""" | ||||||
n_obs = len(self.policy.behavior_spec.observation_specs) | ||||||
current_obs = ObsUtil.from_buffer(batch, n_obs) | ||||||
|
||||||
if agent_id in self.critic_memory_dict: | ||||||
memory = self.critic_memory_dict[agent_id] | ||||||
else: | ||||||
memory = ( | ||||||
torch.zeros((1, 1, self.critic.memory_size)) | ||||||
if self.policy.use_recurrent | ||||||
else None | ||||||
) | ||||||
|
||||||
# Convert to tensors | ||||||
current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] | ||||||
current_obs = [ | ||||||
ModelUtils.list_to_tensor(obs) for obs in ObsUtil.from_buffer(batch, n_obs) | ||||||
] | ||||||
next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] | ||||||
|
||||||
memory = torch.zeros([1, 1, self.policy.m_size]) | ||||||
|
||||||
next_obs = [obs.unsqueeze(0) for obs in next_obs] | ||||||
|
||||||
value_estimates, next_memory = self.policy.actor_critic.critic_pass( | ||||||
current_obs, memory, sequence_length=batch.num_experiences | ||||||
) | ||||||
# If we're using LSTM, we want to get all the intermediate memories. | ||||||
all_next_memories: Optional[AgentBufferField] = None | ||||||
if self.policy.use_recurrent: | ||||||
( | ||||||
value_estimates, | ||||||
all_next_memories, | ||||||
next_memory, | ||||||
) = self._evaluate_by_sequence(current_obs, memory) | ||||||
else: | ||||||
value_estimates, next_memory = self.critic.critic_pass( | ||||||
current_obs, memory, sequence_length=batch.num_experiences | ||||||
) | ||||||
|
||||||
next_value_estimate, _ = self.policy.actor_critic.critic_pass( | ||||||
# Store the memory for the next trajectory | ||||||
self.critic_memory_dict[agent_id] = next_memory | ||||||
|
||||||
next_value_estimate, _ = self.critic.critic_pass( | ||||||
next_obs, next_memory, sequence_length=1 | ||||||
) | ||||||
|
||||||
|
@@ -79,5 +192,6 @@ def get_trajectory_value_estimates( | |||||
for k in next_value_estimate: | ||||||
if not self.reward_signals[k].ignore_done: | ||||||
next_value_estimate[k] = 0.0 | ||||||
|
||||||
return value_estimates, next_value_estimate | ||||||
if agent_id in self.critic_memory_dict: | ||||||
self.critic_memory_dict.pop(agent_id) | ||||||
return value_estimates, next_value_estimate, all_next_memories |
Uh oh!
There was an error while loading. Please reload this page.