Skip to content

Commit 60f2bc2

Browse files
committed
feat: add custom named_resources support (#1085)
1 parent 1e3df20 commit 60f2bc2

File tree

3 files changed

+51
-3
lines changed

3 files changed

+51
-3
lines changed

docs/source/advanced.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,33 @@ resource can then be used in the following manner:
153153

154154
test_app("gpu_x2")
155155

156+
Alternatively, you can define custom named resources in a Python module and point
157+
to it using the ``TORCHX_CUSTOM_NAMED_RESOURCES`` environment variable:
158+
159+
.. code-block:: python
160+
161+
# my_resources.py
162+
from torchx.specs import Resource
163+
164+
def gpu_x8_efa() -> Resource:
165+
return Resource(cpu=100, gpu=8, memMB=819200, devices={"vpc.amazonaws.com/efa": 1})
166+
167+
def cpu_x32() -> Resource:
168+
return Resource(cpu=32, gpu=0, memMB=131072)
169+
170+
NAMED_RESOURCES = {
171+
"gpu_x8_efa": gpu_x8_efa,
172+
"cpu_x32": cpu_x32,
173+
}
174+
175+
Then set the environment variable:
176+
177+
.. code-block:: bash
178+
179+
export TORCHX_CUSTOM_NAMED_RESOURCES=my_resources
180+
181+
This allows you to use your custom resources without creating a package with entry points.
182+
156183

157184
Registering Custom Components
158185
-------------------------------

torchx/specs/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
scheduler or pipeline adapter.
1313
"""
1414
import difflib
15+
16+
import os
1517
from typing import Callable, Dict, Mapping, Optional
1618

1719
from torchx.specs.api import (
@@ -63,8 +65,10 @@
6365
GENERIC_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
6466
"torchx.specs.named_resources_generic", "NAMED_RESOURCES", default={}
6567
)
66-
FB_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
67-
"torchx.specs.fb.named_resources", "NAMED_RESOURCES", default={}
68+
CUSTOM_NAMED_RESOURCES: Mapping[str, ResourceFactory] = import_attr(
69+
os.environ.get("TORCHX_CUSTOM_NAMED_RESOURCES", "torchx.specs.fb.named_resources"),
70+
"NAMED_RESOURCES",
71+
default={},
6872
)
6973

7074

@@ -75,7 +79,7 @@ def _load_named_resources() -> Dict[str, Callable[[], Resource]]:
7579
for name, resource in {
7680
**GENERIC_NAMED_RESOURCES,
7781
**AWS_NAMED_RESOURCES,
78-
**FB_NAMED_RESOURCES,
82+
**CUSTOM_NAMED_RESOURCES,
7983
**resource_methods,
8084
}.items():
8185
materialized_resources[name] = resource

torchx/specs/test/named_resources_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-strict
99

1010

11+
import os
1112
import unittest
1213
from unittest.mock import MagicMock, patch
1314

@@ -47,3 +48,19 @@ def test_named_resources_library(self, mock_named_resources: MagicMock) -> None:
4748
def test_null_and_missing_named_resources(self) -> None:
4849
self.assertEqual(named_resources["NULL"], NULL_RESOURCE)
4950
self.assertEqual(named_resources["MISSING"], NULL_RESOURCE)
51+
52+
def test_custom_named_resources_env_var(self) -> None:
53+
import sys
54+
55+
mock_module = type(sys)("test_module")
56+
mock_module.NAMED_RESOURCES = {"test_resource": mock_resource}
57+
58+
with patch.dict(sys.modules, {"test_module": mock_module}):
59+
with patch(
60+
"torchx.specs.CUSTOM_NAMED_RESOURCES", mock_module.NAMED_RESOURCES
61+
):
62+
import torchx.specs
63+
64+
factories = torchx.specs._load_named_resources()
65+
66+
self.assertIn("test_resource", factories)

0 commit comments

Comments
 (0)