Skip to content
12 changes: 8 additions & 4 deletions metric_learn/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
from six.moves import xrange
from scipy.sparse import coo_matrix
from sklearn.utils import check_random_state

__all__ = ['Constraints']

Expand All @@ -23,7 +24,8 @@ def __init__(self, partial_labels):
self.known_label_idx, = np.where(partial_labels >= 0)
self.known_labels = partial_labels[self.known_label_idx]

def adjacency_matrix(self, num_constraints, random_state=np.random):
def adjacency_matrix(self, num_constraints, random_state=None):
random_state = check_random_state(random_state)
a, b, c, d = self.positive_negative_pairs(num_constraints,
random_state=random_state)
row = np.concatenate((a, c))
Expand All @@ -35,7 +37,8 @@ def adjacency_matrix(self, num_constraints, random_state=np.random):
return adj + adj.T

def positive_negative_pairs(self, num_constraints, same_length=False,
random_state=np.random):
random_state=None):
random_state = check_random_state(random_state)
a, b = self._pairs(num_constraints, same_label=True,
random_state=random_state)
c, d = self._pairs(num_constraints, same_label=False,
Expand Down Expand Up @@ -68,13 +71,14 @@ def _pairs(self, num_constraints, same_label=True, max_iter=10,
ab = np.array(list(ab)[:num_constraints], dtype=int)
return self.known_label_idx[ab.T]

def chunks(self, num_chunks=100, chunk_size=2, random_state=np.random):
def chunks(self, num_chunks=100, chunk_size=2, random_state=None):
"""
the random state object to be passed must be a numpy random seed
"""
random_state = check_random_state(random_state)
chunks = -np.ones_like(self.known_label_idx, dtype=int)
uniq, lookup = np.unique(self.known_labels, return_inverse=True)
all_inds = [set(np.where(lookup==c)[0]) for c in xrange(len(uniq))]
all_inds = [set(np.where(lookup == c)[0]) for c in xrange(len(uniq))]
idx = 0
while idx < num_chunks and all_inds:
if len(all_inds) == 1:
Expand Down
29 changes: 23 additions & 6 deletions metric_learn/itml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import warnings
import numpy as np
from six.moves import xrange
from sklearn.exceptions import ChangedBehaviorWarning
from sklearn.metrics import pairwise_distances
from sklearn.utils.validation import check_array
from sklearn.base import TransformerMixin
Expand Down Expand Up @@ -298,7 +299,6 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
A positive definite (PD) matrix of shape
(n_features, n_features), that will be used as such to set the
prior.

A0 : Not used
.. deprecated:: 0.5.0
`A0` was deprecated in version 0.5.0 and will
Expand All @@ -310,7 +310,9 @@ class ITML_Supervised(_BaseITML, TransformerMixin):
tuples will be formed like this: X[indices].
random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int. If
``prior='random'``, ``random_state`` is used to set the prior.
``prior='random'``, ``random_state`` is used to set the prior. In any
case, `random_state` is also used to randomly sample constraints from
labels.


Attributes
Expand Down Expand Up @@ -350,7 +352,7 @@ def __init__(self, gamma=1., max_iter=1000, convergence_threshold=1e-3,
self.num_constraints = num_constraints
self.bounds = bounds

def fit(self, X, y, random_state=np.random, bounds=None):
def fit(self, X, y, random_state='deprecated', bounds=None):
"""Create constraints from labels and learn the ITML model.


Expand All @@ -362,8 +364,11 @@ def fit(self, X, y, random_state=np.random, bounds=None):
y : (n) array-like
Data labels.

random_state : numpy.random.RandomState, optional
If provided, controls random number generation.
random_state : Not used
.. deprecated:: 0.5.0
`random_state` in the `fit` function was deprecated in version 0.5.0
and will be removed in 0.6.0. Set `random_state` at initialization
instead (when instantiating a new `ITML_Supervised` object).

bounds : array-like of two numbers
Bounds on similarity, aside slack variables, s.t.
Expand All @@ -384,6 +389,18 @@ def fit(self, X, y, random_state=np.random, bounds=None):
' It has been deprecated in version 0.5.0 and will be'
' removed in 0.6.0. Use the "bounds" parameter of this '
'fit method instead.', DeprecationWarning)
if random_state != 'deprecated':
warnings.warn('"random_state" parameter in the `fit` function is '
'deprecated. Set `random_state` at initialization '
'instead (when instantiating a new `ITML_Supervised` '
'object).', DeprecationWarning)
else:
warnings.warn('As of v0.5.0, `ITML_Supervised` now uses the '
'`random_state` given at initialization to sample '
'constraints, not the default `np.random` from the `fit` '
'method, since this argument is now deprecated. '
'This warning will disappear in v0.6.0.',
ChangedBehaviorWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
Expand All @@ -392,6 +409,6 @@ def fit(self, X, y, random_state=np.random, bounds=None):

c = Constraints(y)
pos_neg = c.positive_negative_pairs(num_constraints,
random_state=random_state)
random_state=self.random_state)
pairs, y = wrap_pairs(X, pos_neg)
return _BaseITML._fit(self, pairs, y, bounds=bounds)
26 changes: 21 additions & 5 deletions metric_learn/lsml.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,8 @@ class LSML_Supervised(_BaseLSML, TransformerMixin):
random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int. If
``init='random'``, ``random_state`` is used to set the random
prior.
prior. In any case, `random_state` is also used to randomly sample
constraints from labels.

Attributes
----------
Expand All @@ -308,7 +309,7 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None,
self.num_constraints = num_constraints
self.weights = weights

def fit(self, X, y, random_state=np.random):
def fit(self, X, y, random_state='deprecated'):
"""Create constraints from labels and learn the LSML model.

Parameters
Expand All @@ -319,13 +320,28 @@ def fit(self, X, y, random_state=np.random):
y : (n) array-like
Data labels.

random_state : numpy.random.RandomState, optional
If provided, controls random number generation.
random_state : Not used
.. deprecated:: 0.5.0
`random_state` in the `fit` function was deprecated in version 0.5.0
and will be removed in 0.6.0. Set `random_state` at initialization
instead (when instantiating a new `LSML_Supervised` object).
"""
if self.num_labeled != 'deprecated':
warnings.warn('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
' removed in 0.6.0', DeprecationWarning)
if random_state != 'deprecated':
warnings.warn('"random_state" parameter in the `fit` function is '
'deprecated. Set `random_state` at initialization '
'instead (when instantiating a new `LSML_Supervised` '
'object).', DeprecationWarning)
else:
warnings.warn('As of v0.5.0, `LSML_Supervised` now uses the '
'`random_state` given at initialization to sample '
'constraints, not the default `np.random` from the `fit` '
'method, since this argument is now deprecated. '
'This warning will disappear in v0.6.0.',
ChangedBehaviorWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
Expand All @@ -334,6 +350,6 @@ def fit(self, X, y, random_state=np.random):

c = Constraints(y)
pos_neg = c.positive_negative_pairs(num_constraints, same_length=True,
random_state=random_state)
random_state=self.random_state)
return _BaseLSML._fit(self, X[np.column_stack(pos_neg)],
weights=self.weights)
26 changes: 21 additions & 5 deletions metric_learn/mmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,8 @@ class MMC_Supervised(_BaseMMC, TransformerMixin):
random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int. If
``init='random'``, ``random_state`` is used to initialize the random
Mahalanobis matrix.
Mahalanobis matrix. In any case, `random_state` is also used to
randomly sample constraints from labels.

`MMC_Supervised` creates pairs of similar sample by taking same class
samples, and pairs of dissimilar samples by taking different class
Expand Down Expand Up @@ -566,7 +567,7 @@ def __init__(self, max_iter=100, max_proj=10000, convergence_threshold=1e-6,
self.num_labeled = num_labeled
self.num_constraints = num_constraints

def fit(self, X, y, random_state=np.random):
def fit(self, X, y, random_state='deprecated'):
"""Create constraints from labels and learn the MMC model.

Parameters
Expand All @@ -575,13 +576,28 @@ def fit(self, X, y, random_state=np.random):
Input data, where each row corresponds to a single instance.
y : (n) array-like
Data labels.
random_state : numpy.random.RandomState, optional
If provided, controls random number generation.
random_state : Not used
.. deprecated:: 0.5.0
`random_state` in the `fit` function was deprecated in version 0.5.0
and will be removed in 0.6.0. Set `random_state` at initialization
instead (when instantiating a new `MMC_Supervised` object).
"""
if self.num_labeled != 'deprecated':
warnings.warn('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
' removed in 0.6.0', DeprecationWarning)
if random_state != 'deprecated':
warnings.warn('"random_state" parameter in the `fit` function is '
'deprecated. Set `random_state` at initialization '
'instead (when instantiating a new `MMC_Supervised` '
'object).', DeprecationWarning)
else:
warnings.warn('As of v0.5.0, `MMC_Supervised` now uses the '
'`random_state` given at initialization to sample '
'constraints, not the default `np.random` from the `fit` '
'method, since this argument is now deprecated. '
'This warning will disappear in v0.6.0.',
ChangedBehaviorWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
Expand All @@ -590,6 +606,6 @@ def fit(self, X, y, random_state=np.random):

c = Constraints(y)
pos_neg = c.positive_negative_pairs(num_constraints,
random_state=random_state)
random_state=self.random_state)
pairs, y = wrap_pairs(X, pos_neg)
return _BaseMMC._fit(self, pairs, y)
32 changes: 28 additions & 4 deletions metric_learn/rca.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,17 @@ class RCA_Supervised(RCA):
be removed in 0.6.0. Use `n_components` instead.

num_chunks: int, optional

chunk_size: int, optional

preprocessor : array-like, shape=(n_samples, n_features) or callable
The preprocessor to call to get tuples from indices. If array-like,
tuples will be formed like this: X[indices].

random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int.
It is used to randomly sample constraints from labels.

Attributes
----------
transformer_ : `numpy.ndarray`, shape=(n_components, n_features)
Expand All @@ -197,13 +203,15 @@ class RCA_Supervised(RCA):

def __init__(self, num_dims='deprecated', n_components=None,
pca_comps='deprecated', num_chunks=100, chunk_size=2,
preprocessor=None):
preprocessor=None, random_state=None):
"""Initialize the supervised version of `RCA`."""
RCA.__init__(self, num_dims=num_dims, n_components=n_components,
pca_comps=pca_comps, preprocessor=preprocessor)
self.num_chunks = num_chunks
self.chunk_size = chunk_size
self.random_state = random_state

def fit(self, X, y, random_state=np.random):
def fit(self, X, y, random_state='deprecated'):
"""Create constraints from labels and learn the RCA model.
Needs num_constraints specified in constructor.

Expand All @@ -212,10 +220,26 @@ def fit(self, X, y, random_state=np.random):
X : (n x d) data matrix
each row corresponds to a single instance
y : (n) data labels
random_state : a random.seed object to fix the random_state if needed.
random_state : Not used
.. deprecated:: 0.5.0
`random_state` in the `fit` function was deprecated in version 0.5.0
and will be removed in 0.6.0. Set `random_state` at initialization
instead (when instantiating a new `RCA_Supervised` object).
"""
if random_state != 'deprecated':
warnings.warn('"random_state" parameter in the `fit` function is '
'deprecated. Set `random_state` at initialization '
'instead (when instantiating a new `RCA_Supervised` '
'object).', DeprecationWarning)
else:
warnings.warn('As of v0.5.0, `RCA_Supervised` now uses the '
'`random_state` given at initialization to sample '
'constraints, not the default `np.random` from the `fit` '
'method, since this argument is now deprecated. '
'This warning will disappear in v0.6.0.',
ChangedBehaviorWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
chunks = Constraints(y).chunks(num_chunks=self.num_chunks,
chunk_size=self.chunk_size,
random_state=random_state)
random_state=self.random_state)
return RCA.fit(self, X, chunks)
27 changes: 21 additions & 6 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,8 @@ class SDML_Supervised(_BaseSDML, TransformerMixin):
random_state : int or numpy.RandomState or None, optional (default=None)
A pseudo random number generator object or a seed for it if int. If
``init='random'``, ``random_state`` is used to set the random
prior.
prior. In any case, `random_state` is also used to randomly sample
constraints from labels.

Attributes
----------
Expand All @@ -336,7 +337,7 @@ def __init__(self, balance_param=0.5, sparsity_param=0.01, prior=None,
self.num_labeled = num_labeled
self.num_constraints = num_constraints

def fit(self, X, y, random_state=np.random):
def fit(self, X, y, random_state='deprecated'):
"""Create constraints from labels and learn the SDML model.

Parameters
Expand All @@ -345,9 +346,11 @@ def fit(self, X, y, random_state=np.random):
data matrix, where each row corresponds to a single instance
y : array-like, shape (n,)
data labels, one for each instance
random_state : {numpy.random.RandomState, int}, optional
Random number generator or random seed. If not given, the singleton
numpy.random will be used.
random_state : Not used
.. deprecated:: 0.5.0
`random_state` in the `fit` function was deprecated in version 0.5.0
and will be removed in 0.6.0. Set `random_state` at initialization
instead (when instantiating a new `SDML_Supervised` object).

Returns
-------
Expand All @@ -358,6 +361,18 @@ def fit(self, X, y, random_state=np.random):
warnings.warn('"num_labeled" parameter is not used.'
' It has been deprecated in version 0.5.0 and will be'
' removed in 0.6.0', DeprecationWarning)
if random_state != 'deprecated':
warnings.warn('"random_state" parameter in the `fit` function is '
'deprecated. Set `random_state` at initialization '
'instead (when instantiating a new `SDML_Supervised` '
'object).', DeprecationWarning)
else:
warnings.warn('As of v0.5.0, `SDML_Supervised` now uses the '
'`random_state` given at initialization to sample '
'constraints, not the default `np.random` from the `fit` '
'method, since this argument is now deprecated. '
'This warning will disappear in v0.6.0.',
ChangedBehaviorWarning)
X, y = self._prepare_inputs(X, y, ensure_min_samples=2)
num_constraints = self.num_constraints
if num_constraints is None:
Expand All @@ -366,6 +381,6 @@ def fit(self, X, y, random_state=np.random):

c = Constraints(y)
pos_neg = c.positive_negative_pairs(num_constraints,
random_state=random_state)
random_state=self.random_state)
pairs, y = wrap_pairs(X, pos_neg)
return _BaseSDML._fit(self, pairs, y)
Loading