Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions ml-agents/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from setuptools import setup, find_packages
from setuptools.command.install import install
from setuptools.command.develop import develop
import mlagents.trainers

VERSION = mlagents.trainers.__version__
Expand Down Expand Up @@ -32,6 +33,37 @@ def run(self):
sys.exit(info)


def verify_torch_installed():
# Check that torch version 1.6.0 or later has been installed. If not, refer
# user to the PyTorch webpage for install instructions.
torch_pkg = None
try:
torch_pkg = pkg_resources.get_distribution("torch")
except pkg_resources.DistributionNotFound:
pass
assert torch_pkg is not None and LooseVersion(torch_pkg.version) >= LooseVersion(
"1.6.0"
), (
"A compatible version of PyTorch was not installed. Please visit the PyTorch homepage ",
"(https://pytorch.org/get-started/locally/) and follow the instructions to install. ",
"Version 1.6.0 and later are supported.",
)


class VerifyTorchInstallCommand(install):
description = "verify that Torch is installed"

def run(self):
verify_torch_installed()


class VerifyTorchDevelopCommand(develop):
description = "verify that Torch is installed"

def run(self):
verify_torch_installed()


# Get the long description from the README file
with open(os.path.join(here, "README.md"), encoding="utf-8") as f:
long_description = f.read()
Expand Down Expand Up @@ -79,19 +111,10 @@ def run(self):
"mlagents-run-experiment=mlagents.trainers.run_experiment:main",
]
},
cmdclass={"verify": VerifyVersionCommand},
cmdclass={
"verify": VerifyVersionCommand,
"install": VerifyTorchInstallCommand,
"develop": VerifyTorchDevelopCommand,
},
extras_require={"tensorflow": ["tensorflow>=1.14,<3.0", "six>=1.12.0"]},
)

# Check that torch version 1.6.0 or later has been installed. If not, refer
# user to the PyTorch webpage for install instructions.
torch_pkg = None
try:
torch_pkg = pkg_resources.get_distribution("torch")
except pkg_resources.DistributionNotFound:
pass
assert torch_pkg is not None and LooseVersion(torch_pkg.version) >= LooseVersion(
"1.6.0"
), "A compatible version of PyTorch was not installed. Please visit the PyTorch homepage \
(https://pytorch.org/get-started/locally/) and follow the instructions to install. \
Version 1.6.0 and later are supported."
4 changes: 2 additions & 2 deletions test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ pytest>4.0.0,<6.0.0
pytest-cov==2.6.1
pytest-xdist==1.34.0

# PyTorch tests are here for the time being, before they are used in the codebase.
torch>=1.5.0
# Tensorflow tests are here for the time being, before they are used in the codebase.
tensorflow>=1.14,<3.0

tf2onnx>=1.5.5