diff --git a/ml-agents/mlagents/trainers/action_info.py b/ml-agents/mlagents/trainers/action_info.py index 54889b57c0..c0ec023271 100644 --- a/ml-agents/mlagents/trainers/action_info.py +++ b/ml-agents/mlagents/trainers/action_info.py @@ -6,12 +6,20 @@ class ActionInfo(NamedTuple): + """ + A NamedTuple containing actions and related quantities to the policy forward + pass. Additionally contains the agent ids in the corresponding DecisionStep + :param action: The action output of the policy + :param env_action: The possibly clipped action to be executed in the environment + :param outputs: Dict of all quantities associated with the policy forward pass + :param agent_ids: List of int agent ids in DecisionStep + """ + action: Any env_action: Any - value: Any outputs: ActionInfoOutputs agent_ids: List[AgentId] @staticmethod def empty() -> "ActionInfo": - return ActionInfo([], [], [], {}, []) + return ActionInfo([], [], {}, []) diff --git a/ml-agents/mlagents/trainers/agent_processor.py b/ml-agents/mlagents/trainers/agent_processor.py index e74e4a35a5..fb77df0795 100644 --- a/ml-agents/mlagents/trainers/agent_processor.py +++ b/ml-agents/mlagents/trainers/agent_processor.py @@ -126,7 +126,7 @@ def _process_step( if stored_decision_step is not None and stored_take_action_outputs is not None: obs = stored_decision_step.obs if self.policy.use_recurrent: - memory = self.policy.retrieve_memories([global_id])[0, :] + memory = self.policy.retrieve_previous_memories([global_id])[0, :] else: memory = None done = terminated # Since this is an ongoing step diff --git a/ml-agents/mlagents/trainers/buffer.py b/ml-agents/mlagents/trainers/buffer.py index dd1d54c79d..302fe418a0 100644 --- a/ml-agents/mlagents/trainers/buffer.py +++ b/ml-agents/mlagents/trainers/buffer.py @@ -28,6 +28,7 @@ class BufferKey(enum.Enum): ENVIRONMENT_REWARDS = "environment_rewards" MASKS = "masks" MEMORY = "memory" + CRITIC_MEMORY = "critic_memory" PREV_ACTION = "prev_action" ADVANTAGES = "advantages" diff --git a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py index 086bf9a1a2..56130367ea 100644 --- a/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py +++ b/ml-agents/mlagents/trainers/optimizer/torch_optimizer.py @@ -1,8 +1,9 @@ from typing import Dict, Optional, Tuple, List from mlagents.torch_utils import torch import numpy as np +import math -from mlagents.trainers.buffer import AgentBuffer +from mlagents.trainers.buffer import AgentBuffer, AgentBufferField from mlagents.trainers.trajectory import ObsUtil from mlagents.trainers.torch.components.bc.module import BCModule from mlagents.trainers.torch.components.reward_providers import create_reward_provider @@ -26,6 +27,7 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.global_step = torch.tensor(0) self.bc_module: Optional[BCModule] = None self.create_reward_signals(trainer_settings.reward_signals) + self.critic_memory_dict: Dict[str, torch.Tensor] = {} if trainer_settings.behavioral_cloning is not None: self.bc_module = BCModule( self.policy, @@ -35,6 +37,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): default_num_epoch=3, ) + @property + def critic(self): + raise NotImplementedError + def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: pass @@ -49,25 +55,132 @@ def create_reward_signals(self, reward_signal_configs): reward_signal, self.policy.behavior_spec, settings ) + def _evaluate_by_sequence( + self, tensor_obs: List[torch.Tensor], initial_memory: np.ndarray + ) -> Tuple[Dict[str, torch.Tensor], AgentBufferField, torch.Tensor]: + """ + Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the + intermediate memories for the critic. + :param tensor_obs: A List of tensors of shape (trajectory_len, ) that are the agent's + observations for this trajectory. + :param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,), i.e. + what is returned as the output of a MemoryModules. + :return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial + memories to be used during value function update, and the final memory at the end of the trajectory. + """ + num_experiences = tensor_obs[0].shape[0] + all_next_memories = AgentBufferField() + # In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and + # trajectory is of length 10, the 1st sequence is [pad,pad,obs]. + # Compute the number of elements in this padded seq. + leftover = num_experiences % self.policy.sequence_length + + # Compute values for the potentially truncated initial sequence + seq_obs = [] + + first_seq_len = self.policy.sequence_length + for _obs in tensor_obs: + if leftover > 0: + first_seq_len = leftover + first_seq_obs = _obs[0:first_seq_len] + seq_obs.append(first_seq_obs) + + # For the first sequence, the initial memory should be the one at the + # beginning of this trajectory. + for _ in range(first_seq_len): + all_next_memories.append(initial_memory.squeeze().detach().numpy()) + + init_values, _mem = self.critic.critic_pass( + seq_obs, initial_memory, sequence_length=first_seq_len + ) + all_values = { + signal_name: [init_values[signal_name]] + for signal_name in init_values.keys() + } + + # Evaluate other trajectories, carrying over _mem after each + # trajectory + for seq_num in range( + 1, math.ceil((num_experiences) / (self.policy.sequence_length)) + ): + seq_obs = [] + for _ in range(self.policy.sequence_length): + all_next_memories.append(_mem.squeeze().detach().numpy()) + for _obs in tensor_obs: + start = seq_num * self.policy.sequence_length - ( + self.policy.sequence_length - leftover + ) + end = (seq_num + 1) * self.policy.sequence_length - ( + self.policy.sequence_length - leftover + ) + seq_obs.append(_obs[start:end]) + values, _mem = self.critic.critic_pass( + seq_obs, _mem, sequence_length=self.policy.sequence_length + ) + for signal_name, _val in values.items(): + all_values[signal_name].append(_val) + # Create one tensor per reward signal + all_value_tensors = { + signal_name: torch.cat(value_list, dim=0) + for signal_name, value_list in all_values.items() + } + next_mem = _mem + return all_value_tensors, all_next_memories, next_mem + def get_trajectory_value_estimates( - self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool - ) -> Tuple[Dict[str, np.ndarray], Dict[str, float]]: + self, + batch: AgentBuffer, + next_obs: List[np.ndarray], + done: bool, + agent_id: str = "", + ) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]: + """ + Get value estimates and memories for a trajectory, in batch form. + :param batch: An AgentBuffer that consists of a trajectory. + :param next_obs: the next observation (after the trajectory). Used for boostrapping + if this is not a termiinal trajectory. + :param done: Set true if this is a terminal trajectory. + :param agent_id: Agent ID of the agent that this trajectory belongs to. + :returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)], + the final value estimate as a Dict of [name, float], and optionally (if using memories) + an AgentBufferField of initial critic memories to be used during update. + """ n_obs = len(self.policy.behavior_spec.observation_specs) - current_obs = ObsUtil.from_buffer(batch, n_obs) + + if agent_id in self.critic_memory_dict: + memory = self.critic_memory_dict[agent_id] + else: + memory = ( + torch.zeros((1, 1, self.critic.memory_size)) + if self.policy.use_recurrent + else None + ) # Convert to tensors - current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] + current_obs = [ + ModelUtils.list_to_tensor(obs) for obs in ObsUtil.from_buffer(batch, n_obs) + ] next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] - memory = torch.zeros([1, 1, self.policy.m_size]) - next_obs = [obs.unsqueeze(0) for obs in next_obs] - value_estimates, next_memory = self.policy.actor_critic.critic_pass( - current_obs, memory, sequence_length=batch.num_experiences - ) + # If we're using LSTM, we want to get all the intermediate memories. + all_next_memories: Optional[AgentBufferField] = None + if self.policy.use_recurrent: + ( + value_estimates, + all_next_memories, + next_memory, + ) = self._evaluate_by_sequence(current_obs, memory) + else: + value_estimates, next_memory = self.critic.critic_pass( + current_obs, memory, sequence_length=batch.num_experiences + ) - next_value_estimate, _ = self.policy.actor_critic.critic_pass( + # Store the memory for the next trajectory + self.critic_memory_dict[agent_id] = next_memory + + next_value_estimate, _ = self.critic.critic_pass( next_obs, next_memory, sequence_length=1 ) @@ -79,5 +192,6 @@ def get_trajectory_value_estimates( for k in next_value_estimate: if not self.reward_signals[k].ignore_done: next_value_estimate[k] = 0.0 - - return value_estimates, next_value_estimate + if agent_id in self.critic_memory_dict: + self.critic_memory_dict.pop(agent_id) + return value_estimates, next_value_estimate, all_next_memories diff --git a/ml-agents/mlagents/trainers/policy/policy.py b/ml-agents/mlagents/trainers/policy/policy.py index 79ae4ac9d3..068c989d85 100644 --- a/ml-agents/mlagents/trainers/policy/policy.py +++ b/ml-agents/mlagents/trainers/policy/policy.py @@ -33,6 +33,7 @@ def __init__( self.network_settings: NetworkSettings = trainer_settings.network_settings self.seed = seed self.previous_action_dict: Dict[str, np.ndarray] = {} + self.previous_memory_dict: Dict[str, np.ndarray] = {} self.memory_dict: Dict[str, np.ndarray] = {} self.normalize = trainer_settings.network_settings.normalize self.use_recurrent = self.network_settings.memory is not None @@ -72,6 +73,11 @@ def save_memories( if memory_matrix is None: return + # Pass old memories into previous_memory_dict + for agent_id in agent_ids: + if agent_id in self.memory_dict: + self.previous_memory_dict[agent_id] = self.memory_dict[agent_id] + for index, agent_id in enumerate(agent_ids): self.memory_dict[agent_id] = memory_matrix[index, :] @@ -82,10 +88,19 @@ def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray: memory_matrix[index, :] = self.memory_dict[agent_id] return memory_matrix + def retrieve_previous_memories(self, agent_ids: List[str]) -> np.ndarray: + memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32) + for index, agent_id in enumerate(agent_ids): + if agent_id in self.previous_memory_dict: + memory_matrix[index, :] = self.previous_memory_dict[agent_id] + return memory_matrix + def remove_memories(self, agent_ids): for agent_id in agent_ids: if agent_id in self.memory_dict: self.memory_dict.pop(agent_id) + if agent_id in self.previous_memory_dict: + self.previous_memory_dict.pop(agent_id) def make_empty_previous_action(self, num_agents: int) -> np.ndarray: """ diff --git a/ml-agents/mlagents/trainers/policy/torch_policy.py b/ml-agents/mlagents/trainers/policy/torch_policy.py index c77bdf9c89..a4fc1e1dd4 100644 --- a/ml-agents/mlagents/trainers/policy/torch_policy.py +++ b/ml-agents/mlagents/trainers/policy/torch_policy.py @@ -10,11 +10,7 @@ from mlagents_envs.timers import timed from mlagents.trainers.settings import TrainerSettings -from mlagents.trainers.torch.networks import ( - SharedActorCritic, - SeparateActorCritic, - GlobalSteps, -) +from mlagents.trainers.torch.networks import SimpleActor, SharedActorCritic, GlobalSteps from mlagents.trainers.torch.utils import ModelUtils from mlagents.trainers.buffer import AgentBuffer @@ -61,31 +57,40 @@ def __init__( ) # could be much simpler if TorchPolicy is nn.Module self.grads = None - reward_signal_configs = trainer_settings.reward_signals - reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] - self.stats_name_to_update_name = { "Losses/Value Loss": "value_loss", "Losses/Policy Loss": "policy_loss", } if separate_critic: - ac_class = SeparateActorCritic + self.actor = SimpleActor( + observation_specs=self.behavior_spec.observation_specs, + network_settings=trainer_settings.network_settings, + action_spec=behavior_spec.action_spec, + conditional_sigma=self.condition_sigma_on_obs, + tanh_squash=tanh_squash, + ) + self.shared_critic = False else: - ac_class = SharedActorCritic - self.actor_critic = ac_class( - observation_specs=self.behavior_spec.observation_specs, - network_settings=trainer_settings.network_settings, - action_spec=behavior_spec.action_spec, - stream_names=reward_signal_names, - conditional_sigma=self.condition_sigma_on_obs, - tanh_squash=tanh_squash, - ) + reward_signal_configs = trainer_settings.reward_signals + reward_signal_names = [ + key.value for key, _ in reward_signal_configs.items() + ] + self.actor = SharedActorCritic( + observation_specs=self.behavior_spec.observation_specs, + network_settings=trainer_settings.network_settings, + action_spec=behavior_spec.action_spec, + stream_names=reward_signal_names, + conditional_sigma=self.condition_sigma_on_obs, + tanh_squash=tanh_squash, + ) + self.shared_critic = True + # Save the m_size needed for export self._export_m_size = self.m_size # m_size needed for training is determined by network, not trainer settings - self.m_size = self.actor_critic.memory_size + self.m_size = self.actor.memory_size - self.actor_critic.to(default_device()) + self.actor.to(default_device()) self._clip_action = not tanh_squash @property @@ -115,7 +120,7 @@ def update_normalization(self, buffer: AgentBuffer) -> None: """ if self.normalize: - self.actor_critic.update_normalization(buffer) + self.actor.update_normalization(buffer) @timed def sample_actions( @@ -132,7 +137,7 @@ def sample_actions( :param seq_len: Sequence length when using RNN. :return: Tuple of AgentAction, ActionLogProbs, entropies, and output memories. """ - actions, log_probs, entropies, memories = self.actor_critic.get_action_stats( + actions, log_probs, entropies, memories = self.actor.get_action_and_stats( obs, masks, memories, seq_len ) return (actions, log_probs, entropies, memories) @@ -144,11 +149,11 @@ def evaluate_actions( masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, seq_len: int = 1, - ) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]: - log_probs, entropies, value_heads = self.actor_critic.get_stats_and_value( + ) -> Tuple[ActionLogProbs, torch.Tensor]: + log_probs, entropies = self.actor.get_stats( obs, actions, masks, memories, seq_len ) - return log_probs, entropies, value_heads + return log_probs, entropies @timed def evaluate( @@ -210,7 +215,6 @@ def get_action( return ActionInfo( action=run_out.get("action"), env_action=run_out.get("env_action"), - value=run_out.get("value"), outputs=run_out, agent_ids=list(decision_requests.agent_id), ) @@ -239,13 +243,13 @@ def increment_step(self, n_steps): return self.get_current_step() def load_weights(self, values: List[np.ndarray]) -> None: - self.actor_critic.load_state_dict(values) + self.actor.load_state_dict(values) def init_load_weights(self) -> None: pass def get_weights(self) -> List[np.ndarray]: - return copy.deepcopy(self.actor_critic.state_dict()) + return copy.deepcopy(self.actor.state_dict()) def get_modules(self): - return {"Policy": self.actor_critic, "global_step": self.global_step} + return {"Policy": self.actor, "global_step": self.global_step} diff --git a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py index 2e97c37302..4dcf3db324 100644 --- a/ml-agents/mlagents/trainers/ppo/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/ppo/optimizer_torch.py @@ -1,5 +1,5 @@ from typing import Dict, cast -from mlagents.torch_utils import torch +from mlagents.torch_utils import torch, default_device from mlagents.trainers.buffer import AgentBuffer, BufferKey, RewardSignalUtil @@ -7,6 +7,7 @@ from mlagents.trainers.policy.torch_policy import TorchPolicy from mlagents.trainers.optimizer.torch_optimizer import TorchOptimizer from mlagents.trainers.settings import TrainerSettings, PPOSettings +from mlagents.trainers.torch.networks import ValueNetwork from mlagents.trainers.torch.agent_action import AgentAction from mlagents.trainers.torch.action_log_probs import ActionLogProbs from mlagents.trainers.torch.utils import ModelUtils @@ -25,7 +26,20 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): # Create the graph here to give more granular control of the TF graph to the Optimizer. super().__init__(policy, trainer_settings) - params = list(self.policy.actor_critic.parameters()) + reward_signal_configs = trainer_settings.reward_signals + reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] + + if policy.shared_critic: + self._critic = policy.actor + else: + self._critic = ValueNetwork( + reward_signal_names, + policy.behavior_spec.observation_specs, + network_settings=trainer_settings.network_settings, + ) + self._critic.to(default_device()) + + params = list(self.policy.actor.parameters()) + list(self._critic.parameters()) self.hyperparameters: PPOSettings = cast( PPOSettings, trainer_settings.hyperparameters ) @@ -58,6 +72,10 @@ def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): self.stream_names = list(self.reward_signals.keys()) + @property + def critic(self): + return self._critic + def ppo_value_loss( self, values: Dict[str, torch.Tensor], @@ -152,13 +170,28 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) - log_probs, entropy, values = self.policy.evaluate_actions( + # Get value memories + value_memories = [ + ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) + for i in range( + 0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length + ) + ] + if len(value_memories) > 0: + value_memories = torch.stack(value_memories).unsqueeze(0) + + log_probs, entropy = self.policy.evaluate_actions( current_obs, masks=act_masks, actions=actions, memories=memories, seq_len=self.policy.sequence_length, ) + values, _ = self.critic.critic_pass( + current_obs, + memories=value_memories, + sequence_length=self.policy.sequence_length, + ) old_log_probs = ActionLogProbs.from_buffer(batch).flatten() log_probs = log_probs.flatten() loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) diff --git a/ml-agents/mlagents/trainers/ppo/trainer.py b/ml-agents/mlagents/trainers/ppo/trainer.py index 03ef41d805..5025adb500 100644 --- a/ml-agents/mlagents/trainers/ppo/trainer.py +++ b/ml-agents/mlagents/trainers/ppo/trainer.py @@ -73,11 +73,13 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: self.policy.update_normalization(agent_buffer_trajectory) # Get all value estimates - value_estimates, value_next = self.optimizer.get_trajectory_value_estimates( + value_estimates, value_next, value_memories = self.optimizer.get_trajectory_value_estimates( agent_buffer_trajectory, trajectory.next_obs, trajectory.done_reached and not trajectory.interrupted, ) + if value_memories is not None: + agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories) for name, v in value_estimates.items(): agent_buffer_trajectory[RewardSignalUtil.value_estimates_key(name)].extend( diff --git a/ml-agents/mlagents/trainers/sac/optimizer_torch.py b/ml-agents/mlagents/trainers/sac/optimizer_torch.py index d900e39c17..fe3a8ddc58 100644 --- a/ml-agents/mlagents/trainers/sac/optimizer_torch.py +++ b/ml-agents/mlagents/trainers/sac/optimizer_torch.py @@ -107,6 +107,16 @@ def __init__(self, discrete, continuous): def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): super().__init__(policy, trainer_params) + reward_signal_configs = trainer_params.reward_signals + reward_signal_names = [key.value for key, _ in reward_signal_configs.items()] + if policy.shared_critic: + raise UnityTrainerException("SAC does not support SharedActorCritic") + self._critic = ValueNetwork( + reward_signal_names, + policy.behavior_spec.observation_specs, + policy.network_settings, + ) + hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) self.tau = hyperparameters.tau self.init_entcoef = hyperparameters.init_entcoef @@ -130,7 +140,7 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): } self._action_spec = self.policy.behavior_spec.action_spec - self.value_network = TorchSACOptimizer.PolicyValueNetwork( + self.q_network = TorchSACOptimizer.PolicyValueNetwork( self.stream_names, self.policy.behavior_spec.observation_specs, policy_network_settings, @@ -142,9 +152,7 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): self.policy.behavior_spec.observation_specs, policy_network_settings, ) - ModelUtils.soft_update( - self.policy.actor_critic.critic, self.target_network, 1.0 - ) + ModelUtils.soft_update(self._critic, self.target_network, 1.0) # We create one entropy coefficient per action, whether discrete or continuous. _disc_log_ent_coef = torch.nn.Parameter( @@ -173,11 +181,9 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): self.target_entropy = TorchSACOptimizer.TargetEntropy( continuous=_cont_target, discrete=_disc_target ) - policy_params = list(self.policy.actor_critic.network_body.parameters()) + list( - self.policy.actor_critic.action_model.parameters() - ) - value_params = list(self.value_network.parameters()) + list( - self.policy.actor_critic.critic.parameters() + policy_params = list(self.policy.actor.parameters()) + value_params = list(self.q_network.parameters()) + list( + self._critic.parameters() ) logger.debug("value_vars") @@ -204,10 +210,15 @@ def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): ) self._move_to_device(default_device()) + @property + def critic(self): + return self._critic + def _move_to_device(self, device: torch.device) -> None: self._log_ent_coef.to(device) self.target_network.to(device) - self.value_network.to(device) + self._critic.to(device) + self.q_network.to(device) def sac_q_loss( self, @@ -480,60 +491,69 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) ] # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. + value_memories_list = [ + ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) + for i in range( + 0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length + ) + ] offset = 1 if self.policy.sequence_length > 1 else 0 - next_memories_list = [ + next_value_memories_list = [ ModelUtils.list_to_tensor( - batch[BufferKey.MEMORY][i][self.policy.m_size // 2 :] + batch[BufferKey.CRITIC_MEMORY][i] ) # only pass value part of memory to target network for i in range( - offset, len(batch[BufferKey.MEMORY]), self.policy.sequence_length + offset, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length ) ] if len(memories_list) > 0: memories = torch.stack(memories_list).unsqueeze(0) - next_memories = torch.stack(next_memories_list).unsqueeze(0) + value_memories = torch.stack(value_memories_list).unsqueeze(0) + next_value_memories = torch.stack(next_value_memories_list).unsqueeze(0) else: memories = None - next_memories = None - # Q network memories are 0'ed out, since we don't have them during inference. + value_memories = None + next_value_memories = None + + # Q and V network memories are 0'ed out, since we don't have them during inference. q_memories = ( - torch.zeros_like(next_memories) if next_memories is not None else None + torch.zeros_like(next_value_memories) + if next_value_memories is not None + else None ) # Copy normalizers from policy - self.value_network.q1_network.network_body.copy_normalization( - self.policy.actor_critic.network_body + self.q_network.q1_network.network_body.copy_normalization( + self.policy.actor.network_body ) - self.value_network.q2_network.network_body.copy_normalization( - self.policy.actor_critic.network_body + self.q_network.q2_network.network_body.copy_normalization( + self.policy.actor.network_body ) self.target_network.network_body.copy_normalization( - self.policy.actor_critic.network_body + self.policy.actor.network_body ) - ( - sampled_actions, - log_probs, - _, - value_estimates, - _, - ) = self.policy.actor_critic.get_action_stats_and_value( + self._critic.network_body.copy_normalization(self.policy.actor.network_body) + sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats( current_obs, masks=act_masks, memories=memories, sequence_length=self.policy.sequence_length, ) + value_estimates, _ = self._critic.critic_pass( + current_obs, value_memories, sequence_length=self.policy.sequence_length + ) cont_sampled_actions = sampled_actions.continuous_tensor cont_actions = actions.continuous_tensor - q1p_out, q2p_out = self.value_network( + q1p_out, q2p_out = self.q_network( current_obs, cont_sampled_actions, memories=q_memories, sequence_length=self.policy.sequence_length, q2_grad=False, ) - q1_out, q2_out = self.value_network( + q1_out, q2_out = self.q_network( current_obs, cont_actions, memories=q_memories, @@ -550,7 +570,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: with torch.no_grad(): target_values, _ = self.target_network( next_obs, - memories=next_memories, + memories=next_value_memories, sequence_length=self.policy.sequence_length, ) masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) @@ -565,7 +585,11 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) entropy_loss = self.sac_entropy_loss(log_probs, masks) - total_value_loss = q1_loss + q2_loss + value_loss + total_value_loss = q1_loss + q2_loss + if self.policy.shared_critic: + policy_loss += value_loss + else: + total_value_loss += value_loss decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) @@ -584,9 +608,7 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: self.entropy_optimizer.step() # Update target network - ModelUtils.soft_update( - self.policy.actor_critic.critic, self.target_network, self.tau - ) + ModelUtils.soft_update(self._critic, self.target_network, self.tau) update_stats = { "Losses/Policy Loss": policy_loss.item(), "Losses/Value Loss": value_loss.item(), @@ -613,7 +635,7 @@ def update_reward_signals( def get_modules(self): modules = { - "Optimizer:value_network": self.value_network, + "Optimizer:value_network": self.q_network, "Optimizer:target_network": self.target_network, "Optimizer:policy_optimizer": self.policy_optimizer, "Optimizer:value_optimizer": self.value_optimizer, diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py index eb2183c4d2..640d9155bf 100644 --- a/ml-agents/mlagents/trainers/sac/trainer.py +++ b/ml-agents/mlagents/trainers/sac/trainer.py @@ -149,9 +149,16 @@ def _process_trajectory(self, trajectory: Trajectory) -> None: self.collected_rewards[name][agent_id] += np.sum(evaluate_result) # Get all value estimates for reporting purposes - value_estimates, _ = self.optimizer.get_trajectory_value_estimates( + ( + value_estimates, + _, + value_memories, + ) = self.optimizer.get_trajectory_value_estimates( agent_buffer_trajectory, trajectory.next_obs, trajectory.done_reached ) + if value_memories is not None: + agent_buffer_trajectory[BufferKey.CRITIC_MEMORY].set(value_memories) + for name, v in value_estimates.items(): self._stats_reporter.add_stat( f"Policy/{self.optimizer.reward_signals[name].name.capitalize()} Value", diff --git a/ml-agents/mlagents/trainers/tests/test_agent_processor.py b/ml-agents/mlagents/trainers/tests/test_agent_processor.py index c7c5f5ec80..9b8affcc22 100644 --- a/ml-agents/mlagents/trainers/tests/test_agent_processor.py +++ b/ml-agents/mlagents/trainers/tests/test_agent_processor.py @@ -20,7 +20,9 @@ def create_mock_policy(): mock_policy = mock.Mock() mock_policy.reward_signals = {} - mock_policy.retrieve_memories.return_value = np.zeros((1, 1), dtype=np.float32) + mock_policy.retrieve_previous_memories.return_value = np.zeros( + (1, 1), dtype=np.float32 + ) mock_policy.retrieve_previous_action.return_value = np.zeros((1, 1), dtype=np.int32) return mock_policy @@ -38,10 +40,12 @@ def test_agentprocessor(num_vis_obs): ) fake_action_outputs = { - "action": ActionTuple(continuous=np.array([[0.1], [0.1]])), + "action": ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)), "entropy": np.array([1.0], dtype=np.float32), "learning_rate": 1.0, - "log_probs": LogProbsTuple(continuous=np.array([[0.1], [0.1]])), + "log_probs": LogProbsTuple( + continuous=np.array([[0.1], [0.1]], dtype=np.float32) + ), } mock_decision_steps, mock_terminal_steps = mb.create_mock_steps( num_agents=2, @@ -51,9 +55,8 @@ def test_agentprocessor(num_vis_obs): action_spec=ActionSpec.create_continuous(2), ) fake_action_info = ActionInfo( - action=ActionTuple(continuous=np.array([[0.1], [0.1]])), - env_action=ActionTuple(continuous=np.array([[0.1], [0.1]])), - value=[0.1, 0.1], + action=ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)), + env_action=ActionTuple(continuous=np.array([[0.1], [0.1]], dtype=np.float32)), outputs=fake_action_outputs, agent_ids=mock_decision_steps.agent_id, ) @@ -103,10 +106,10 @@ def test_agent_deletion(): stats_reporter=StatsReporter("testcat"), ) fake_action_outputs = { - "action": ActionTuple(continuous=np.array([[0.1]])), + "action": ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), "entropy": np.array([1.0], dtype=np.float32), "learning_rate": 1.0, - "log_probs": LogProbsTuple(continuous=np.array([[0.1]])), + "log_probs": LogProbsTuple(continuous=np.array([[0.1]], dtype=np.float32)), } mock_decision_step, mock_terminal_step = mb.create_mock_steps( @@ -121,9 +124,8 @@ def test_agent_deletion(): done=True, ) fake_action_info = ActionInfo( - action=ActionTuple(continuous=np.array([[0.1]])), - env_action=ActionTuple(continuous=np.array([[0.1]])), - value=[0.1], + action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), + env_action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), outputs=fake_action_outputs, agent_ids=mock_decision_step.agent_id, ) @@ -182,10 +184,10 @@ def test_end_episode(): stats_reporter=StatsReporter("testcat"), ) fake_action_outputs = { - "action": ActionTuple(continuous=np.array([[0.1]])), + "action": ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), "entropy": np.array([1.0], dtype=np.float32), "learning_rate": 1.0, - "log_probs": LogProbsTuple(continuous=np.array([[0.1]])), + "log_probs": LogProbsTuple(continuous=np.array([[0.1]], dtype=np.float32)), } mock_decision_step, mock_terminal_step = mb.create_mock_steps( @@ -194,9 +196,8 @@ def test_end_episode(): action_spec=ActionSpec.create_continuous(2), ) fake_action_info = ActionInfo( - action=ActionTuple(continuous=np.array([[0.1]])), - env_action=ActionTuple(continuous=np.array([[0.1]])), - value=[0.1], + action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), + env_action=ActionTuple(continuous=np.array([[0.1]], dtype=np.float32)), outputs=fake_action_outputs, agent_ids=mock_decision_step.agent_id, ) diff --git a/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py b/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py index 515afe34c7..8a41467972 100644 --- a/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py +++ b/ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py @@ -69,8 +69,8 @@ def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None: """ Make sure two policies have the same output for the same input. """ - policy1.actor_critic = policy1.actor_critic.to(default_device()) - policy2.actor_critic = policy2.actor_critic.to(default_device()) + policy1.actor = policy1.actor.to(default_device()) + policy2.actor = policy2.actor.to(default_device()) decision_step, _ = mb.create_steps_from_behavior_spec( policy1.behavior_spec, num_agents=1 diff --git a/ml-agents/mlagents/trainers/tests/torch/test_networks.py b/ml-agents/mlagents/trainers/tests/torch/test_networks.py index cadf65216d..1b5064db77 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_networks.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_networks.py @@ -4,8 +4,8 @@ from mlagents.trainers.torch.networks import ( NetworkBody, ValueNetwork, + SimpleActor, SharedActorCritic, - SeparateActorCritic, ) from mlagents.trainers.settings import NetworkSettings from mlagents_envs.base_env import ActionSpec @@ -128,9 +128,9 @@ def test_valuenetwork(): assert _out[0] == pytest.approx(1.0, abs=0.1) -@pytest.mark.parametrize("ac_type", [SharedActorCritic, SeparateActorCritic]) +@pytest.mark.parametrize("shared", [True, False]) @pytest.mark.parametrize("lstm", [True, False]) -def test_actor_critic(ac_type, lstm): +def test_actor_critic(lstm, shared): obs_size = 4 network_settings = NetworkSettings( memory=NetworkSettings.MemorySettings() if lstm else None, normalize=True @@ -141,7 +141,13 @@ def test_actor_critic(ac_type, lstm): stream_names = [f"stream_name{n}" for n in range(4)] # action_spec = ActionSpec.create_continuous(act_size[0]) action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size))) - actor = ac_type(obs_spec, network_settings, action_spec, stream_names) + if shared: + actor = critic = SharedActorCritic( + obs_spec, network_settings, action_spec, stream_names, network_settings + ) + else: + actor = SimpleActor(obs_spec, network_settings, action_spec) + critic = ValueNetwork(stream_names, obs_spec, network_settings) if lstm: sample_obs = torch.ones((1, network_settings.memory.sequence_length, obs_size)) memories = torch.ones( @@ -153,7 +159,7 @@ def test_actor_critic(ac_type, lstm): # memories isn't always set to None, the network should be able to # deal with that. # Test critic pass - value_out, memories_out = actor.critic_pass([sample_obs], memories=memories) + value_out, memories_out = critic.critic_pass([sample_obs], memories=memories) for stream in stream_names: if lstm: assert value_out[stream].shape == (network_settings.memory.sequence_length,) @@ -162,7 +168,7 @@ def test_actor_critic(ac_type, lstm): assert value_out[stream].shape == (1,) # Test get action stats and_value - action, log_probs, entropies, value_out, mem_out = actor.get_action_stats_and_value( + action, log_probs, entropies, mem_out = actor.get_action_and_stats( [sample_obs], memories=memories, masks=mask ) if lstm: @@ -179,8 +185,3 @@ def test_actor_critic(ac_type, lstm): if mem_out is not None: assert mem_out.shape == memories.shape - for stream in stream_names: - if lstm: - assert value_out[stream].shape == (network_settings.memory.sequence_length,) - else: - assert value_out[stream].shape == (1,) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_policy.py b/ml-agents/mlagents/trainers/tests/torch/test_policy.py index 159584af54..956abcee21 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_policy.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_policy.py @@ -81,7 +81,7 @@ def test_evaluate_actions(rnn, visual, discrete): if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) - log_probs, entropy, values = policy.evaluate_actions( + log_probs, entropy = policy.evaluate_actions( tensor_obs, masks=act_masks, actions=agent_action, @@ -95,8 +95,6 @@ def test_evaluate_actions(rnn, visual, discrete): assert log_probs.flatten().shape == (64, _size) assert entropy.shape == (64,) - for val in values.values(): - assert val.shape == (64,) @pytest.mark.parametrize("discrete", [True, False], ids=["discrete", "continuous"]) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py b/ml-agents/mlagents/trainers/tests/torch/test_ppo.py index a95ac491ec..0b4c2c3472 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_ppo.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_ppo.py @@ -79,6 +79,8 @@ def test_ppo_optimizer_update(dummy_config, rnn, visual, discrete): RewardSignalUtil.value_estimates_key("extrinsic"), ], ) + # Copy memories to critic memories + copy_buffer_fields(update_buffer, BufferKey.MEMORY, [BufferKey.CRITIC_MEMORY]) return_stats = optimizer.update( update_buffer, @@ -126,6 +128,8 @@ def test_ppo_optimizer_update_curiosity( RewardSignalUtil.value_estimates_key("curiosity"), ], ) + # Copy memories to critic memories + copy_buffer_fields(update_buffer, BufferKey.MEMORY, [BufferKey.CRITIC_MEMORY]) optimizer.update( update_buffer, @@ -200,14 +204,16 @@ def test_ppo_get_value_estimates(dummy_config, rnn, visual, discrete): action_spec=DISCRETE_ACTION_SPEC if discrete else CONTINUOUS_ACTION_SPEC, max_step_complete=True, ) - run_out, final_value_out = optimizer.get_trajectory_value_estimates( + run_out, final_value_out, all_memories = optimizer.get_trajectory_value_estimates( trajectory.to_agentbuffer(), trajectory.next_obs, done=False ) for key, val in run_out.items(): assert type(key) is str assert len(val) == 15 + if all_memories is not None: + assert len(all_memories) == 15 - run_out, final_value_out = optimizer.get_trajectory_value_estimates( + run_out, final_value_out, _ = optimizer.get_trajectory_value_estimates( trajectory.to_agentbuffer(), trajectory.next_obs, done=True ) for key, val in final_value_out.items(): @@ -216,7 +222,7 @@ def test_ppo_get_value_estimates(dummy_config, rnn, visual, discrete): # Check if we ignore terminal states properly optimizer.reward_signals["extrinsic"].use_terminal_states = False - run_out, final_value_out = optimizer.get_trajectory_value_estimates( + run_out, final_value_out, _ = optimizer.get_trajectory_value_estimates( trajectory.to_agentbuffer(), trajectory.next_obs, done=False ) for key, val in final_value_out.items(): diff --git a/ml-agents/mlagents/trainers/tests/torch/test_sac.py b/ml-agents/mlagents/trainers/tests/torch/test_sac.py index f4ca4524de..323e47e986 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_sac.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_sac.py @@ -55,12 +55,14 @@ def test_sac_optimizer_update(dummy_config, rnn, visual, discrete): ) # Test update update_buffer = mb.simulate_rollout( - BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec, memory_size=24 + BUFFER_INIT_SAMPLES, optimizer.policy.behavior_spec, memory_size=12 ) # Mock out reward signal eval update_buffer[RewardSignalUtil.rewards_key("extrinsic")] = update_buffer[ BufferKey.ENVIRONMENT_REWARDS ] + # Mock out value memories + update_buffer[BufferKey.CRITIC_MEMORY] = update_buffer[BufferKey.MEMORY] return_stats = optimizer.update( update_buffer, num_sequences=update_buffer.num_experiences // optimizer.policy.sequence_length, diff --git a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py index 910c50a23c..6f1c30f36b 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_simple_rl.py @@ -245,7 +245,7 @@ def test_recurrent_sac(action_sizes): new_hyperparams = attr.evolve( SAC_TORCH_CONFIG.hyperparameters, batch_size=256, - learning_rate=1e-4, + learning_rate=3e-4, buffer_init_steps=1000, steps_per_update=2, ) @@ -255,7 +255,7 @@ def test_recurrent_sac(action_sizes): network_settings=new_networksettings, max_steps=4000, ) - check_environment_trains(env, {BRAIN_NAME: config}, training_seed=1213) + check_environment_trains(env, {BRAIN_NAME: config}, training_seed=1337) @pytest.mark.parametrize("action_sizes", [(0, 1), (1, 0)]) @@ -395,7 +395,7 @@ def test_gail_visual_ppo(simple_record, action_sizes): num_visual=1, num_vector=0, action_sizes=action_sizes, - step_size=0.2, + step_size=0.3, ) bc_settings = BehavioralCloningSettings(demo_path=demo_path, steps=1500) reward_signals = { diff --git a/ml-agents/mlagents/trainers/torch/components/bc/module.py b/ml-agents/mlagents/trainers/torch/components/bc/module.py index 4a71cf6e31..b8879569b4 100644 --- a/ml-agents/mlagents/trainers/torch/components/bc/module.py +++ b/ml-agents/mlagents/trainers/torch/components/bc/module.py @@ -37,7 +37,7 @@ def __init__( self.decay_learning_rate = ModelUtils.DecayedValue( learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps ) - params = self.policy.actor_critic.parameters() + params = self.policy.actor.parameters() self.optimizer = torch.optim.Adam(params, lr=self.current_lr) _, self.demonstration_buffer = demo_to_buffer( settings.demo_path, policy.sequence_length, policy.behavior_spec diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index 5af5401b17..3b44d2e31d 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -131,7 +131,7 @@ def export_policy_model(self, output_filepath: str) -> None: with exporting_to_onnx(): torch.onnx.export( - self.policy.actor_critic, + self.policy.actor, self.dummy_input, onnx_output_path, opset_version=SerializationSettings.onnx_opset, diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index fdf02bdc3e..1a2ab70872 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -156,7 +156,31 @@ def forward( return encoding, memories -class ValueNetwork(nn.Module): +class Critic(abc.ABC): + @abc.abstractmethod + def update_normalization(self, buffer: AgentBuffer) -> None: + """ + Updates normalization of Actor based on the provided List of vector obs. + :param vector_obs: A List of vector obs as tensors. + """ + pass + + def critic_pass( + self, + inputs: List[torch.Tensor], + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + """ + Get value outputs for the given obs. + :param inputs: List of inputs as tensors. + :param memories: Tensor of memories, if using memory. Otherwise, None. + :returns: Dict of reward stream to output tensor for values. + """ + pass + + +class ValueNetwork(nn.Module, Critic): def __init__( self, stream_names: List[str], @@ -177,10 +201,24 @@ def __init__( encoding_size = network_settings.hidden_units self.value_heads = ValueHeads(stream_names, encoding_size, outputs_per_stream) + def update_normalization(self, buffer: AgentBuffer) -> None: + self.network_body.update_normalization(buffer) + @property def memory_size(self) -> int: return self.network_body.memory_size + def critic_pass( + self, + inputs: List[torch.Tensor], + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + value_outputs, critic_mem_out = self.forward( + inputs, memories=memories, sequence_length=sequence_length + ) + return value_outputs, critic_mem_out + def forward( self, inputs: List[torch.Tensor], @@ -204,7 +242,7 @@ def update_normalization(self, buffer: AgentBuffer) -> None: """ pass - def get_action_stats( + def get_action_and_stats( self, inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, @@ -214,8 +252,7 @@ def get_action_stats( """ Returns sampled actions. If memory is enabled, return the memories as well. - :param vec_inputs: A List of vector inputs as tensors. - :param vis_inputs: A List of visual inputs as tensors. + :param inputs: A List of inputs as tensors. :param masks: If using discrete actions, a Tensor of action masks. :param memories: If using memory, a Tensor of initial memories. :param sequence_length: If using memory, the sequence length. @@ -224,66 +261,41 @@ def get_action_stats( """ pass - @abc.abstractmethod - def forward( - self, - vec_inputs: List[torch.Tensor], - vis_inputs: List[torch.Tensor], - var_len_inputs: List[torch.Tensor], - masks: Optional[torch.Tensor] = None, - memories: Optional[torch.Tensor] = None, - ) -> Tuple[Union[int, torch.Tensor], ...]: - """ - Forward pass of the Actor for inference. This is required for export to ONNX, and - the inputs and outputs of this method should not be changed without a respective change - in the ONNX export code. - """ - pass - - -class ActorCritic(Actor): - @abc.abstractmethod - def critic_pass( - self, - inputs: List[torch.Tensor], - memories: Optional[torch.Tensor] = None, - sequence_length: int = 1, - ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: - """ - Get value outputs for the given obs. - :param inputs: List of inputs as tensors. - :param memories: Tensor of memories, if using memory. Otherwise, None. - :returns: Dict of reward stream to output tensor for values. - """ - pass - - @abc.abstractmethod - def get_action_stats_and_value( + def get_stats( self, inputs: List[torch.Tensor], + actions: AgentAction, masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, - ) -> Tuple[ - AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor - ]: + ) -> Tuple[ActionLogProbs, torch.Tensor]: """ - Returns sampled actions and value estimates. + Returns log_probs for actions and entropies. If memory is enabled, return the memories as well. - :param inputs: A List of vector inputs as tensors. + :param inputs: A List of inputs as tensors. + :param actions: AgentAction of actions. :param masks: If using discrete actions, a Tensor of action masks. :param memories: If using memory, a Tensor of initial memories. :param sequence_length: If using memory, the sequence length. - :return: A Tuple of AgentAction, ActionLogProbs, entropies, Dict of reward signal - name to value estimate, and memories. Memories will be None if not using memory. + :return: A Tuple of AgentAction, ActionLogProbs, entropies, and memories. + Memories will be None if not using memory. """ + pass - @abc.abstractproperty - def memory_size(self): + @abc.abstractmethod + def forward( + self, + vec_inputs: List[torch.Tensor], + vis_inputs: List[torch.Tensor], + var_len_inputs: List[torch.Tensor], + masks: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + ) -> Tuple[Union[int, torch.Tensor], ...]: """ - Returns the size of the memory (same size used as input and output in the other - methods) used by this Actor. + Forward pass of the Actor for inference. This is required for export to ONNX, and + the inputs and outputs of this method should not be changed without a respective change + in the ONNX export code. """ pass @@ -344,7 +356,7 @@ def memory_size(self) -> int: def update_normalization(self, buffer: AgentBuffer) -> None: self.network_body.update_normalization(buffer) - def get_action_stats( + def get_action_and_stats( self, inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, @@ -358,6 +370,21 @@ def get_action_stats( action, log_probs, entropies = self.action_model(encoding, masks) return action, log_probs, entropies, memories + def get_stats( + self, + inputs: List[torch.Tensor], + actions: AgentAction, + masks: Optional[torch.Tensor] = None, + memories: Optional[torch.Tensor] = None, + sequence_length: int = 1, + ) -> Tuple[ActionLogProbs, torch.Tensor]: + encoding, actor_mem_outs = self.network_body( + inputs, memories=memories, sequence_length=sequence_length + ) + log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) + + return log_probs, entropies + def forward( self, vec_inputs: List[torch.Tensor], @@ -418,7 +445,7 @@ def forward( return tuple(export_out) -class SharedActorCritic(SimpleActor, ActorCritic): +class SharedActorCritic(SimpleActor, Critic): def __init__( self, observation_specs: List[ObservationSpec], @@ -450,157 +477,6 @@ def critic_pass( ) return self.value_heads(encoding), memories_out - def get_stats_and_value( - self, - inputs: List[torch.Tensor], - actions: AgentAction, - masks: Optional[torch.Tensor] = None, - memories: Optional[torch.Tensor] = None, - sequence_length: int = 1, - ) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]: - encoding, memories = self.network_body( - inputs, memories=memories, sequence_length=sequence_length - ) - log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) - value_outputs = self.value_heads(encoding) - return log_probs, entropies, value_outputs - - def get_action_stats_and_value( - self, - inputs: List[torch.Tensor], - masks: Optional[torch.Tensor] = None, - memories: Optional[torch.Tensor] = None, - sequence_length: int = 1, - ) -> Tuple[ - AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor - ]: - - encoding, memories = self.network_body( - inputs, memories=memories, sequence_length=sequence_length - ) - action, log_probs, entropies = self.action_model(encoding, masks) - value_outputs = self.value_heads(encoding) - return action, log_probs, entropies, value_outputs, memories - - -class SeparateActorCritic(SimpleActor, ActorCritic): - def __init__( - self, - observation_specs: List[ObservationSpec], - network_settings: NetworkSettings, - action_spec: ActionSpec, - stream_names: List[str], - conditional_sigma: bool = False, - tanh_squash: bool = False, - ): - self.use_lstm = network_settings.memory is not None - super().__init__( - observation_specs, - network_settings, - action_spec, - conditional_sigma, - tanh_squash, - ) - self.stream_names = stream_names - self.critic = ValueNetwork(stream_names, observation_specs, network_settings) - - @property - def memory_size(self) -> int: - return self.network_body.memory_size + self.critic.memory_size - - def _get_actor_critic_mem( - self, memories: Optional[torch.Tensor] = None - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - if self.use_lstm and memories is not None: - # Use only the back half of memories for critic and actor - actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) - actor_mem, critic_mem = actor_mem.contiguous(), critic_mem.contiguous() - else: - critic_mem = None - actor_mem = None - return actor_mem, critic_mem - - def critic_pass( - self, - inputs: List[torch.Tensor], - memories: Optional[torch.Tensor] = None, - sequence_length: int = 1, - ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: - actor_mem, critic_mem = self._get_actor_critic_mem(memories) - value_outputs, critic_mem_out = self.critic( - inputs, memories=critic_mem, sequence_length=sequence_length - ) - if actor_mem is not None: - # Make memories with the actor mem unchanged - memories_out = torch.cat([actor_mem, critic_mem_out], dim=-1) - else: - memories_out = None - return value_outputs, memories_out - - def get_stats_and_value( - self, - inputs: List[torch.Tensor], - actions: AgentAction, - masks: Optional[torch.Tensor] = None, - memories: Optional[torch.Tensor] = None, - sequence_length: int = 1, - ) -> Tuple[ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor]]: - actor_mem, critic_mem = self._get_actor_critic_mem(memories) - encoding, actor_mem_outs = self.network_body( - inputs, memories=actor_mem, sequence_length=sequence_length - ) - log_probs, entropies = self.action_model.evaluate(encoding, masks, actions) - value_outputs, critic_mem_outs = self.critic( - inputs, memories=critic_mem, sequence_length=sequence_length - ) - - return log_probs, entropies, value_outputs - - def get_action_stats( - self, - inputs: List[torch.Tensor], - masks: Optional[torch.Tensor] = None, - memories: Optional[torch.Tensor] = None, - sequence_length: int = 1, - ) -> Tuple[AgentAction, ActionLogProbs, torch.Tensor, torch.Tensor]: - actor_mem, critic_mem = self._get_actor_critic_mem(memories) - action, log_probs, entropies, actor_mem_out = super().get_action_stats( - inputs, masks=masks, memories=actor_mem, sequence_length=sequence_length - ) - if critic_mem is not None: - # Make memories with the actor mem unchanged - memories_out = torch.cat([actor_mem_out, critic_mem], dim=-1) - else: - memories_out = None - return action, log_probs, entropies, memories_out - - def get_action_stats_and_value( - self, - inputs: List[torch.Tensor], - masks: Optional[torch.Tensor] = None, - memories: Optional[torch.Tensor] = None, - sequence_length: int = 1, - ) -> Tuple[ - AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor - ]: - actor_mem, critic_mem = self._get_actor_critic_mem(memories) - encoding, actor_mem_outs = self.network_body( - inputs, memories=actor_mem, sequence_length=sequence_length - ) - action, log_probs, entropies = self.action_model(encoding, masks) - value_outputs, critic_mem_outs = self.critic( - inputs, memories=critic_mem, sequence_length=sequence_length - ) - if self.use_lstm: - mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) - else: - mem_out = None - return action, log_probs, entropies, value_outputs, mem_out - - def update_normalization(self, buffer: AgentBuffer) -> None: - super().update_normalization(buffer) - self.critic.network_body.update_normalization(buffer) - class GlobalSteps(nn.Module): def __init__(self):