Skip to content

Commit 2380f51

Browse files
RobinVogelbellet
authored andcommitted
Corrects the forgotten bits of PR #267 (#269)
* maj * maj * corrected PR 267 * trailing whitespace * test calibrate_threshold, test predict * maj * Checks estimator is fitted before set threshold * correct failed tests with MockBadClassifier * remove checks * forgot one * missed one check_is_fitted * sklearn changed the assumptions behind check_is_fitted
1 parent 1b40c3b commit 2380f51

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

metric_learn/base_metric.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ def _prepare_inputs(self, X, y=None, type_of_inputs='classic',
9393
The checked input labels array.
9494
"""
9595
self._check_preprocessor()
96+
97+
check_is_fitted(self, ['preprocessor_'])
9698
return check_input(X, y,
9799
type_of_inputs=type_of_inputs,
98100
preprocessor=self.preprocessor_,
@@ -215,6 +217,7 @@ def score_pairs(self, pairs):
215217
:ref:`mahalanobis_distances` : The section of the project documentation
216218
that describes Mahalanobis Distances.
217219
"""
220+
check_is_fitted(self, ['preprocessor_'])
218221
pairs = check_input(pairs, type_of_inputs='tuples',
219222
preprocessor=self.preprocessor_,
220223
estimator=self, tuple_size=2)
@@ -336,8 +339,10 @@ def predict(self, pairs):
336339
y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,)
337340
The predicted learned metric value between samples in every pair.
338341
"""
342+
check_is_fitted(self, 'preprocessor_')
343+
339344
if "threshold_" not in vars(self):
340-
msg = ("A threshold for this estimator has not been set,"
345+
msg = ("A threshold for this estimator has not been set, "
341346
"call its set_threshold or calibrate_threshold method.")
342347
raise AttributeError(msg)
343348
return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1
@@ -414,6 +419,8 @@ def set_threshold(self, threshold):
414419
self : `_PairsClassifier`
415420
The pairs classifier with the new threshold set.
416421
"""
422+
check_is_fitted(self, 'preprocessor_')
423+
417424
self.threshold_ = threshold
418425
return self
419426

@@ -476,6 +483,7 @@ def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy',
476483
--------
477484
sklearn.calibration : scikit-learn's module for calibrating classifiers
478485
"""
486+
check_is_fitted(self, 'preprocessor_')
479487

480488
self._validate_calibration_params(strategy, min_rate, beta)
481489

test/test_pairs_classifiers.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,31 @@ def test_predict_monotonous(estimator, build_dataset,
6666
ids=ids_pairs_learners)
6767
def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset,
6868
with_preprocessor):
69-
"""Test that a NotFittedError is raised if someone tries to predict and
70-
the metric learner has not been fitted."""
69+
"""Test that a NotFittedError is raised if someone tries to use
70+
score_pairs, decision_function, get_metric, transform or
71+
get_mahalanobis_matrix on input data and the metric learner
72+
has not been fitted."""
7173
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor)
7274
estimator = clone(estimator)
7375
estimator.set_params(preprocessor=preprocessor)
7476
set_random_state(estimator)
77+
with pytest.raises(NotFittedError):
78+
estimator.score_pairs(input_data)
7579
with pytest.raises(NotFittedError):
7680
estimator.decision_function(input_data)
81+
with pytest.raises(NotFittedError):
82+
estimator.get_metric()
83+
with pytest.raises(NotFittedError):
84+
estimator.transform(input_data)
85+
with pytest.raises(NotFittedError):
86+
estimator.get_mahalanobis_matrix()
87+
with pytest.raises(NotFittedError):
88+
estimator.calibrate_threshold(input_data, labels)
89+
90+
with pytest.raises(NotFittedError):
91+
estimator.set_threshold(0.5)
92+
with pytest.raises(NotFittedError):
93+
estimator.predict(input_data)
7794

7895

7996
@pytest.mark.parametrize('calibration_params',
@@ -138,15 +155,16 @@ def fit(self, pairs, y):
138155

139156

140157
def test_unset_threshold():
141-
# test that set_threshold indeed sets the threshold
158+
"""Tests that the "threshold is unset" error is raised when using predict
159+
(performs binary classification on pairs) with an unset threshold."""
142160
identity_pairs_classifier = IdentityPairsClassifier()
143161
pairs = np.array([[[0.], [1.]], [[1.], [3.]], [[2.], [5.]], [[3.], [7.]]])
144162
y = np.array([1, 1, -1, -1])
145163
identity_pairs_classifier.fit(pairs, y)
146164
with pytest.raises(AttributeError) as e:
147165
identity_pairs_classifier.predict(pairs)
148166

149-
expected_msg = ("A threshold for this estimator has not been set,"
167+
expected_msg = ("A threshold for this estimator has not been set, "
150168
"call its set_threshold or calibrate_threshold method.")
151169

152170
assert str(e.value) == expected_msg
@@ -362,6 +380,7 @@ class MockBadPairsClassifier(MahalanobisMixin, _PairsClassifierMixin):
362380
"""
363381

364382
def fit(self, pairs, y, calibration_params=None):
383+
self.preprocessor_ = 'not used'
365384
self.components_ = 'not used'
366385
self.calibrate_threshold(pairs, y, **(calibration_params if
367386
calibration_params is not None else

test/test_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,8 @@ def test_array_like_indexer_array_like_valid_classic(input_data, indices):
749749
"""Checks that any array-like is valid in the 'preprocessor' argument,
750750
and in the indices, for a classic input"""
751751
class MockMetricLearner(MahalanobisMixin):
752+
def fit(self):
753+
pass
752754
pass
753755

754756
mock_algo = MockMetricLearner(preprocessor=input_data)
@@ -763,6 +765,8 @@ def test_array_like_indexer_array_like_valid_tuples(input_data, indices):
763765
"""Checks that any array-like is valid in the 'preprocessor' argument,
764766
and in the indices, for a classic input"""
765767
class MockMetricLearner(MahalanobisMixin):
768+
def fit(self):
769+
pass
766770
pass
767771

768772
mock_algo = MockMetricLearner(preprocessor=input_data)

0 commit comments

Comments
 (0)