@@ -66,14 +66,31 @@ def test_predict_monotonous(estimator, build_dataset,
66
66
ids = ids_pairs_learners )
67
67
def test_raise_not_fitted_error_if_not_fitted (estimator , build_dataset ,
68
68
with_preprocessor ):
69
- """Test that a NotFittedError is raised if someone tries to predict and
70
- the metric learner has not been fitted."""
69
+ """Test that a NotFittedError is raised if someone tries to use
70
+ score_pairs, decision_function, get_metric, transform or
71
+ get_mahalanobis_matrix on input data and the metric learner
72
+ has not been fitted."""
71
73
input_data , labels , preprocessor , _ = build_dataset (with_preprocessor )
72
74
estimator = clone (estimator )
73
75
estimator .set_params (preprocessor = preprocessor )
74
76
set_random_state (estimator )
77
+ with pytest .raises (NotFittedError ):
78
+ estimator .score_pairs (input_data )
75
79
with pytest .raises (NotFittedError ):
76
80
estimator .decision_function (input_data )
81
+ with pytest .raises (NotFittedError ):
82
+ estimator .get_metric ()
83
+ with pytest .raises (NotFittedError ):
84
+ estimator .transform (input_data )
85
+ with pytest .raises (NotFittedError ):
86
+ estimator .get_mahalanobis_matrix ()
87
+ with pytest .raises (NotFittedError ):
88
+ estimator .calibrate_threshold (input_data , labels )
89
+
90
+ with pytest .raises (NotFittedError ):
91
+ estimator .set_threshold (0.5 )
92
+ with pytest .raises (NotFittedError ):
93
+ estimator .predict (input_data )
77
94
78
95
79
96
@pytest .mark .parametrize ('calibration_params' ,
@@ -138,15 +155,16 @@ def fit(self, pairs, y):
138
155
139
156
140
157
def test_unset_threshold ():
141
- # test that set_threshold indeed sets the threshold
158
+ """Tests that the "threshold is unset" error is raised when using predict
159
+ (performs binary classification on pairs) with an unset threshold."""
142
160
identity_pairs_classifier = IdentityPairsClassifier ()
143
161
pairs = np .array ([[[0. ], [1. ]], [[1. ], [3. ]], [[2. ], [5. ]], [[3. ], [7. ]]])
144
162
y = np .array ([1 , 1 , - 1 , - 1 ])
145
163
identity_pairs_classifier .fit (pairs , y )
146
164
with pytest .raises (AttributeError ) as e :
147
165
identity_pairs_classifier .predict (pairs )
148
166
149
- expected_msg = ("A threshold for this estimator has not been set,"
167
+ expected_msg = ("A threshold for this estimator has not been set, "
150
168
"call its set_threshold or calibrate_threshold method." )
151
169
152
170
assert str (e .value ) == expected_msg
@@ -362,6 +380,7 @@ class MockBadPairsClassifier(MahalanobisMixin, _PairsClassifierMixin):
362
380
"""
363
381
364
382
def fit (self , pairs , y , calibration_params = None ):
383
+ self .preprocessor_ = 'not used'
365
384
self .components_ = 'not used'
366
385
self .calibrate_threshold (pairs , y , ** (calibration_params if
367
386
calibration_params is not None else
0 commit comments