From 834e55af15ec2261aed084bfcfc4afcf2cef40bc Mon Sep 17 00:00:00 2001 From: Kiuk Chung Date: Mon, 6 Oct 2025 12:24:51 -0700 Subject: [PATCH] (torchx/entrypoints) remove redundant skip_defaults flag from entrypoints.load() (#1140) Summary: Removes the redunant (and potentially confusing) `skip_defaults` flag from: ``` from torchx.utils.entrypoints import load_group load_group(group, default, skip_defaults) ``` This is because the same can be achieved at the call-site by passing `default=None`. That is: ``` load_group(group_name, defaults=None if skip_defaults else default_value) ``` Coincidentally also fixes an illegal type return in `torchx.schedulers.get_scheduler_factories` (should always return a `dict[str, SchedulerFactory]` (an empty dict if none found) where passing `get_scheduler_factories(..., skip_defaults=True)` currently returns `None` which violates the return type hint. Differential Revision: D83991870 --- torchx/schedulers/__init__.py | 17 ++++++++--------- torchx/util/entrypoints.py | 7 +------ torchx/util/test/entrypoints_test.py | 5 ----- 3 files changed, 9 insertions(+), 20 deletions(-) diff --git a/torchx/schedulers/__init__.py b/torchx/schedulers/__init__.py index 2143afb26..aa773ea54 100644 --- a/torchx/schedulers/__init__.py +++ b/torchx/schedulers/__init__.py @@ -49,15 +49,14 @@ def get_scheduler_factories( The first scheduler in the dictionary is used as the default scheduler. """ - default_schedulers: dict[str, SchedulerFactory] = {} - for scheduler, path in DEFAULT_SCHEDULER_MODULES.items(): - default_schedulers[scheduler] = _defer_load_scheduler(path) - - return load_group( - group, - default=default_schedulers, - skip_defaults=skip_defaults, - ) + if skip_defaults: + default_schedulers = {} + else: + default_schedulers: dict[str, SchedulerFactory] = {} + for scheduler, path in DEFAULT_SCHEDULER_MODULES.items(): + default_schedulers[scheduler] = _defer_load_scheduler(path) + + return load_group(group, default=default_schedulers) def get_default_scheduler_name() -> str: diff --git a/torchx/util/entrypoints.py b/torchx/util/entrypoints.py index f3bfcac70..4e2379836 100644 --- a/torchx/util/entrypoints.py +++ b/torchx/util/entrypoints.py @@ -69,9 +69,7 @@ def run(*args: object, **kwargs: object) -> object: return run -def load_group( - group: str, default: Optional[Dict[str, Any]] = None, skip_defaults: bool = False -): +def load_group(group: str, default: Optional[Dict[str, Any]] = None): """ Loads all the entry points specified by ``group`` and returns the entry points as a map of ``name (str) -> deferred_load_fn``. @@ -90,7 +88,6 @@ def load_group( 1. ``load_group("foo")["bar"]("baz")`` -> equivalent to calling ``this.is.a_fn("baz")`` 1. ``load_group("food")`` -> ``None`` 1. ``load_group("food", default={"hello": this.is.c_fn})["hello"]("world")`` -> equivalent to calling ``this.is.c_fn("world")`` - 1. ``load_group("food", default={"hello": this.is.c_fn}, skip_defaults=True)`` -> ``None`` If the entrypoint is a module (versus a function as shown above), then calling the ``deferred_load_fn`` @@ -115,8 +112,6 @@ def load_group( entrypoints = metadata.entry_points().get(group, ()) if len(entrypoints) == 0: - if skip_defaults: - return None return default eps = {} diff --git a/torchx/util/test/entrypoints_test.py b/torchx/util/test/entrypoints_test.py index e6327168c..a4e11a53e 100644 --- a/torchx/util/test/entrypoints_test.py +++ b/torchx/util/test/entrypoints_test.py @@ -134,11 +134,6 @@ def test_load_group_with_default(self, _: MagicMock) -> None: self.assertEqual("barbaz", eps["foo"]()) self.assertEqual("foobar", eps["bar"]()) - eps = load_group( - "ep.grp.test.missing", {"foo": barbaz, "bar": foobar}, skip_defaults=True - ) - self.assertIsNone(eps) - @patch(_METADATA_EPS, return_value=_ENTRY_POINTS) def test_load_group_missing(self, _: MagicMock) -> None: with self.assertRaises(AttributeError):