Skip to content

Commit 00fd83b

Browse files
enochkanbfirsh
authored andcommitted
fixed overlapping plots issue
Signed-off-by: Enoch Kan <[email protected]>
1 parent f7d8a62 commit 00fd83b

File tree

3 files changed

+52
-15
lines changed

3 files changed

+52
-15
lines changed

python/keepsake/experiment.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,33 @@
11
try:
22
# backport is incompatible with 3.7+, so we must use built-in
3-
from dataclasses import dataclass, InitVar, field
43
import dataclasses
4+
from dataclasses import InitVar, dataclass, field
55
except ImportError:
66
from ._vendor.dataclasses import dataclass, InitVar, field
77
from ._vendor import dataclasses # type: ignore
8+
9+
import datetime
810
import getpass
9-
import os
10-
import math
1111
import html
12-
import datetime
1312
import json
13+
import math
14+
import os
1415
import shlex
1516
import sys
1617
from typing import (
17-
Dict,
18+
TYPE_CHECKING,
1819
Any,
19-
Optional,
20-
Tuple,
20+
Callable,
21+
Dict,
2122
List,
22-
TYPE_CHECKING,
2323
MutableSequence,
24-
Callable,
24+
Optional,
25+
Tuple,
2526
)
2627

2728
from . import console
28-
from .checkpoint import (
29-
Checkpoint,
30-
PrimaryMetric,
31-
CheckpointList,
32-
)
33-
from .metadata import rfc3339_datetime, parse_rfc3339
29+
from .checkpoint import Checkpoint, CheckpointList, PrimaryMetric
30+
from .metadata import parse_rfc3339, rfc3339_datetime
3431
from .packages import get_imported_packages
3532
from .system import get_python_version
3633
from .validate import check_path
@@ -446,6 +443,11 @@ def plot(self, metric: Optional[str] = None, logy=False):
446443
if metric is None:
447444
metric = self.primary_metric()
448445

446+
plotted_label = plt.axes().yaxis.get_label().get_text() or metric
447+
448+
if metric != plotted_label:
449+
plt.figure()
450+
449451
for exp in self:
450452
exp.plot(metric, plot_only=True)
451453

python/tests/test_plot.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import datetime
2+
3+
import matplotlib.pyplot as plt
4+
from keepsake.checkpoint import Checkpoint, CheckpointList
5+
from keepsake.experiment import Experiment, ExperimentList
6+
from keepsake.project import Project, init
7+
8+
9+
def test_num_plots(temp_workdir):
10+
with open("keepsake.yaml", "w") as f:
11+
f.write("repository: file://.keepsake/")
12+
13+
experiment = init(path=".", params={"learning_rate": 0.1, "num_epochs": 1},)
14+
15+
experiment.checkpoint(
16+
path=".",
17+
step=1,
18+
metrics={"loss": 1.1836304664611816, "accuracy": 0.3333333432674408},
19+
primary_metric=("loss", "minimize"),
20+
)
21+
experiment.checkpoint(
22+
path=".",
23+
step=2,
24+
metrics={"loss": 1.1836304662222222, "accuracy": 0.4333333432674408},
25+
primary_metric=("loss", "minimize"),
26+
)
27+
28+
experiment_list = ExperimentList([experiment])
29+
num_plots = 30
30+
for rep in range(num_plots):
31+
experiment_list.plot()
32+
assert len(plt.get_fignums()) == 1
33+
experiment_list.plot(metric="accuracy")
34+
assert len(plt.get_fignums()) == 2

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ boto3==1.12.32
1010
google-cloud-storage==1.32.0
1111
waiting==1.4.1
1212
python-dateutil==2.1
13+
matplotlib==3.3.4

0 commit comments

Comments
 (0)