Skip to content

Commit 0217eef

Browse files
committed
refactor: [typing] add typing to the Scheduler
Types added to make it clearer what is going on with the job scheduling. This means the pyright/ruff LSP have more information to be able to jump to definitions and show documentation. Signed-off-by: James McCorrie <[email protected]>
1 parent 96ff3d5 commit 0217eef

File tree

5 files changed

+181
-134
lines changed

5 files changed

+181
-134
lines changed

src/dvsim/flow/base.py

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
import os
88
import pprint
99
import sys
10-
from collections.abc import Mapping
10+
from abc import ABC, abstractmethod
11+
from collections.abc import Mapping, Sequence
1112
from pathlib import Path
13+
from typing import ClassVar
1214

1315
import hjson
1416

@@ -27,7 +29,7 @@
2729

2830

2931
# Interface class for extensions.
30-
class FlowCfg:
32+
class FlowCfg(ABC):
3133
"""Base class for the different flows supported by dvsim.py.
3234
3335
The constructor expects some parsed hjson data. Create these objects with
@@ -41,9 +43,10 @@ class FlowCfg:
4143

4244
# Can be overridden in subclasses to configure which wildcards to ignore
4345
# when expanding hjson.
44-
ignored_wildcards = []
46+
ignored_wildcards: ClassVar = []
4547

4648
def __str__(self) -> str:
49+
"""Get string representation of the flow config."""
4750
return pprint.pformat(self.__dict__)
4851

4952
def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
@@ -87,7 +90,7 @@ def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
8790
# For a primary cfg, it is the aggregated list of all deploy objects
8891
# under self.cfgs. For a non-primary cfg, it is the list of items
8992
# slated for dispatch.
90-
self.deploy = []
93+
self.deploy: Sequence[Deploy] = []
9194

9295
# Timestamp
9396
self.timestamp_long = args.timestamp_long
@@ -98,7 +101,7 @@ def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
98101
self.rel_path = ""
99102
self.results_title = ""
100103
self.revision = ""
101-
self.css_file = os.path.join(Path(os.path.realpath(__file__)).parent, "style.css")
104+
self.css_file = Path(os.path.realpath(__file__)).parent / "style.css"
102105
# `self.results_*` below will be updated after `self.rel_path` and
103106
# `self.scratch_base_root` variables are updated.
104107
self.results_dir = ""
@@ -132,7 +135,9 @@ def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
132135
self._load_child_cfg(entry, mk_config)
133136

134137
if self.rel_path == "":
135-
self.rel_path = Path(self.flow_cfg_file).parent.replace(self.proj_root + "/", "")
138+
self.rel_path = str(
139+
Path(self.flow_cfg_file).parent.relative_to(self.proj_root),
140+
)
136141

137142
# Process overrides before substituting wildcards
138143
self._process_overrides()
@@ -149,7 +154,7 @@ def __init__(self, flow_cfg_file, hjson_data, args, mk_config) -> None:
149154
# Run any final checks
150155
self._post_init()
151156

152-
def _merge_hjson(self, hjson_data) -> None:
157+
def _merge_hjson(self, hjson_data: Mapping) -> None:
153158
"""Take hjson data and merge it into self.__dict__.
154159
155160
Subclasses that need to do something just before the merge should
@@ -160,7 +165,7 @@ def _merge_hjson(self, hjson_data) -> None:
160165
set_target_attribute(self.flow_cfg_file, self.__dict__, key, value)
161166

162167
def _expand(self) -> None:
163-
"""Called to expand wildcards after merging hjson.
168+
"""Expand wildcards after merging hjson.
164169
165170
Subclasses can override this to do something just before expansion.
166171
@@ -235,8 +240,9 @@ def _load_child_cfg(self, entry, mk_config) -> None:
235240
)
236241
sys.exit(1)
237242

238-
def _conv_inline_cfg_to_hjson(self, idict):
243+
def _conv_inline_cfg_to_hjson(self, idict: Mapping) -> str | None:
239244
"""Dump a temp hjson file in the scratch space from input dict.
245+
240246
This method is to be called only by a primary cfg.
241247
"""
242248
if not self.is_primary_cfg:
@@ -257,9 +263,11 @@ def _conv_inline_cfg_to_hjson(self, idict):
257263

