diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index 2143afb26..aa773ea54 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -49,15 +49,14 @@ def get_scheduler_factories( The first scheduler in the dictionary is used as the default scheduler. """ - default_schedulers: dict[str, SchedulerFactory] = {} - for scheduler, path in DEFAULT_SCHEDULER_MODULES.items(): - default_schedulers[scheduler] = _defer_load_scheduler(path) - - return load_group( - group, - default=default_schedulers, - skip_defaults=skip_defaults, - ) + if skip_defaults: + default_schedulers = {} + else: + default_schedulers: dict[str, SchedulerFactory] = {} + for scheduler, path in DEFAULT_SCHEDULER_MODULES.items(): + default_schedulers[scheduler] = _defer_load_scheduler(path) + + return load_group(group, default=default_schedulers) def get_default_scheduler_name() -> str: diff --git a/torchx/util/entrypoints.py b/torchx/util/entrypoints.py index f3bfcac70..4e2379836 100644 --- a/torchx/util/entrypoints.py +++ b/torchx/util/entrypoints.py @@ -69,9 +69,7 @@ def run(*args: object, **kwargs: object) -> object: return run -def load_group( - group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False -): +def load_group(group: str, default: Optional[Dict[str, Any]] = None): """ Loads all the entry points specified by ``group`` and returns the entry points as a map of ``name (str) -> deferred_load_fn``. @@ -90,7 +88,6 @@ def load_group( 1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")`` 1. ``load_group("food")`` -> ``None`` 1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")`` - 1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None`` If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn`` @@ -115,8 +112,6 @@ def load_group( entrypoints = metadata.entry_points().get(group, ()) if len(entrypoints) == 0: - if skip_defaults: - return None return default eps = {} diff --git a/torchx/util/test/entrypoints_test.py b/torchx/util/test/entrypoints_test.py index e6327168c..a4e11a53e 100644 --- a/torchx/util/test/entrypoints_test.py +++ b/torchx/util/test/entrypoints_test.py @@ -134,11 +134,6 @@ def test_load_group_with_default(self, _: MagicMock) -> None: self.assertEqual("barbaz", eps["foo"]()) self.assertEqual("foobar", eps["bar"]()) - eps = load_group( - "ep.grp.test.missing", {"foo": barbaz, "bar": foobar}, skip_defaults=True - ) - self.assertIsNone(eps) - @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) def test_load_group_missing(self, _: MagicMock) -> None: with self.assertRaises(AttributeError):