Skip to content

Commit 834e55a

Browse files
kiukchungfacebook-github-bot
authored andcommitted
(torchx/entrypoints) remove redundant skip_defaults flag from entrypoints.load() (#1140)
Summary: Removes the redunant (and potentially confusing) `skip_defaults` flag from: ``` from torchx.utils.entrypoints import load_group load_group(group, default, skip_defaults) ``` This is because the same can be achieved at the call-site by passing `default=None`. That is: ``` load_group(group_name, defaults=None if skip_defaults else default_value) ``` Coincidentally also fixes an illegal type return in `torchx.schedulers.get_scheduler_factories` (should always return a `dict[str, SchedulerFactory]` (an empty dict if none found) where passing `get_scheduler_factories(..., skip_defaults=True)` currently returns `None` which violates the return type hint. Differential Revision: D83991870
1 parent 1e3df20 commit 834e55a

File tree

3 files changed

+9
-20
lines changed

3 files changed

+9
-20
lines changed

torchx/schedulers/__init__.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@ def get_scheduler_factories(
4949
The first scheduler in the dictionary is used as the default scheduler.
5050
"""
5151

52-
default_schedulers: dict[str, SchedulerFactory] = {}
53-
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
54-
default_schedulers[scheduler] = _defer_load_scheduler(path)
55-
56-
return load_group(
57-
group,
58-
default=default_schedulers,
59-
skip_defaults=skip_defaults,
60-
)
52+
if skip_defaults:
53+
default_schedulers = {}
54+
else:
55+
default_schedulers: dict[str, SchedulerFactory] = {}
56+
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
57+
default_schedulers[scheduler] = _defer_load_scheduler(path)
58+
59+
return load_group(group, default=default_schedulers)
6160

6261

6362
def get_default_scheduler_name() -> str:

torchx/util/entrypoints.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,7 @@ def run(*args: object, **kwargs: object) -> object:
6969
return run
7070

7171

72-
def load_group(
73-
group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False
74-
):
72+
def load_group(group: str, default: Optional[Dict[str, Any]] = None):
7573
"""
7674
Loads all the entry points specified by ``group`` and returns
7775
the entry points as a map of ``name (str) -> deferred_load_fn``.
@@ -90,7 +88,6 @@ def load_group(
9088
1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")``
9189
1. ``load_group("food")`` -> ``None``
9290
1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")``
93-
1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None``
9491
9592
9693
If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn``
@@ -115,8 +112,6 @@ def load_group(
115112
entrypoints = metadata.entry_points().get(group, ())
116113

117114
if len(entrypoints) == 0:
118-
if skip_defaults:
119-
return None
120115
return default
121116

122117
eps = {}

torchx/util/test/entrypoints_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,6 @@ def test_load_group_with_default(self, _: MagicMock) -> None:
134134
self.assertEqual("barbaz", eps["foo"]())
135135
self.assertEqual("foobar", eps["bar"]())
136136

137-
eps = load_group(
138-
"ep.grp.test.missing", {"foo": barbaz, "bar": foobar}, skip_defaults=True
139-
)
140-
self.assertIsNone(eps)
141-
142137
@patch(_METADATA_EPS, return_value=_ENTRY_POINTS)
143138
def test_load_group_missing(self, _: MagicMock) -> None:
144139
with self.assertRaises(AttributeError):

0 commit comments

Comments
 (0)