258264
# Create the file and dump the dict as hjson
259265
log.verbose('Dumping inline cfg "%s" in hjson to:\n%s', name, temp_cfg_file)
266+
260267
try:
261268
Path(temp_cfg_file).write_text(hjson.dumps(idict, for_json=True))
262-
except Exception as e:
269+
270+
except Exception as e: # noqa: BLE001
263271
log.exception(
264272
'Failed to hjson-dump temp cfg file"%s" for "%s"(will be skipped!) due to:\n%s',
265273
temp_cfg_file,
@@ -330,6 +338,7 @@ def _do_override(self, ov_name: str, ov_value: object) -> None:
330338
log.error('Override key "%s" not found in the cfg!', ov_name)
331339
sys.exit(1)
332340

341+
@abstractmethod
333342
def _purge(self) -> None:
334343
"""Purge the existing scratch areas in preparation for the new run."""
335344

@@ -338,6 +347,7 @@ def purge(self) -> None:
338347
for item in self.cfgs:
339348
item._purge()
340349

350+
@abstractmethod
341351
def _print_list(self) -> None:
342352
"""Print the list of available items that can be kicked off."""
343353

@@ -368,12 +378,13 @@ def prune_selected_cfgs(self) -> None:
368378
# Filter configurations
369379
self.cfgs = [c for c in self.cfgs if c.name in self.select_cfgs]
370380

381+
@abstractmethod
371382
def _create_deploy_objects(self) -> None:
372383
"""Create deploy objects from items that were passed on for being run.
384+
373385
The deploy objects for build and run are created from the objects that
374386
were created from the create_objects() method.
375387
"""
376-
return
377388

378389
def create_deploy_objects(self) -> None:
379390
"""Public facing API for _create_deploy_objects()."""
@@ -387,7 +398,7 @@ def create_deploy_objects(self) -> None:
387398
for item in self.cfgs:
388399
item._create_deploy_objects()
389400

390-
def deploy_objects(self):
401+
def deploy_objects(self) -> Mapping[Deploy, str]:
391402
"""Public facing API for deploying all available objects.
392403
393404
Runs each job and returns a map from item to status.
@@ -400,21 +411,26 @@ def deploy_objects(self):
400411
log.error("Nothing to run!")
401412
sys.exit(1)
402413

403-
return Scheduler(deploy, get_launcher_cls(), self.interactive).run()
414+
return Scheduler(
415+
items=deploy,
416+
launcher_cls=get_launcher_cls(),
417+
interactive=self.interactive,
418+
).run()
404419

405-
def _gen_results(self, results: Mapping[Deploy, str]) -> None:
406-
"""Generate results.
420+
@abstractmethod
421+
def _gen_results(self, results: Mapping[Deploy, str]) -> str:
422+
"""Generate flow results.
407423
408-
The function is called after the flow has completed. It collates the
409-
status of all run targets and generates a dict. It parses the log
424+
The function is called after the flow has completed. It collates
425+
the status of all run targets and generates a dict. It parses the log
410426
to identify the errors, warnings and failures as applicable. It also
411427
prints the full list of failures for debug / triage to the final
412428
report, which is in markdown format.
413429
414430
results should be a dictionary mapping deployed item to result.
415431
"""
416432

417-
def gen_results(self, results) -> None:
433+
def gen_results(self, results: Mapping[Deploy, str]) -> None:
418434
"""Public facing API for _gen_results().
419435
420436
results should be a dictionary mapping deployed item to result.
@@ -435,6 +451,7 @@ def gen_results(self, results) -> None:
435451
self.gen_results_summary()
436452
self.write_results(self.results_html_name, self.results_summary_md)
437453

454+
@abstractmethod
438455
def gen_results_summary(self) -> None:
439456
"""Public facing API to generate summary results for each IP/cfg file."""
440457

@@ -466,4 +483,5 @@ def _get_results_page_link(self, relative_to: str, link_text: str = "") -> str:
466483
return f"[{link_text}]({relative_link})"
467484

468485
def has_errors(self) -> bool:
486+
"""Return error state."""
469487
return self.errors_seen

src/dvsim/flow/sim.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44

55
"""Class describing simulation configuration object."""
66

7-
import collections
87
import fnmatch
98
import json
109
import os
1110
import re
1211
import sys
13-
from collections import OrderedDict
12+
from collections import OrderedDict, defaultdict
13+
from collections.abc import Mapping
1414
from datetime import datetime, timezone
1515
from pathlib import Path
1616
from typing import ClassVar
1717

1818
from tabulate import tabulate
1919

2020
from dvsim.flow.base import FlowCfg
21-
from dvsim.job.deploy import CompileSim, CovAnalyze, CovMerge, CovReport, CovUnr, RunTest
21+
from dvsim.job.deploy import CompileSim, CovAnalyze, CovMerge, CovReport, CovUnr, Deploy, RunTest
2222
from dvsim.logging import log
2323
from dvsim.modes import BuildMode, Mode, RunMode, find_mode
2424
from dvsim.regression import Regression
@@ -327,7 +327,7 @@ def _print_list(self) -> None:
327327
log.info(mode_name)
328328

329329
def _create_build_and_run_list(self) -> None:
330-
"""Generates a list of deployable objects from the provided items.
330+
"""Generate a list of deployable objects from the provided items.
331331
332332
Tests to be run are provided with --items switch. These can be glob-
333333
style patterns. This method finds regressions and tests that match
@@ -562,20 +562,13 @@ def cov_unr(self) -> None:
562562
for item in self.cfgs:
563563
item._cov_unr()
564564

565-
def _gen_json_results(self, run_results):
566-
"""Returns the run results as json-formatted dictionary."""
567-
568-
def _empty_str_as_none(s: str) -> str | None:
569-
"""Map an empty string to None and retain the value of a non-empty
570-
string.
571-
572-
This is intended to clearly distinguish an empty string, which may
573-
or may not be an valid value, from an invalid value.
574-
"""
575-
return s if s != "" else None
565+
def _gen_json_results(self, run_results: Mapping[Deploy, str]) -> str:
566+
"""Return the run results as json-formatted dictionary."""
576567

577568
def _pct_str_to_float(s: str) -> float | None:
578-
"""Map a percentage value stored in a string with ` %` suffix to a
569+
"""Extract percent or None.
570+
571+
Map a percentage value stored in a string with ` %` suffix to a
579572
float or to None if the conversion to Float fails.
580573
"""
581574
try:
@@ -608,7 +601,7 @@ def _test_result_to_dict(tr) -> dict:
608601
# Describe name of hardware block targeted by this run and optionally
609602
# the variant of the hardware block.
610603
results["block_name"] = self.name.lower()
611-
results["block_variant"] = _empty_str_as_none(self.variant.lower())
604+
results["block_variant"] = self.variant.lower() or None
612605

613606
# The timestamp for this run has been taken with `utcnow()` and is
614607
# stored in a custom format. Store it in standard ISO format with
@@ -620,7 +613,7 @@ def _test_result_to_dict(tr) -> dict:
620613
# Extract Git properties.
621614
m = re.search(r"https://github.com/.+?/tree/([0-9a-fA-F]+)", self.revision)
622615
results["git_revision"] = m.group(1) if m else None
623-
results["git_branch_name"] = _empty_str_as_none(self.branch)
616+
results["git_branch_name"] = self.branch or None
624617

625618
# Describe type of report and tool used.
626619
results["report_type"] = "simulation"
@@ -704,7 +697,7 @@ def _test_result_to_dict(tr) -> dict:
704697
if sim_results.buckets:
705698
by_tests = sorted(sim_results.buckets.items(), key=lambda i: len(i[1]), reverse=True)
706699
for bucket, tests in by_tests:
707-
unique_tests = collections.defaultdict(list)
700+
unique_tests = defaultdict(list)
708701
for test, line, context in tests:
709702
if not isinstance(test, RunTest):
710703
continue
@@ -743,16 +736,18 @@ def _test_result_to_dict(tr) -> dict:
743736
# Return the `results` dictionary as json string.
744737
return json.dumps(self.results_dict)
745738

746-
def _gen_results(self, run_results):
747-
"""The function is called after the regression has completed. It collates the
739+
def _gen_results(self, results: Mapping[Deploy, str]) -> str:
740+
"""Generate simulation results.
741+
742+
The function is called after the regression has completed. It collates the
748743
status of all run targets and generates a dict. It parses the testplan and
749744
maps the generated result to the testplan entries to generate a final table
750745
(list). It also prints the full list of failures for debug / triage. If cov
751746
is enabled, then the summary coverage report is also generated. The final
752747
result is in markdown format.
753748
"""
754749

755-
def indent_by(level):
750+
def indent_by(level: int) -> str:
756751
return " " * (4 * level)
757752

758753
def create_failure_message(test, line, context):
@@ -769,7 +764,7 @@ def create_failure_message(test, line, context):
769764
return message
770765

771766
def create_bucket_report(buckets):
772-
"""Creates a report based on the given buckets.
767+
"""Create a report based on the given buckets.
773768
774769
The buckets are sorted by descending number of failures. Within
775770
buckets this also group tests by unqualified name, and just a few
@@ -787,7 +782,7 @@ def create_bucket_report(buckets):
787782
fail_msgs = ["\n## Failure Buckets", ""]
788783
for bucket, tests in by_tests:
789784
fail_msgs.append(f"* `{bucket}` has {len(tests)} failures:")
790-
unique_tests = collections.defaultdict(list)
785+
unique_tests = defaultdict(list)
791786
for test, line, context in tests:
792787
unique_tests[test.name].append((test, line, context))
793788
for name, test_reseeds in list(unique_tests.items())[:_MAX_UNIQUE_TESTS]:
@@ -812,7 +807,7 @@ def create_bucket_report(buckets):
812807
return fail_msgs
813808

814809
deployed_items = self.deploy
815-
results = SimResults(deployed_items, run_results)
810+
results = SimResults(deployed_items, results)
816811

817812
# Generate results table for runs.
818813
results_str = "## " + self.results_title + "\n"
@@ -881,7 +876,7 @@ def create_bucket_report(buckets):
881876

882877
# Append coverage results if coverage was enabled.
883878
if self.cov_report_deploy is not None:
884-
report_status = run_results[self.cov_report_deploy]
879+
report_status = results[self.cov_report_deploy]
885880
if report_status == "P":
886881
results_str += "\n## Coverage Results\n"
887882
# Link the dashboard page using "cov_report_page" value.

src/dvsim/job/deploy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
if TYPE_CHECKING:
2828
from dvsim.flow.sim import SimCfg
29+
from dvsim.launcher.base import Launcher
2930

3031

3132
class Deploy:
@@ -92,7 +93,7 @@ def __init__(self, sim_cfg: "SimCfg") -> None:
9293
self.cmd = self._construct_cmd()
9394

9495
# Launcher instance created later using create_launcher() method.
95-
self.launcher = None
96+
self.launcher: Launcher | None = None
9697

9798
# Job's wall clock time (a.k.a CPU time, or runtime).
9899
self.job_runtime = JobTime()
@@ -484,7 +485,7 @@ class RunTest(Deploy):
484485
fixed_seed = None
485486
cmds_list_vars = ["pre_run_cmds", "post_run_cmds"]
486487

487-
def __init__(self, index, test, build_job, sim_cfg) -> None:
488+
def __init__(self, index, test, build_job, sim_cfg: "SimCfg") -> None:
488489
self.test_obj = test
489490
self.index = index
490491
self.build_seed = sim_cfg.build_seed

src/dvsim/launcher/factory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
EDACLOUD_LAUNCHER_EXISTS = False
2121

2222
# The chosen launcher class.
23-
_LAUNCHER_CLS = None
23+
_LAUNCHER_CLS: type[Launcher] | None = None
2424

2525

26-
def set_launcher_type(is_local=False) -> None:
27-
"""Sets the launcher type that will be used to launch the jobs.
26+
def set_launcher_type(is_local: bool = False) -> None:
27+
"""Set the launcher type that will be used to launch the jobs.
2828
2929
The env variable `DVSIM_LAUNCHER` is used to identify what launcher system
3030
to use. This variable is specific to the user's work site. It is meant to
@@ -66,7 +66,7 @@ def set_launcher_type(is_local=False) -> None:
6666
_LAUNCHER_CLS = LocalLauncher
6767

6868

69-
def get_launcher_cls():
69+
def get_launcher_cls() -> type[Launcher]:
7070
"""Returns the chosen launcher class."""
7171
assert _LAUNCHER_CLS is not None
7272
return _LAUNCHER_CLS

0 commit comments

Comments
 (0)