diff --git a/UnitySDK/Assets/ML-Agents/Scripts/CoreBrainInternal.cs b/UnitySDK/Assets/ML-Agents/Scripts/CoreBrainInternal.cs index 2d73184043..3428b32bdd 100644 --- a/UnitySDK/Assets/ML-Agents/Scripts/CoreBrainInternal.cs +++ b/UnitySDK/Assets/ML-Agents/Scripts/CoreBrainInternal.cs @@ -41,9 +41,6 @@ public enum TensorType /// Modify only in inspector : Reference to the Graph asset public TextAsset graphModel; - /// Modify only in inspector : If a scope was used when training the model, specify it here - public string graphScope; - [SerializeField] [Tooltip( "If your graph takes additional inputs that are fixed (example: noise level) you can specify them here.")] @@ -136,40 +133,35 @@ public void InitializeCoreBrain(MLAgents.Batcher brainBatcher) // TODO: Make this a loop over a dynamic set of graph inputs - if ((graphScope.Length > 1) && (graphScope[graphScope.Length - 1] != '/')) - { - graphScope = graphScope + '/'; - } - - if (graph[graphScope + BatchSizePlaceholderName] != null) + if (graph[BatchSizePlaceholderName] != null) { hasBatchSize = true; } - if ((graph[graphScope + RecurrentInPlaceholderName] != null) && - (graph[graphScope + RecurrentOutPlaceholderName] != null)) + if ((graph[RecurrentInPlaceholderName] != null) && + (graph[RecurrentOutPlaceholderName] != null)) { hasRecurrent = true; var runner = session.GetRunner(); - runner.Fetch(graph[graphScope + "memory_size"][0]); + runner.Fetch(graph["memory_size"][0]); var networkOutput = runner.Run()[0].GetValue(); memorySize = (int) networkOutput; } - if (graph[graphScope + VectorObservationPlacholderName] != null) + if (graph[VectorObservationPlacholderName] != null) { hasState = true; } - if (graph[graphScope + PreviousActionPlaceholderName] != null) + if (graph[PreviousActionPlaceholderName] != null) { hasPrevAction = true; } - if (graph[graphScope + "value_estimate"] != null) + if (graph["value_estimate"] != null) { hasValueEstimate = true; } - if (graph[graphScope + ActionMaskPlaceholderName] != null) + if (graph[ActionMaskPlaceholderName] != null) { hasMaskedActions = true; } @@ -304,18 +296,18 @@ public void DecideAction(Dictionary agentInfo) var runner = session.GetRunner(); try { - runner.Fetch(graph[graphScope + ActionPlaceholderName][0]); + runner.Fetch(graph[ActionPlaceholderName][0]); } catch { throw new UnityAgentsException(string.Format( - @"The node {0} could not be found. Please make sure the graphScope {1} is correct", - graphScope + ActionPlaceholderName, graphScope)); + @"The node {0} could not be found. Please make sure the node name is correct", + ActionPlaceholderName)); } if (hasBatchSize) { - runner.AddInput(graph[graphScope + BatchSizePlaceholderName][0], new int[] {currentBatchSize}); + runner.AddInput(graph[BatchSizePlaceholderName][0], new int[] {currentBatchSize}); } foreach (TensorFlowAgentPlaceholder placeholder in graphPlaceholders) @@ -324,12 +316,12 @@ public void DecideAction(Dictionary agentInfo) { if (placeholder.valueType == TensorFlowAgentPlaceholder.TensorType.FloatingPoint) { - runner.AddInput(graph[graphScope + placeholder.name][0], + runner.AddInput(graph[placeholder.name][0], new float[] {Random.Range(placeholder.minValue, placeholder.maxValue)}); } else if (placeholder.valueType == TensorFlowAgentPlaceholder.TensorType.Integer) { - runner.AddInput(graph[graphScope + placeholder.name][0], + runner.AddInput(graph[placeholder.name][0], new int[] {Random.Range((int) placeholder.minValue, (int) placeholder.maxValue + 1)}); } } @@ -338,26 +330,26 @@ public void DecideAction(Dictionary agentInfo) throw new UnityAgentsException(string.Format( @"One of the Tensorflow placeholder cound nout be found. In brain {0}, there are no {1} placeholder named {2}.", - brain.gameObject.name, placeholder.valueType.ToString(), graphScope + placeholder.name)); + brain.gameObject.name, placeholder.valueType.ToString(), placeholder.name)); } } // Create the state tensor if (hasState) { - runner.AddInput(graph[graphScope + VectorObservationPlacholderName][0], inputState); + runner.AddInput(graph[VectorObservationPlacholderName][0], inputState); } // Create the previous action tensor if (hasPrevAction) { - runner.AddInput(graph[graphScope + PreviousActionPlaceholderName][0], inputPrevAction); + runner.AddInput(graph[PreviousActionPlaceholderName][0], inputPrevAction); } // Create the mask action tensor if (hasMaskedActions) { - runner.AddInput(graph[graphScope + ActionMaskPlaceholderName][0], maskedActions); + runner.AddInput(graph[ActionMaskPlaceholderName][0], maskedActions); } // Create the observation tensors @@ -366,20 +358,20 @@ public void DecideAction(Dictionary agentInfo) obsNumber < brain.brainParameters.cameraResolutions.Length; obsNumber++) { - runner.AddInput(graph[graphScope + VisualObservationPlaceholderName[obsNumber]][0], + runner.AddInput(graph[VisualObservationPlaceholderName[obsNumber]][0], observationMatrixList[obsNumber]); } if (hasRecurrent) { - runner.AddInput(graph[graphScope + "sequence_length"][0], 1); - runner.AddInput(graph[graphScope + RecurrentInPlaceholderName][0], inputOldMemories); - runner.Fetch(graph[graphScope + RecurrentOutPlaceholderName][0]); + runner.AddInput(graph["sequence_length"][0], 1); + runner.AddInput(graph[RecurrentInPlaceholderName][0], inputOldMemories); + runner.Fetch(graph[RecurrentOutPlaceholderName][0]); } if (hasValueEstimate) { - runner.Fetch(graph[graphScope + "value_estimate"][0]); + runner.Fetch(graph["value_estimate"][0]); } TFTensor[] networkOutput; @@ -504,13 +496,6 @@ public void OnInspector() { EditorGUILayout.HelpBox("Please provide a tensorflow graph as a bytes file.", MessageType.Error); } - - - graphScope = - EditorGUILayout.TextField(new GUIContent("Graph Scope", - "If you set a scope while training your tensorflow model, " + - "all your placeholder name will have a prefix. You must specify that prefix here."), graphScope); - if (BatchSizePlaceholderName == "") { BatchSizePlaceholderName = "batch_size"; diff --git a/docs/Basic-Guide.md b/docs/Basic-Guide.md index b620fd14a0..f110993ae9 100644 --- a/docs/Basic-Guide.md +++ b/docs/Basic-Guide.md @@ -192,12 +192,12 @@ INFO:mlagents.envs:Hyperparameters for the PPO Trainer of brain Ball3DBrain: sequence_length: 64 summary_freq: 1000 use_recurrent: False - graph_scope: summary_path: ./summaries/first-run-0 memory_size: 256 use_curiosity: False curiosity_strength: 0.01 curiosity_enc_size: 128 + model_path: ./models/first-run-0/Ball3DBrain INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 1000. Mean Reward: 1.242. Std of Reward: 0.746. Training. INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 2000. Mean Reward: 1.319. Std of Reward: 0.693. Training. INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 3000. Mean Reward: 1.804. Std of Reward: 1.056. Training. diff --git a/docs/FAQ.md b/docs/FAQ.md index 5edbf29b1a..cf9e8ba2e6 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -44,28 +44,6 @@ C# scripts, then adding the `CoreBrain` back. Make sure your brain is in Internal mode, your TensorFlowSharp plugin is imported and the ENABLE_TENSORFLOW flag is set. This fix is only valid locally and unstable. -## Tensorflow epsilon placeholder error - -If you have a graph placeholder set in the Internal Brain inspector that is not -present in the TensorFlow graph, you will see some error like this: - -```console -UnityAgentsException: One of the TensorFlow placeholder could not be found. In brain , there are no FloatingPoint placeholder named . -``` - -Solution: Go to all of your Brain object, find `Graph placeholders` and change -its `size` to 0 to remove the `epsilon` placeholder. - -Similarly, if you have a graph scope set in the Internal Brain inspector that is -not correctly set, you will see some error like this: - -```console -UnityAgentsException: The node /action could not be found. Please make sure the graphScope / is correct -``` - -Solution: Make sure your Graph Scope field matches the corresponding Brain -object name in your Hierarchy Inspector when there are multiple Brains. - ## Environment Permission Error If you directly import your Unity environment without building it in the diff --git a/docs/Learning-Environment-Design-External-Internal-Brains.md b/docs/Learning-Environment-Design-External-Internal-Brains.md index 677c034901..f9b60d3093 100644 --- a/docs/Learning-Environment-Design-External-Internal-Brains.md +++ b/docs/Learning-Environment-Design-External-Internal-Brains.md @@ -81,11 +81,6 @@ which must be set to the .bytes file containing the trained model itself. Only change the following Internal Brain properties if you have created your own TensorFlow model and are not using an ML-Agents model: -* `Graph Scope` : If you set a scope while training your TensorFlow model, all - your placeholder name will have a prefix. You must specify that prefix here. - Note that if more than one Brain were set to external during training, you - must give a `Graph Scope` to the Internal Brain corresponding to the name of - the Brain GameObject. * `Batch Size Node Name` : If the batch size is one of the inputs of your graph, you must specify the name if the placeholder here. The Brain will make the batch size equal to the number of Agents connected to the Brain diff --git a/docs/Learning-Environment-Executable.md b/docs/Learning-Environment-Executable.md index 829fea10b0..df0d8f7dde 100644 --- a/docs/Learning-Environment-Executable.md +++ b/docs/Learning-Environment-Executable.md @@ -185,12 +185,12 @@ INFO:mlagents.envs:Hyperparameters for the PPO Trainer of brain Ball3DBrain: sequence_length: 64 summary_freq: 1000 use_recurrent: False - graph_scope: summary_path: ./summaries/first-run-0 memory_size: 256 use_curiosity: False curiosity_strength: 0.01 curiosity_enc_size: 128 + model_path: ./models/first-run-0/Ball3DBrain INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 1000. Mean Reward: 1.242. Std of Reward: 0.746. Training. INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 2000. Mean Reward: 1.319. Std of Reward: 0.693. Training. INFO:mlagents.trainers: first-run-0: Ball3DBrain: Step: 3000. Mean Reward: 1.804. Std of Reward: 1.056. Training. diff --git a/docs/Using-TensorFlow-Sharp-in-Unity.md b/docs/Using-TensorFlow-Sharp-in-Unity.md index 2e3f2246cf..e2ed0f8fc9 100644 --- a/docs/Using-TensorFlow-Sharp-in-Unity.md +++ b/docs/Using-TensorFlow-Sharp-in-Unity.md @@ -93,10 +93,6 @@ Your model will be saved with the name `your_name_graph.bytes` and will contain both the graph and associated weights. Note that you must save your graph as a .bytes file so Unity can load it. -In the Unity Editor, you must specify the names of the nodes used by your graph -in the **Internal** Brain Inspector window. If you used a scope when defining -your graph, specify it in the `Graph Scope` field. - ![Internal Brain Inspector](images/internal_brain.png) See diff --git a/ml-agents/mlagents/trainers/bc/models.py b/ml-agents/mlagents/trainers/bc/models.py index 25e7c3ccf6..d06f33fb18 100644 --- a/ml-agents/mlagents/trainers/bc/models.py +++ b/ml-agents/mlagents/trainers/bc/models.py @@ -5,52 +5,51 @@ class BehavioralCloningModel(LearningModel): def __init__(self, brain, h_size=128, lr=1e-4, n_layers=2, m_size=128, - normalize=False, use_recurrent=False, scope='PPO', seed=0): - with tf.variable_scope(scope): - LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed) - num_streams = 1 - hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers) - hidden = hidden_streams[0] - self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate") - hidden_reg = tf.layers.dropout(hidden, self.dropout_rate) - if self.use_recurrent: - tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32) - self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in') - hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in, - self.sequence_length) - self.memory_out = tf.identity(self.memory_out, name='recurrent_out') + normalize=False, use_recurrent=False, seed=0): + LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed) + num_streams = 1 + hidden_streams = self.create_observation_streams(num_streams, h_size, n_layers) + hidden = hidden_streams[0] + self.dropout_rate = tf.placeholder(dtype=tf.float32, shape=[], name="dropout_rate") + hidden_reg = tf.layers.dropout(hidden, self.dropout_rate) + if self.use_recurrent: + tf.Variable(self.m_size, name="memory_size", trainable=False, dtype=tf.int32) + self.memory_in = tf.placeholder(shape=[None, self.m_size], dtype=tf.float32, name='recurrent_in') + hidden_reg, self.memory_out = self.create_recurrent_encoder(hidden_reg, self.memory_in, + self.sequence_length) + self.memory_out = tf.identity(self.memory_out, name='recurrent_out') - if brain.vector_action_space_type == "discrete": - policy_branches = [] - for size in self.act_size: - policy_branches.append( - tf.layers.dense( - hidden, - size, - activation=None, - use_bias=False, - kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01))) - self.action_probs = tf.concat( - [tf.nn.softmax(branch) for branch in policy_branches], axis=1, name="action_probs") - self.action_masks = tf.placeholder(shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks") - self.sample_action_float, _ = self.create_discrete_action_masking_layer( - tf.concat(policy_branches, axis = 1), self.action_masks, self.act_size) - self.sample_action_float = tf.identity(self.sample_action_float, name="action") - self.sample_action = tf.cast(self.sample_action_float, tf.int32) - self.true_action = tf.placeholder(shape=[None, len(policy_branches)], dtype=tf.int32, name="teacher_action") - self.action_oh = tf.concat([ - tf.one_hot(self.true_action[:, i], self.act_size[i]) for i in range(len(self.act_size))], axis=1) - self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh) - self.action_percent = tf.reduce_mean(tf.cast( - tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32)) - else: - self.policy = tf.layers.dense(hidden_reg, self.act_size[0], activation=None, use_bias=False, name='pre_action', - kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) - self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1) - self.sample_action = tf.identity(self.clipped_sample_action, name="action") - self.true_action = tf.placeholder(shape=[None, self.act_size[0]], dtype=tf.float32, name="teacher_action") - self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1) - self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action)) + if brain.vector_action_space_type == "discrete": + policy_branches = [] + for size in self.act_size: + policy_branches.append( + tf.layers.dense( + hidden, + size, + activation=None, + use_bias=False, + kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01))) + self.action_probs = tf.concat( + [tf.nn.softmax(branch) for branch in policy_branches], axis=1, name="action_probs") + self.action_masks = tf.placeholder(shape=[None, sum(self.act_size)], dtype=tf.float32, name="action_masks") + self.sample_action_float, _ = self.create_discrete_action_masking_layer( + tf.concat(policy_branches, axis = 1), self.action_masks, self.act_size) + self.sample_action_float = tf.identity(self.sample_action_float, name="action") + self.sample_action = tf.cast(self.sample_action_float, tf.int32) + self.true_action = tf.placeholder(shape=[None, len(policy_branches)], dtype=tf.int32, name="teacher_action") + self.action_oh = tf.concat([ + tf.one_hot(self.true_action[:, i], self.act_size[i]) for i in range(len(self.act_size))], axis=1) + self.loss = tf.reduce_sum(-tf.log(self.action_probs + 1e-10) * self.action_oh) + self.action_percent = tf.reduce_mean(tf.cast( + tf.equal(tf.cast(tf.argmax(self.action_probs, axis=1), tf.int32), self.sample_action), tf.float32)) + else: + self.policy = tf.layers.dense(hidden_reg, self.act_size[0], activation=None, use_bias=False, name='pre_action', + kernel_initializer=c_layers.variance_scaling_initializer(factor=0.01)) + self.clipped_sample_action = tf.clip_by_value(self.policy, -1, 1) + self.sample_action = tf.identity(self.clipped_sample_action, name="action") + self.true_action = tf.placeholder(shape=[None, self.act_size[0]], dtype=tf.float32, name="teacher_action") + self.clipped_true_action = tf.clip_by_value(self.true_action, -1, 1) + self.loss = tf.reduce_sum(tf.squared_difference(self.clipped_true_action, self.sample_action)) - optimizer = tf.train.AdamOptimizer(learning_rate=lr) - self.update = optimizer.minimize(self.loss) + optimizer = tf.train.AdamOptimizer(learning_rate=lr) + self.update = optimizer.minimize(self.loss) diff --git a/ml-agents/mlagents/trainers/bc/policy.py b/ml-agents/mlagents/trainers/bc/policy.py index 585d8eb6fe..1eac4a87a9 100644 --- a/ml-agents/mlagents/trainers/bc/policy.py +++ b/ml-agents/mlagents/trainers/bc/policy.py @@ -8,25 +8,31 @@ class BCPolicy(Policy): - def __init__(self, seed, brain, trainer_parameters, sess): + def __init__(self, seed, brain, trainer_parameters, load): """ :param seed: Random seed. :param brain: Assigned Brain object. :param trainer_parameters: Defined training parameters. - :param sess: TensorFlow session. + :param load: Whether a pre-trained model will be loaded or a new one created. """ - super().__init__(seed, brain, trainer_parameters, sess) + super().__init__(seed, brain, trainer_parameters) - self.model = BehavioralCloningModel( - h_size=int(trainer_parameters['hidden_units']), - lr=float(trainer_parameters['learning_rate']), - n_layers=int(trainer_parameters['num_layers']), - m_size=self.m_size, - normalize=False, - use_recurrent=trainer_parameters['use_recurrent'], - brain=brain, - scope=self.variable_scope, - seed=seed) + with self.graph.as_default(): + with self.graph.as_default(): + self.model = BehavioralCloningModel( + h_size=int(trainer_parameters['hidden_units']), + lr=float(trainer_parameters['learning_rate']), + n_layers=int(trainer_parameters['num_layers']), + m_size=self.m_size, + normalize=False, + use_recurrent=trainer_parameters['use_recurrent'], + brain=brain, + seed=seed) + + if load: + self._load_graph() + else: + self._initialize_graph() self.inference_dict = {'action': self.model.sample_action} self.update_dict = {'policy_loss': self.model.loss, diff --git a/ml-agents/mlagents/trainers/bc/trainer.py b/ml-agents/mlagents/trainers/bc/trainer.py index 5c677259cf..5e2f045bec 100644 --- a/ml-agents/mlagents/trainers/bc/trainer.py +++ b/ml-agents/mlagents/trainers/bc/trainer.py @@ -19,27 +19,31 @@ class BehavioralCloningTrainer(Trainer): """The ImitationTrainer is an implementation of the imitation learning.""" - def __init__(self, sess, brain, trainer_parameters, training, seed, run_id): + def __init__(self, brain, trainer_parameters, training, load, seed, run_id): """ Responsible for collecting experiences and training PPO model. - :param sess: Tensorflow session. :param trainer_parameters: The parameters for the trainer (dictionary). :param training: Whether the trainer is set for training. + :param load: Whether the model should be loaded. + :param seed: The seed the model will be initialized with + :param run_id: The The identifier of the current run """ - super(BehavioralCloningTrainer, self).__init__(sess, brain, trainer_parameters, training, run_id) - self.param_keys = ['brain_to_imitate', 'batch_size', 'time_horizon', - 'graph_scope', 'summary_freq', 'max_steps', + 'summary_freq', 'max_steps', 'batches_per_epoch', 'use_recurrent', - 'hidden_units','learning_rate', 'num_layers', - 'sequence_length', 'memory_size'] + 'hidden_units', 'learning_rate', 'num_layers', + 'sequence_length', 'memory_size', 'model_path'] for k in self.param_keys: + print(k) + print(k not in trainer_parameters) if k not in trainer_parameters: raise UnityTrainerException("The hyperparameter {0} could not be found for the Imitation trainer of " "brain {1}.".format(k, brain.brain_name)) - self.policy = BCPolicy(seed, brain, trainer_parameters, sess) + super(BehavioralCloningTrainer, self).__init__(brain, trainer_parameters, training, run_id) + + self.policy = BCPolicy(seed, brain, trainer_parameters, load) self.brain_name = brain.brain_name self.brain_to_imitate = trainer_parameters['brain_to_imitate'] self.batches_per_epoch = trainer_parameters['batches_per_epoch'] diff --git a/ml-agents/mlagents/trainers/policy.py b/ml-agents/mlagents/trainers/policy.py index 3e8095e8b6..d545652143 100644 --- a/ml-agents/mlagents/trainers/policy.py +++ b/ml-agents/mlagents/trainers/policy.py @@ -1,9 +1,12 @@ import logging import numpy as np +import tensorflow as tf from mlagents.trainers import UnityException from mlagents.trainers.models import LearningModel +from tensorflow.python.tools import freeze_graph + logger = logging.getLogger("mlagents.trainers") @@ -19,14 +22,15 @@ class Policy(object): Contains a learning model, and the necessary functions to interact with it to perform evaluate and updating. """ + possible_output_nodes = ['action', 'value_estimate', + 'action_probs', 'recurrent_out', 'memory_size'] - def __init__(self, seed, brain, trainer_parameters, sess): + def __init__(self, seed, brain, trainer_parameters): """ Initialized the policy. :param seed: Random seed to use for TensorFlow. :param brain: The corresponding Brain for this policy. :param trainer_parameters: The trainer parameters. - :param sess: The current TensorFlow session. """ self.m_size = None self.model = None @@ -35,10 +39,15 @@ def __init__(self, seed, brain, trainer_parameters, sess): self.sequence_length = 1 self.seed = seed self.brain = brain - self.variable_scope = trainer_parameters['graph_scope'] self.use_recurrent = trainer_parameters["use_recurrent"] self.use_continuous_act = (brain.vector_action_space_type == "continuous") - self.sess = sess + self.model_path = trainer_parameters["model_path"] + self.keep_checkpoints = trainer_parameters.get("keep_checkpoints", 5) + self.graph = tf.Graph() + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self.sess = tf.Session(config=config, graph=self.graph) + self.saver = None if self.use_recurrent: self.m_size = trainer_parameters["memory_size"] self.sequence_length = trainer_parameters["sequence_length"] @@ -51,6 +60,24 @@ def __init__(self, seed, brain, trainer_parameters, sess): "but it must be divisible by 4." .format(brain.brain_name, self.m_size)) + def _initialize_graph(self): + with self.graph.as_default(): + self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) + init = tf.global_variables_initializer() + self.sess.run(init) + + def _load_graph(self): + with self.graph.as_default(): + self.saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) + logger.info('Loading Model for brain {}'.format(self.brain.brain_name)) + ckpt = tf.train.get_checkpoint_state(self.model_path) + if ckpt is None: + logger.info('The model {0} could not be found. Make ' + 'sure you specified the right ' + '--run-id' + .format(self.model_path)) + self.saver.restore(self.sess, ckpt.model_checkpoint_path) + def evaluate(self, brain_info): """ Evaluates policy for the agent experiences provided. @@ -96,13 +123,6 @@ def make_empty_memory(self, num_agents): """ return np.zeros((num_agents, self.m_size)) - @property - def graph_scope(self): - """ - Returns the graph scope of the trainer. - """ - return self.variable_scope - def get_current_step(self): """ Gets current model step. @@ -129,6 +149,47 @@ def get_update_vars(self): """ return list(self.update_dict.keys()) + def save_model(self, steps): + """ + Saves the model + :param steps: The number of steps the model was trained for + :return: + """ + with self.graph.as_default(): + last_checkpoint = self.model_path + '/model-' + str(steps) + '.cptk' + self.saver.save(self.sess, last_checkpoint) + tf.train.write_graph(self.graph, self.model_path, + 'raw_graph_def.pb', as_text=False) + + def export_model(self): + """ + Exports latest saved model to .bytes format for Unity embedding. + """ + with self.graph.as_default(): + target_nodes = ','.join(self._process_graph()) + ckpt = tf.train.get_checkpoint_state(self.model_path) + freeze_graph.freeze_graph( + input_graph=self.model_path + '/raw_graph_def.pb', + input_binary=True, + input_checkpoint=ckpt.model_checkpoint_path, + output_node_names=target_nodes, + output_graph=(self.model_path + '.bytes'), + clear_devices=True, initializer_nodes='', input_saver='', + restore_op_name='save/restore_all', + filename_tensor_name='save/Const:0') + + def _process_graph(self): + """ + Gets the list of the output nodes present in the graph for inference + :return: list of node names + """ + all_nodes = [x.name for x in self.graph.as_graph_def().node] + nodes = [x for x in all_nodes if x in self.possible_output_nodes] + logger.info('List of nodes to export for brain :' + self.brain.brain_name) + for n in nodes: + logger.info('\t' + n) + return nodes + @property def vis_obs_size(self): return self.model.vis_obs_size diff --git a/ml-agents/mlagents/trainers/ppo/models.py b/ml-agents/mlagents/trainers/ppo/models.py index 86cf0d03dd..0c7b13329e 100644 --- a/ml-agents/mlagents/trainers/ppo/models.py +++ b/ml-agents/mlagents/trainers/ppo/models.py @@ -10,7 +10,7 @@ class PPOModel(LearningModel): def __init__(self, brain, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step=5e6, normalize=False, use_recurrent=False, num_layers=2, m_size=None, use_curiosity=False, - curiosity_strength=0.01, curiosity_enc_size=128, scope='Model', seed=0): + curiosity_strength=0.01, curiosity_enc_size=128, seed=0): """ Takes a Unity environment and model-specific hyper-parameters and returns the appropriate PPO agent model for the environment. @@ -26,25 +26,24 @@ def __init__(self, brain, lr=1e-4, h_size=128, epsilon=0.2, beta=1e-3, max_step= :param num_layers Number of hidden layers between encoded input and policy & value layers :param m_size: Size of brain memory. """ - with tf.variable_scope(scope): - LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed) - self.use_curiosity = use_curiosity - if num_layers < 1: - num_layers = 1 - self.last_reward, self.new_reward, self.update_reward = self.create_reward_encoder() - if brain.vector_action_space_type == "continuous": - self.create_cc_actor_critic(h_size, num_layers) - self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy - else: - self.create_dc_actor_critic(h_size, num_layers) - if self.use_curiosity: - self.curiosity_enc_size = curiosity_enc_size - self.curiosity_strength = curiosity_strength - encoded_state, encoded_next_state = self.create_curiosity_encoders() - self.create_inverse_model(encoded_state, encoded_next_state) - self.create_forward_model(encoded_state, encoded_next_state) - self.create_ppo_optimizer(self.log_probs, self.old_log_probs, self.value, - self.entropy, beta, epsilon, lr, max_step) + LearningModel.__init__(self, m_size, normalize, use_recurrent, brain, seed) + self.use_curiosity = use_curiosity + if num_layers < 1: + num_layers = 1 + self.last_reward, self.new_reward, self.update_reward = self.create_reward_encoder() + if brain.vector_action_space_type == "continuous": + self.create_cc_actor_critic(h_size, num_layers) + self.entropy = tf.ones_like(tf.reshape(self.value, [-1])) * self.entropy + else: + self.create_dc_actor_critic(h_size, num_layers) + if self.use_curiosity: + self.curiosity_enc_size = curiosity_enc_size + self.curiosity_strength = curiosity_strength + encoded_state, encoded_next_state = self.create_curiosity_encoders() + self.create_inverse_model(encoded_state, encoded_next_state) + self.create_forward_model(encoded_state, encoded_next_state) + self.create_ppo_optimizer(self.log_probs, self.old_log_probs, self.value, + self.entropy, beta, epsilon, lr, max_step) @staticmethod def create_reward_encoder(): diff --git a/ml-agents/mlagents/trainers/ppo/policy.py b/ml-agents/mlagents/trainers/ppo/policy.py index d810e89eff..f49f14ccf9 100644 --- a/ml-agents/mlagents/trainers/ppo/policy.py +++ b/ml-agents/mlagents/trainers/ppo/policy.py @@ -1,6 +1,5 @@ import logging -import numpy as np from mlagents.trainers.ppo.models import PPOModel from mlagents.trainers.policy import Policy @@ -8,32 +7,39 @@ class PPOPolicy(Policy): - def __init__(self, seed, brain, trainer_params, sess, is_training): + def __init__(self, seed, brain, trainer_params, is_training, load): """ Policy for Proximal Policy Optimization Networks. :param seed: Random seed. :param brain: Assigned Brain object. :param trainer_params: Defined training parameters. - :param sess: TensorFlow session. :param is_training: Whether the model should be trained. + :param load: Whether a pre-trained model will be loaded or a new one created. """ - super().__init__(seed, brain, trainer_params, sess) + super().__init__(seed, brain, trainer_params) self.has_updated = False self.use_curiosity = bool(trainer_params['use_curiosity']) - self.model = PPOModel(brain, - lr=float(trainer_params['learning_rate']), - h_size=int(trainer_params['hidden_units']), - epsilon=float(trainer_params['epsilon']), - beta=float(trainer_params['beta']), - max_step=float(trainer_params['max_steps']), - normalize=trainer_params['normalize'], - use_recurrent=trainer_params['use_recurrent'], - num_layers=int(trainer_params['num_layers']), - m_size=self.m_size, - use_curiosity=bool(trainer_params['use_curiosity']), - curiosity_strength=float(trainer_params['curiosity_strength']), - curiosity_enc_size=float(trainer_params['curiosity_enc_size']), - scope=self.variable_scope, seed=seed) + + with self.graph.as_default(): + self.model = PPOModel(brain, + lr=float(trainer_params['learning_rate']), + h_size=int(trainer_params['hidden_units']), + epsilon=float(trainer_params['epsilon']), + beta=float(trainer_params['beta']), + max_step=float(trainer_params['max_steps']), + normalize=trainer_params['normalize'], + use_recurrent=trainer_params['use_recurrent'], + num_layers=int(trainer_params['num_layers']), + m_size=self.m_size, + use_curiosity=bool(trainer_params['use_curiosity']), + curiosity_strength=float(trainer_params['curiosity_strength']), + curiosity_enc_size=float(trainer_params['curiosity_enc_size']), + seed=seed) + + if load: + self._load_graph() + else: + self._initialize_graph() self.inference_dict = {'action': self.model.output, 'log_probs': self.model.all_log_probs, 'value': self.model.value, 'entropy': self.model.entropy, diff --git a/ml-agents/mlagents/trainers/ppo/trainer.py b/ml-agents/mlagents/trainers/ppo/trainer.py index 80040f3944..2d83c7dd9e 100644 --- a/ml-agents/mlagents/trainers/ppo/trainer.py +++ b/ml-agents/mlagents/trainers/ppo/trainer.py @@ -20,32 +20,33 @@ class PPOTrainer(Trainer): """The PPOTrainer is an implementation of the PPO algorithm.""" - def __init__(self, sess, brain, reward_buff_cap, trainer_parameters, training, seed, run_id): + def __init__(self, brain, reward_buff_cap, trainer_parameters, training, load, seed, run_id): """ Responsible for collecting experiences and training PPO model. - :param sess: Tensorflow session. - :param trainer_parameters: The parameters for the trainer (dictionary). + :param trainer_parameters: The parameters for the trainer (dictionary). :param training: Whether the trainer is set for training. + :param load: Whether the model should be loaded. + :param seed: The seed the model will be initialized with + :param run_id: The The identifier of the current run """ - super(PPOTrainer, self).__init__(sess, brain.brain_name, trainer_parameters, training, run_id) - self.param_keys = ['batch_size', 'beta', 'buffer_size', 'epsilon', 'gamma', 'hidden_units', 'lambd', 'learning_rate', 'max_steps', 'normalize', 'num_epoch', 'num_layers', 'time_horizon', 'sequence_length', 'summary_freq', 'use_recurrent', - 'graph_scope', 'summary_path', 'memory_size', 'use_curiosity', 'curiosity_strength', - 'curiosity_enc_size'] + 'summary_path', 'memory_size', 'use_curiosity', 'curiosity_strength', + 'curiosity_enc_size', 'model_path'] for k in self.param_keys: if k not in trainer_parameters: raise UnityTrainerException("The hyperparameter {0} could not be found for the PPO trainer of " "brain {1}.".format(k, brain.brain_name)) + super(PPOTrainer, self).__init__(brain.brain_name, trainer_parameters, training, run_id) self.use_curiosity = bool(trainer_parameters['use_curiosity']) self.step = 0 self.policy = PPOPolicy(seed, brain, trainer_parameters, - sess, self.is_training) + self.is_training, load) stats = {'cumulative_reward': [], 'episode_length': [], 'value_estimate': [], 'entropy': [], 'value_loss': [], 'policy_loss': [], 'learning_rate': []} diff --git a/ml-agents/mlagents/trainers/trainer.py b/ml-agents/mlagents/trainers/trainer.py index c8edc3d1c8..8c18f0b838 100644 --- a/ml-agents/mlagents/trainers/trainer.py +++ b/ml-agents/mlagents/trainers/trainer.py @@ -19,20 +19,20 @@ class UnityTrainerException(UnityException): class Trainer(object): """This class is the abstract class for the mlagents.trainers""" - def __init__(self, sess, brain_name, trainer_parameters, training, run_id): + def __init__(self, brain_name, trainer_parameters, training, run_id): """ Responsible for collecting experiences and training a neural network model. - :param sess: Tensorflow session. :param trainer_parameters: The parameters for the trainer (dictionary). :param training: Whether the trainer is set for training. + :param run_id: The identifier of the current run """ - self.sess = sess self.brain_name = brain_name self.run_id = run_id self.trainer_parameters = trainer_parameters self.is_training = training self.stats = {} self.summary_writer = None + self.policy = None def __str__(self): return '''Empty Trainer''' @@ -130,6 +130,19 @@ def update_policy(self): """ raise UnityTrainerException("The update_model method was not implemented.") + def save_model(self, steps): + """ + Saves the model + :param steps: The number of steps of training + """ + self.policy.save_model(steps) + + def export_model(self): + """ + Exports the model + """ + self.policy.export_model() + def write_summary(self, global_step, lesson_num=0): """ Saves training statistics to Tensorboard. @@ -166,10 +179,11 @@ def write_tensorboard_text(self, key, input_dict): :param input_dict: A dictionary that will be displayed in a table on Tensorboard. """ try: - s_op = tf.summary.text(key, tf.convert_to_tensor( - ([[str(x), str(input_dict[x])] for x in input_dict]))) - s = self.sess.run(s_op) - self.summary_writer.add_summary(s, self.get_step) + with tf.Session() as sess: + s_op = tf.summary.text(key, tf.convert_to_tensor( + ([[str(x), str(input_dict[x])] for x in input_dict]))) + s = sess.run(s_op) + self.summary_writer.add_summary(s, self.get_step) except: logger.info( "Cannot write text summary for Tensorboard. Tensorflow version must be r1.2 or above.") diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index b4d2e4f854..47d7e97c2c 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -140,81 +140,38 @@ def _get_measure_vals(self): else: return None - def _process_graph(self): - nodes = [] - scopes = [] - for brain_name in self.trainers.keys(): - if self.trainers[brain_name].policy.graph_scope is not None: - scope = self.trainers[brain_name].policy.graph_scope + '/' - if scope == '/': - scope = '' - scopes += [scope] - if self.trainers[brain_name].parameters['trainer'] \ - == 'imitation': - nodes += [scope + x for x in ['action']] - else: - nodes += [scope + x for x in ['action', 'value_estimate', - 'action_probs', - 'value_estimate']] - if self.trainers[brain_name].parameters['use_recurrent']: - nodes += [scope + x for x in ['recurrent_out', - 'memory_size']] - if len(scopes) > 1: - self.logger.info('List of available scopes :') - for scope in scopes: - self.logger.info('\t' + scope) - self.logger.info('List of nodes to export :') - for n in nodes: - self.logger.info('\t' + n) - return nodes - - def _save_model(self, sess, saver, steps=0): + def _save_model(self,steps=0): """ Saves current model to checkpoint folder. - :param sess: Current Tensorflow session. :param steps: Current number of steps in training process. :param saver: Tensorflow saver for session. """ - last_checkpoint = self.model_path + '/model-' + str(steps) + '.cptk' - saver.save(sess, last_checkpoint) - tf.train.write_graph(sess.graph_def, self.model_path, - 'raw_graph_def.pb', as_text=False) + for brain_name in self.trainers.keys(): + self.trainers[brain_name].save_model(steps) self.logger.info('Saved Model') def _export_graph(self): """ - Exports latest saved model to .bytes format for Unity embedding. + Exports latest saved models to .bytes format for Unity embedding. """ - target_nodes = ','.join(self._process_graph()) - ckpt = tf.train.get_checkpoint_state(self.model_path) - freeze_graph.freeze_graph( - input_graph=self.model_path + '/raw_graph_def.pb', - input_binary=True, - input_checkpoint=ckpt.model_checkpoint_path, - output_node_names=target_nodes, - output_graph=(self.model_path + '/' + self.env_name + '_' - + self.run_id + '.bytes'), - clear_devices=True, initializer_nodes='', input_saver='', - restore_op_name='save/restore_all', - filename_tensor_name='save/Const:0') + for brain_name in self.trainers.keys(): + self.trainers[brain_name].export_model() - def _initialize_trainers(self, trainer_config, sess): + def _initialize_trainers(self, trainer_config): + """ + Initialization of the trainers + :param trainer_config: The configurations of the trainers + """ trainer_parameters_dict = {} - # TODO: This probably doesn't need to be reinitialized. - self.trainers = {} for brain_name in self.env.external_brain_names: trainer_parameters = trainer_config['default'].copy() - if len(self.env.external_brain_names) > 1: - graph_scope = re.sub('[^0-9a-zA-Z]+', '-', brain_name) - trainer_parameters['graph_scope'] = graph_scope - trainer_parameters['summary_path'] = '{basedir}/{name}'.format( - basedir=self.summaries_dir, - name=str(self.run_id) + '_' + graph_scope) - else: - trainer_parameters['graph_scope'] = '' - trainer_parameters['summary_path'] = '{basedir}/{name}'.format( - basedir=self.summaries_dir, - name=str(self.run_id)) + trainer_parameters['summary_path'] = '{basedir}/{name}'.format( + basedir=self.summaries_dir, + name=str(self.run_id) + '_' + brain_name) + trainer_parameters['model_path'] = '{basedir}/{name}'.format( + basedir=self.model_path, + name=brain_name) + trainer_parameters['keep_checkpoints'] = self.keep_checkpoints if brain_name in trainer_config: _brain_key = brain_name while not isinstance(trainer_config[_brain_key], dict): @@ -225,17 +182,17 @@ def _initialize_trainers(self, trainer_config, sess): for brain_name in self.env.external_brain_names: if trainer_parameters_dict[brain_name]['trainer'] == 'imitation': self.trainers[brain_name] = BehavioralCloningTrainer( - sess, self.env.brains[brain_name], + self.env.brains[brain_name], trainer_parameters_dict[brain_name], self.train_model, - self.seed, self.run_id) + self.load_model, self.seed, self.run_id) elif trainer_parameters_dict[brain_name]['trainer'] == 'ppo': self.trainers[brain_name] = PPOTrainer( - sess, self.env.brains[brain_name], + self.env.brains[brain_name], self.meta_curriculum .brains_to_curriculums[brain_name] .min_lesson_length if self.meta_curriculum else 0, trainer_parameters_dict[brain_name], - self.train_model, self.seed, self.run_id) + self.train_model, self.load_model, self.seed, self.run_id) else: raise UnityEnvironmentException('The trainer config contains ' 'an unknown trainer type for ' @@ -292,117 +249,100 @@ def start_learning(self): tf.reset_default_graph() # Prevent a single session from taking all GPU memory. - config = tf.ConfigProto() - config.gpu_options.allow_growth = True - with tf.Session(config=config) as sess: - self._initialize_trainers(trainer_config, sess) - for _, t in self.trainers.items(): - self.logger.info(t) - init = tf.global_variables_initializer() - saver = tf.train.Saver(max_to_keep=self.keep_checkpoints) - # Instantiate model parameters - if self.load_model: - self.logger.info('Loading Model...') - ckpt = tf.train.get_checkpoint_state(self.model_path) - if ckpt is None: - self.logger.info('The model {0} could not be found. Make ' - 'sure you specified the right ' - '--run-id' - .format(self.model_path)) - saver.restore(sess, ckpt.model_checkpoint_path) - else: - sess.run(init) - global_step = 0 # This is only for saving the model - curr_info = self._reset_env() - if self.train_model: - for brain_name, trainer in self.trainers.items(): - trainer.write_tensorboard_text('Hyperparameters', - trainer.parameters) - try: - while any([t.get_step <= t.get_max_steps \ - for k, t in self.trainers.items()]) \ - or not self.train_model: - if self.meta_curriculum: - # Get the sizes of the reward buffers. - reward_buff_sizes = {k:len(t.reward_buffer) \ - for (k,t) in self.trainers.items()} - # Attempt to increment the lessons of the brains who - # were ready. - lessons_incremented = \ - self.meta_curriculum.increment_lessons( - self._get_measure_vals(), - reward_buff_sizes=reward_buff_sizes) - - # If any lessons were incremented or the environment is - # ready to be reset - if (self.meta_curriculum - and any(lessons_incremented.values())): - curr_info = self._reset_env() - for brain_name, trainer in self.trainers.items(): - trainer.end_episode() - for brain_name, changed in lessons_incremented.items(): - if changed: - self.trainers[brain_name].reward_buffer.clear() - elif self.env.global_done: - curr_info = self._reset_env() - for brain_name, trainer in self.trainers.items(): - trainer.end_episode() + self._initialize_trainers(trainer_config) + for _, t in self.trainers.items(): + self.logger.info(t) + global_step = 0 # This is only for saving the model + curr_info = self._reset_env() + if self.train_model: + for brain_name, trainer in self.trainers.items(): + trainer.write_tensorboard_text('Hyperparameters', + trainer.parameters) + try: + while any([t.get_step <= t.get_max_steps \ + for k, t in self.trainers.items()]) \ + or not self.train_model: + if self.meta_curriculum: + # Get the sizes of the reward buffers. + reward_buff_sizes = {k:len(t.reward_buffer) \ + for (k,t) in self.trainers.items()} + # Attempt to increment the lessons of the brains who + # were ready. + lessons_incremented = \ + self.meta_curriculum.increment_lessons( + self._get_measure_vals(), + reward_buff_sizes=reward_buff_sizes) - # Decide and take an action - take_action_vector, \ - take_action_memories, \ - take_action_text, \ - take_action_value, \ - take_action_outputs \ - = {}, {}, {}, {}, {} + # If any lessons were incremented or the environment is + # ready to be reset + if (self.meta_curriculum + and any(lessons_incremented.values())): + curr_info = self._reset_env() for brain_name, trainer in self.trainers.items(): - (take_action_vector[brain_name], - take_action_memories[brain_name], - take_action_text[brain_name], - take_action_value[brain_name], - take_action_outputs[brain_name]) = \ - trainer.take_action(curr_info) - new_info = self.env.step(vector_action=take_action_vector, - memory=take_action_memories, - text_action=take_action_text, - value=take_action_value) + trainer.end_episode() + for brain_name, changed in lessons_incremented.items(): + if changed: + self.trainers[brain_name].reward_buffer.clear() + elif self.env.global_done: + curr_info = self._reset_env() for brain_name, trainer in self.trainers.items(): - trainer.add_experiences(curr_info, new_info, - take_action_outputs[brain_name]) - trainer.process_experiences(curr_info, new_info) - if trainer.is_ready_update() and self.train_model \ - and trainer.get_step <= trainer.get_max_steps: - # Perform gradient descent with experience buffer - trainer.update_policy() - # Write training statistics to Tensorboard. - if self.meta_curriculum is not None: - trainer.write_summary( - global_step, - lesson_num=self.meta_curriculum - .brains_to_curriculums[brain_name] - .lesson_num) - else: - trainer.write_summary(global_step) - if self.train_model \ - and trainer.get_step <= trainer.get_max_steps: - trainer.increment_step_and_update_last_reward() - global_step += 1 - if global_step % self.save_freq == 0 and global_step != 0 \ - and self.train_model: - # Save Tensorflow model - self._save_model(sess, steps=global_step, saver=saver) - curr_info = new_info - # Final save Tensorflow model - if global_step != 0 and self.train_model: - self._save_model(sess, steps=global_step, saver=saver) - except KeyboardInterrupt: - print('--------------------------Now saving model--------------' - '-----------') - if self.train_model: - self.logger.info('Learning was interrupted. Please wait ' - 'while the graph is generated.') - self._save_model(sess, steps=global_step, saver=saver) - pass + trainer.end_episode() + + # Decide and take an action + take_action_vector, \ + take_action_memories, \ + take_action_text, \ + take_action_value, \ + take_action_outputs \ + = {}, {}, {}, {}, {} + for brain_name, trainer in self.trainers.items(): + (take_action_vector[brain_name], + take_action_memories[brain_name], + take_action_text[brain_name], + take_action_value[brain_name], + take_action_outputs[brain_name]) = \ + trainer.take_action(curr_info) + new_info = self.env.step(vector_action=take_action_vector, + memory=take_action_memories, + text_action=take_action_text, + value=take_action_value) + for brain_name, trainer in self.trainers.items(): + trainer.add_experiences(curr_info, new_info, + take_action_outputs[brain_name]) + trainer.process_experiences(curr_info, new_info) + if trainer.is_ready_update() and self.train_model \ + and trainer.get_step <= trainer.get_max_steps: + # Perform gradient descent with experience buffer + trainer.update_policy() + # Write training statistics to Tensorboard. + if self.meta_curriculum is not None: + trainer.write_summary( + global_step, + lesson_num=self.meta_curriculum + .brains_to_curriculums[brain_name] + .lesson_num) + else: + trainer.write_summary(global_step) + if self.train_model \ + and trainer.get_step <= trainer.get_max_steps: + trainer.increment_step_and_update_last_reward() + global_step += 1 + if global_step % self.save_freq == 0 and global_step != 0 \ + and self.train_model: + # Save Tensorflow model + self._save_model(steps=global_step) + curr_info = new_info + # Final save Tensorflow model + if global_step != 0 and self.train_model: + self._save_model(steps=global_step) + except KeyboardInterrupt: + print('--------------------------Now saving model--------------' + '-----------') + if self.train_model: + self.logger.info('Learning was interrupted. Please wait ' + 'while the graph is generated.') + self._save_model(steps=global_step) + pass self.env.close() if self.train_model: self._export_graph() diff --git a/ml-agents/tests/trainers/test_bc.py b/ml-agents/tests/trainers/test_bc.py index ede499aeaa..31ad019b43 100644 --- a/ml-agents/tests/trainers/test_bc.py +++ b/ml-agents/tests/trainers/test_bc.py @@ -28,21 +28,19 @@ def dummy_config(): @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') def test_bc_policy_evaluate(mock_communicator, mock_launcher): tf.reset_default_graph() - with tf.Session() as sess: - mock_communicator.return_value = MockCommunicator( - discrete_action=False, visual_inputs=0) - env = UnityEnvironment(' ') - brain_infos = env.reset() - brain_info = brain_infos[env.brain_names[0]] - - trainer_parameters = dummy_config() - graph_scope = env.brain_names[0] - trainer_parameters['graph_scope'] = graph_scope - policy = BCPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, sess) - init = tf.global_variables_initializer() - sess.run(init) - run_out = policy.evaluate(brain_info) - assert run_out['action'].shape == (3, 2) + mock_communicator.return_value = MockCommunicator( + discrete_action=False, visual_inputs=0) + env = UnityEnvironment(' ') + brain_infos = env.reset() + brain_info = brain_infos[env.brain_names[0]] + + trainer_parameters = dummy_config() + model_path = env.brain_names[0] + trainer_parameters['model_path'] = model_path + trainer_parameters['keep_checkpoints'] = 3 + policy = BCPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, False) + run_out = policy.evaluate(brain_info) + assert run_out['action'].shape == (3, 2) env.close() diff --git a/ml-agents/tests/trainers/test_ppo.py b/ml-agents/tests/trainers/test_ppo.py index d13d8a38db..061215600c 100644 --- a/ml-agents/tests/trainers/test_ppo.py +++ b/ml-agents/tests/trainers/test_ppo.py @@ -44,22 +44,20 @@ def dummy_config(): @mock.patch('mlagents.envs.UnityEnvironment.get_communicator') def test_ppo_policy_evaluate(mock_communicator, mock_launcher): tf.reset_default_graph() - with tf.Session() as sess: - mock_communicator.return_value = MockCommunicator( - discrete_action=False, visual_inputs=0) - env = UnityEnvironment(' ') - brain_infos = env.reset() - brain_info = brain_infos[env.brain_names[0]] - - trainer_parameters = dummy_config() - graph_scope = env.brain_names[0] - trainer_parameters['graph_scope'] = graph_scope - policy = PPOPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, sess, False) - init = tf.global_variables_initializer() - sess.run(init) - run_out = policy.evaluate(brain_info) - assert run_out['action'].shape == (3, 2) - env.close() + mock_communicator.return_value = MockCommunicator( + discrete_action=False, visual_inputs=0) + env = UnityEnvironment(' ') + brain_infos = env.reset() + brain_info = brain_infos[env.brain_names[0]] + + trainer_parameters = dummy_config() + model_path = env.brain_names[0] + trainer_parameters['model_path'] = model_path + trainer_parameters['keep_checkpoints'] = 3 + policy = PPOPolicy(0, env.brains[env.brain_names[0]], trainer_parameters, False, False) + run_out = policy.evaluate(brain_info) + assert run_out['action'].shape == (3, 2) + env.close() @mock.patch('mlagents.envs.UnityEnvironment.executable_launcher') diff --git a/ml-agents/tests/trainers/test_trainer_controller.py b/ml-agents/tests/trainers/test_trainer_controller.py index d7e870b71f..493fcc0d5b 100644 --- a/ml-agents/tests/trainers/test_trainer_controller.py +++ b/ml-agents/tests/trainers/test_trainer_controller.py @@ -158,7 +158,7 @@ def test_initialize_trainers(mock_communicator, mock_launcher, dummy_config, with mock.patch(open_name, create=True) as _: mock_communicator.return_value = MockCommunicator( discrete_action=True, visual_inputs=1) - tc = TrainerController(' ', ' ', 1, None, True, True, False, 1, 1, + tc = TrainerController(' ', ' ', 1, None, True, False, False, 1, 1, 1, 1, '', "tests/test_mlagents.trainers.py", False) @@ -166,23 +166,20 @@ def test_initialize_trainers(mock_communicator, mock_launcher, dummy_config, mock_load.return_value = dummy_config config = tc._load_config() tf.reset_default_graph() - with tf.Session() as sess: - tc._initialize_trainers(config, sess) - assert(len(tc.trainers) == 1) - assert(isinstance(tc.trainers['RealFakeBrain'], PPOTrainer)) + tc._initialize_trainers(config) + assert(len(tc.trainers) == 1) + assert(isinstance(tc.trainers['RealFakeBrain'], PPOTrainer)) # Test for Behavior Cloning Trainer mock_load.return_value = dummy_bc_config config = tc._load_config() tf.reset_default_graph() - with tf.Session() as sess: - tc._initialize_trainers(config, sess) - assert(isinstance(tc.trainers['RealFakeBrain'], BehavioralCloningTrainer)) + tc._initialize_trainers(config) + assert(isinstance(tc.trainers['RealFakeBrain'], BehavioralCloningTrainer)) # Test for proper exception when trainer name is incorrect mock_load.return_value = dummy_bad_config config = tc._load_config() tf.reset_default_graph() - with tf.Session() as sess: - with pytest.raises(UnityEnvironmentException): - tc._initialize_trainers(config, sess) + with pytest.raises(UnityEnvironmentException): + tc._initialize_trainers(config)