Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ and this project adheres to
### Major Changes
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)
- PyTorch trainers are now the default. See the
[installation docs](https://github.com/Unity-Technologies/ml-agents/blob/mastere/docs/Installation.md) for
more information on installing PyTorch. For the time being, TensorFlow is still available;
you can use the TensorFlow backend by adding `--tensorflow` to the CLI, or
adding `framework: tensorflow` in the configuration YAML. (#4517)

### Minor Changes
#### com.unity.ml-agents (C#)
Expand Down
11 changes: 9 additions & 2 deletions ml-agents/mlagents/trainers/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,15 @@ def _create_parser() -> argparse.ArgumentParser:
"--torch",
default=False,
action=DetectDefaultStoreTrue,
help="(Experimental) Use the PyTorch framework instead of TensorFlow. Install PyTorch "
"before using this option",
help="Use the PyTorch framework. Note that this option is not required anymore as PyTorch is the"
"default framework, and will be removed in the next release.",
)
argparser.add_argument(
"--tensorflow",
default=False,
action=DetectDefaultStoreTrue,
help="(Deprecated) Use the TensorFlow framework instead of PyTorch. Install TensorFlow "
"before using this option.",
)

eng_conf = argparser.add_argument_group(title="Engine Configuration")
Expand Down
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def run_training(run_seed: int, options: RunOptions) -> None:
init_path=maybe_init_path,
multi_gpu=False,
force_torch="torch" in DetectDefault.non_default_args,
force_tensorflow="tensorflow" in DetectDefault.non_default_args,
)
# Create controller and begin training.
tc = TrainerController(
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def _set_default_hyperparameters(self):
threaded: bool = True
self_play: Optional[SelfPlaySettings] = None
behavioral_cloning: Optional[BehavioralCloningSettings] = None
framework: FrameworkType = FrameworkType.TENSORFLOW
framework: FrameworkType = FrameworkType.PYTORCH

cattr.register_structure_hook(
Dict[RewardSignalType, RewardSignalSettings], RewardSignalSettings.structure
Expand Down
18 changes: 17 additions & 1 deletion ml-agents/mlagents/trainers/trainer/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
init_path: str = None,
multi_gpu: bool = False,
force_torch: bool = False,
force_tensorflow: bool = False,
):
"""
The TrainerFactory generates the Trainers based on the configuration passed as
Expand All @@ -45,7 +46,9 @@ def __init__(
:param init_path: Path from which to load model.
:param multi_gpu: If True, multi-gpu will be used. (currently not available)
:param force_torch: If True, the Trainers will all use the PyTorch framework
instead of the TensorFlow framework.
instead of what is specified in the config YAML.
:param force_tensorflow: If True, thee Trainers will all use the TensorFlow
framework.
"""
self.trainer_config = trainer_config
self.output_path = output_path
Expand All @@ -57,6 +60,7 @@ def __init__(
self.multi_gpu = multi_gpu
self.ghost_controller = GhostController()
self._force_torch = force_torch
self._force_tf = force_tensorflow

def generate(self, behavior_name: str) -> Trainer:
if behavior_name not in self.trainer_config.keys():
Expand All @@ -67,6 +71,18 @@ def generate(self, behavior_name: str) -> Trainer:
trainer_settings = self.trainer_config[behavior_name]
if self._force_torch:
trainer_settings.framework = FrameworkType.PYTORCH
logger.warning(
"Note that specifying --torch is not required anymore as PyTorch is the default framework."
)
if self._force_tf:
trainer_settings.framework = FrameworkType.TENSORFLOW
logger.warning(
"Setting the framework to TensorFlow. TensorFlow trainers will be deprecated in the future."
)
if self._force_torch:
logger.warning(
"Both --torch and --tensorflow CLI options were specified. Using TensorFlow."
)
return TrainerFactory._initialize_trainer(
trainer_settings,
behavior_name,
Expand Down