Skip to content
10 changes: 8 additions & 2 deletions metric_learn/base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,8 +569,14 @@ def set_threshold(self, threshold):
The pairs classifier with the new threshold set.
"""
check_is_fitted(self, 'preprocessor_')

self.threshold_ = threshold
try:
self.threshold_ = float(threshold)
except TypeError:
raise ValueError('Parameter threshold must be a real number. '
'Got {} instead.'.format(type(threshold)))
except ValueError:
raise ValueError('Parameter threshold must be a real number. '
'Got {} instead.'.format(type(threshold)))
return self

def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy',
Expand Down
19 changes: 19 additions & 0 deletions test/test_pairs_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,25 @@ def test_set_threshold():
assert identity_pairs_classifier.threshold_ == 0.5


@pytest.mark.parametrize('value', ["ABC", None, [1, 2, 3], {'key': None},
(1, 2), set(),
np.array([[[0.], [1.]], [[1.], [3.]]])])
def test_set_wrong_type_threshold(value):
"""
Test that `set_threshold` indeed sets the threshold
and cannot accept nothing but float or integers, but
being permissive with boolean True=1.0 and False=0.0
"""
model = IdentityPairsClassifier()
model.fit(np.array([[[0.], [1.]]]), np.array([1]))
msg = ('Parameter threshold must be a real number. '
'Got {} instead.'.format(type(value)))

with pytest.raises(ValueError) as e: # String
model.set_threshold(value)
assert str(e.value).startswith(msg)


def test_f_beta_1_is_f_1():
# test that putting beta to 1 indeed finds the best threshold to optimize
# the f1_score
Expand Down