Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
EnvironmentStats,
)
from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy import Policy
from mlagents.trainers.action_info import ActionInfo, ActionInfoOutputs
from mlagents.trainers.stats import StatsReporter
Expand All @@ -32,7 +31,7 @@ class AgentProcessor:

def __init__(
self,
policy: TFPolicy,
policy: Policy,
behavior_id: str,
stats_reporter: StatsReporter,
max_trajectory_length: int = sys.maxsize,
Expand Down Expand Up @@ -290,7 +289,7 @@ class AgentManager(AgentProcessor):

def __init__(
self,
policy: TFPolicy,
policy: Policy,
behavior_id: str,
stats_reporter: StatsReporter,
max_trajectory_length: int = sys.maxsize,
Expand Down
6 changes: 3 additions & 3 deletions ml-agents/mlagents/trainers/env_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)
from mlagents_envs.side_channel.stats_side_channel import EnvironmentStats

from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy import Policy
from mlagents.trainers.agent_processor import AgentManager, AgentManagerQueue
from mlagents.trainers.action_info import ActionInfo
from mlagents_envs.logging_util import get_logger
Expand Down Expand Up @@ -36,11 +36,11 @@ def empty(worker_id: int) -> "EnvironmentStep":

class EnvManager(ABC):
def __init__(self):
self.policies: Dict[BehaviorName, TFPolicy] = {}
self.policies: Dict[BehaviorName, Policy] = {}
self.agent_managers: Dict[BehaviorName, AgentManager] = {}
self.first_step_infos: List[EnvironmentStep] = []

def set_policy(self, brain_name: BehaviorName, policy: TFPolicy) -> None:
def set_policy(self, brain_name: BehaviorName, policy: Policy) -> None:
self.policies[brain_name] = policy
if brain_name in self.agent_managers:
self.agent_managers[brain_name].policy = policy
Expand Down
11 changes: 5 additions & 6 deletions ml-agents/mlagents/trainers/ghost/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
# ## ML-Agent Learning (Ghost Trainer)

from collections import defaultdict
from typing import Deque, Dict, DefaultDict, List, cast
from typing import Deque, Dict, DefaultDict, List

import numpy as np

from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy import Policy
from mlagents.trainers.policy.tf_policy import TFPolicy

from mlagents.trainers.trainer import Trainer
from mlagents.trainers.trajectory import Trajectory
Expand Down Expand Up @@ -262,7 +261,7 @@ def advance(self) -> None:
for brain_name in self._internal_policy_queues:
internal_policy_queue = self._internal_policy_queues[brain_name]
try:
policy = cast(TFPolicy, internal_policy_queue.get_nowait())
policy = internal_policy_queue.get_nowait()
self.current_policy_snapshot[brain_name] = policy.get_weights()
except AgentManagerQueue.Empty:
pass
Expand Down Expand Up @@ -306,7 +305,7 @@ def save_model(self) -> None:

def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
) -> TFPolicy:
) -> Policy:
"""
Creates policy with the wrapped trainer's create_policy function
The first policy encountered sets the wrapped
Expand Down Expand Up @@ -339,7 +338,7 @@ def create_policy(
return policy

def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
) -> None:
"""
Adds policy to GhostTrainer.
Expand All @@ -350,7 +349,7 @@ def add_policy(
self._name_to_parsed_behavior_id[name_behavior_id] = parsed_behavior_id
self.policies[name_behavior_id] = policy

def get_policy(self, name_behavior_id: str) -> TFPolicy:
def get_policy(self, name_behavior_id: str) -> Policy:
"""
Gets policy associated with name_behavior_id
:param name_behavior_id: Fully qualified behavior name
Expand Down
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ class Optimizer(abc.ABC):
Provides methods to update the Policy.
"""

def __init__(self):
self.reward_signals = {}

@abc.abstractmethod
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
"""
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/optimizer/tf_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

class TFOptimizer(Optimizer): # pylint: disable=W0223
def __init__(self, policy: TFPolicy, trainer_params: TrainerSettings):
super().__init__()
self.sess = policy.sess
self.policy = policy
self.update_dict: Dict[str, tf.Tensor] = {}
Expand Down Expand Up @@ -129,7 +130,6 @@ def create_reward_signals(
Create reward signals
:param reward_signal_configs: Reward signal config.
"""
self.reward_signals = {}
# Create reward signals
for reward_signal, settings in reward_signal_configs.items():
# Name reward signals by string in case we have duplicates later
Expand Down
12 changes: 12 additions & 0 deletions ml-agents/mlagents/trainers/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,15 @@ def checkpoint(self, checkpoint_path: str, settings: SerializationSettings) -> N
@abstractmethod
def save(self, output_filepath: str, settings: SerializationSettings) -> None:
pass

@abstractmethod
def load_weights(self, values: List[np.ndarray]) -> None:
pass

@abstractmethod
def get_weights(self) -> List[np.ndarray]:
return []

@abstractmethod
def init_load_weights(self) -> None:
pass
56 changes: 0 additions & 56 deletions ml-agents/mlagents/trainers/policy/tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,62 +376,6 @@ def fill_eval_dict(self, feed_dict, batched_step_result):
feed_dict[self.action_masks] = mask
return feed_dict

def make_empty_memory(self, num_agents):
"""
Creates empty memory for use with RNNs
:param num_agents: Number of agents.
:return: Numpy array of zeros.
"""
return np.zeros((num_agents, self.m_size), dtype=np.float32)

def save_memories(
self, agent_ids: List[str], memory_matrix: Optional[np.ndarray]
) -> None:
if memory_matrix is None:
return
for index, agent_id in enumerate(agent_ids):
self.memory_dict[agent_id] = memory_matrix[index, :]

def retrieve_memories(self, agent_ids: List[str]) -> np.ndarray:
memory_matrix = np.zeros((len(agent_ids), self.m_size), dtype=np.float32)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.memory_dict:
memory_matrix[index, :] = self.memory_dict[agent_id]
return memory_matrix

def remove_memories(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.memory_dict:
self.memory_dict.pop(agent_id)

def make_empty_previous_action(self, num_agents):
"""
Creates empty previous action for use with RNNs and discrete control
:param num_agents: Number of agents.
:return: Numpy array of zeros.
"""
return np.zeros((num_agents, self.num_branches), dtype=np.int)

def save_previous_action(
self, agent_ids: List[str], action_matrix: Optional[np.ndarray]
) -> None:
if action_matrix is None:
return
for index, agent_id in enumerate(agent_ids):
self.previous_action_dict[agent_id] = action_matrix[index, :]

def retrieve_previous_action(self, agent_ids: List[str]) -> np.ndarray:
action_matrix = np.zeros((len(agent_ids), self.num_branches), dtype=np.int)
for index, agent_id in enumerate(agent_ids):
if agent_id in self.previous_action_dict:
action_matrix[index, :] = self.previous_action_dict[agent_id]
return action_matrix

def remove_previous_action(self, agent_ids):
for agent_id in agent_ids:
if agent_id in self.previous_action_dict:
self.previous_action_dict.pop(agent_id)

def get_current_step(self):
"""
Gets current model step.
Expand Down
11 changes: 7 additions & 4 deletions ml-agents/mlagents/trainers/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.policy import Policy
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.ppo.optimizer import PPOOptimizer
from mlagents.trainers.trajectory import Trajectory
Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(
)
self.load = load
self.seed = seed
self.policy: TFPolicy = None # type: ignore
self.policy: Policy = None # type: ignore

def _process_trajectory(self, trajectory: Trajectory) -> None:
"""
Expand Down Expand Up @@ -208,7 +209,7 @@ def create_policy(
return policy

def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
) -> None:
"""
Adds policy to trainer.
Expand All @@ -224,13 +225,15 @@ def add_policy(
)
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
self.optimizer = PPOOptimizer(self.policy, self.trainer_settings)
self.optimizer = PPOOptimizer(
cast(TFPolicy, self.policy), self.trainer_settings
)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
self.step = policy.get_current_step()

def get_policy(self, name_behavior_id: str) -> TFPolicy:
def get_policy(self, name_behavior_id: str) -> Policy:
"""
Gets policy from trainer associated with name_behavior_id
:param name_behavior_id: full identifier of policy
Expand Down
11 changes: 7 additions & 4 deletions ml-agents/mlagents/trainers/sac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mlagents_envs.timers import timed
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.policy import Policy
from mlagents.trainers.sac.optimizer import SACOptimizer
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.trajectory import Trajectory, SplitObservations
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(

self.load = load
self.seed = seed
self.policy: TFPolicy = None # type: ignore
self.policy: Policy = None # type: ignore
self.optimizer: SACOptimizer = None # type: ignore
self.hyperparameters: SACSettings = cast(
SACSettings, trainer_settings.hyperparameters
Expand Down Expand Up @@ -312,7 +313,7 @@ def _update_reward_signals(self) -> None:
self._stats_reporter.add_stat(stat, np.mean(stat_list))

def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
) -> None:
"""
Adds policy to trainer.
Expand All @@ -326,7 +327,9 @@ def add_policy(
)
self.policy = policy
self.policies[parsed_behavior_id.behavior_id] = policy
self.optimizer = SACOptimizer(self.policy, self.trainer_settings)
self.optimizer = SACOptimizer(
cast(TFPolicy, self.policy), self.trainer_settings
)
for _reward_signal in self.optimizer.reward_signals.keys():
self.collected_rewards[_reward_signal] = defaultdict(lambda: 0)
# Needed to resume loads properly
Expand All @@ -337,7 +340,7 @@ def add_policy(
max(1, self.step / self.reward_signal_steps_per_update)
)

def get_policy(self, name_behavior_id: str) -> TFPolicy:
def get_policy(self, name_behavior_id: str) -> Policy:
"""
Gets policy from trainer associated with name_behavior_id
:param name_behavior_id: full identifier of policy
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
from mlagents.trainers.optimizer.tf_optimizer import TFOptimizer
from mlagents.trainers.optimizer import Optimizer
from mlagents.trainers.buffer import AgentBuffer
from mlagents.trainers.trainer import Trainer
from mlagents.trainers.components.reward_signals import RewardSignalResult
Expand Down Expand Up @@ -56,7 +56,7 @@ def end_episode(self) -> None:
for agent_id in rewards:
rewards[agent_id] = 0

def _update_end_episode_stats(self, agent_id: str, optimizer: TFOptimizer) -> None:
def _update_end_episode_stats(self, agent_id: str, optimizer: Optimizer) -> None:
for name, rewards in self.collected_rewards.items():
if name == "environment":
self.stats_reporter.add_stat(
Expand Down
9 changes: 4 additions & 5 deletions ml-agents/mlagents/trainers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from mlagents_envs.logging_util import get_logger
from mlagents_envs.base_env import BehaviorSpec
from mlagents.trainers.policy.tf_policy import TFPolicy
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.agent_processor import AgentManagerQueue
Expand Down Expand Up @@ -47,7 +46,7 @@ def __init__(
self.step: int = 0
self.artifact_path = artifact_path
self.summary_freq = self.trainer_settings.summary_freq
self.policies: Dict[str, TFPolicy] = {}
self.policies: Dict[str, Policy] = {}

@property
def stats_reporter(self):
Expand Down Expand Up @@ -125,23 +124,23 @@ def end_episode(self):
@abc.abstractmethod
def create_policy(
self, parsed_behavior_id: BehaviorIdentifiers, behavior_spec: BehaviorSpec
) -> TFPolicy:
) -> Policy:
"""
Creates policy
"""
pass

@abc.abstractmethod
def add_policy(
self, parsed_behavior_id: BehaviorIdentifiers, policy: TFPolicy
self, parsed_behavior_id: BehaviorIdentifiers, policy: Policy
) -> None:
"""
Adds policy to trainer.
"""
pass

@abc.abstractmethod
def get_policy(self, name_behavior_id: str) -> TFPolicy:
def get_policy(self, name_behavior_id: str) -> Policy:
"""
Gets policy from trainer.
"""
Expand Down