Skip to content
18 changes: 16 additions & 2 deletions keras/backend/torch/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ def _post_build(self):
self._track_variables()

def _track_variables(self):
self.torch_params = torch.nn.ParameterList(
[variable.value for variable in self.variables]
# Index given to ParameterDict must be a string
self.torch_params = torch.nn.ParameterDict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dict is more readable, but it's unsafe -- there's no guarantee that variable names are unique within a model (except for Functional models). It's entirely possible to create models with duplicate variable paths, which would cause tracking issues above. So the list is preferable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be safe to use id(variable) as a key for torch.nn.ParameterDict. I believe BaseOptimizer adopts the same approach to get the mapping of variables.

I could not find a solution for safely adding/removing the variable from torch.nn.ParameterList
(in and remove of KerasVariable are not supported)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the original issue with using a list though?

Copy link
Contributor Author

@james77777778 james77777778 Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that torch.nn.ParameterList cannot remove elements. This will cause failure in the int8 quantization because we need to remove the old floating _kernel.
Also, it is hard to determine if it is safe to append a new variable because we cannot check whether the item is already in the list.

from keras import layers

layer = layers.Dense(units=16)
layer.build([None, 8])
assert len(layer.torch_params) == len(layer.variables)

layer.enable_lora(rank=2)
assert len(layer.torch_params) == len(layer.variables)  # <--
master this pr
Failed at line 8 (4 vs. 2) Success

The biggest issue is that zero_grad() will fail to reset the uncaptured variables.
I believe the currect LoRA implementation will not be trained correctly on torch backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To demonstrate the error of zero_grad using LoRA:

import torch

from keras import layers

layer = layers.Dense(units=3)
layer.build([None, 2])
layer.enable_lora(rank=1)


def list_grads(layer):
    grads = dict()
    for v in layer.trainable_variables:
        grads[v.name] = v.value.grad
    return grads


# Fake grads
for v in layer.trainable_variables:
    v.value.grad = torch.rand_like(v.value)

print(list_grads(layer))
layer.zero_grad()
print(list_grads(layer))
# master branch
{'bias': tensor([0.6259, 0.4827, 0.6012], device='cuda:0'), 'lora_kernel_a': tensor([[0.6620],
        [0.7231]], device='cuda:0'), 'lora_kernel_b': tensor([[0.7123, 0.9257, 0.1676]], device='cuda:0')}
{'bias': None, 'lora_kernel_a': tensor([[0.6620],
        [0.7231]], device='cuda:0'), 'lora_kernel_b': tensor([[0.7123, 0.9257, 0.1676]], device='cuda:0')}

# this pr
{'bias': tensor([0.5960, 0.2336, 0.1569], device='cuda:0'), 'lora_kernel_a': tensor([[0.9123],
        [0.9217]], device='cuda:0'), 'lora_kernel_b': tensor([[0.3435, 0.9276, 0.6599]], device='cuda:0')}
{'bias': None, 'lora_kernel_a': None, 'lora_kernel_b': None}

{str(id(variable)): variable.value for variable in self.variables}
)

def parameters(self, recurse=True):
Expand All @@ -38,3 +39,16 @@ def _setattr_hook(self, name, value):
if not isinstance(self, TorchModuleWrapper):
value = TorchModuleWrapper(value)
return name, value

def _post_track_variable(self, variable):
if hasattr(self, "torch_params"):
# Index given to ParameterDict must be a string
key = str(id(variable))
if key not in self.torch_params:
self.torch_params[key] = variable.value

def _post_untrack_variable(self, variable):
if hasattr(self, "torch_params"):
# Index given to ParameterDict must be a string
key = str(id(variable))
self.torch_params.pop(key)
2 changes: 1 addition & 1 deletion keras/layers/attention/grouped_query_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_basics(self):
expected_output_shape=(2, 8, 16),
expected_num_trainable_weights=4,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_seed_generators=1,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
Expand Down
2 changes: 1 addition & 1 deletion keras/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_basics(self):
expected_output_shape=(2, 8, 16),
expected_num_trainable_weights=4,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_seed_generators=1,
expected_num_losses=0,
supports_masking=True,
run_training_check=False,
Expand Down
4 changes: 4 additions & 0 deletions keras/layers/core/dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def test_enable_lora(self):
layer.enable_lora(4)
self.assertLen(layer.trainable_weights, 3)
self.assertLen(layer.non_trainable_weights, 1)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 4)
# Try eager call
x = np.random.random((64, 8))
y = np.random.random((64, 16))
Expand Down Expand Up @@ -434,6 +436,8 @@ def test_quantize_when_lora_enabled(self):
layer.quantize("int8")
self.assertLen(layer.trainable_weights, 3)
self.assertLen(layer.non_trainable_weights, 2)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 5)

# Try calling fit()
init_lora_a_kernel_value = layer.lora_kernel_a.numpy()
Expand Down
4 changes: 4 additions & 0 deletions keras/layers/core/einsum_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def test_enable_lora(self):
layer.enable_lora(2)
self.assertLen(layer.trainable_weights, 2)
self.assertLen(layer.non_trainable_weights, 1)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 3)
# Try eager call
x = np.random.random((64, 3))
y = np.random.random((64, 8, 32))
Expand Down Expand Up @@ -532,6 +534,8 @@ def test_quantize_when_lora_enabled(self):
layer.quantize("int8")
self.assertLen(layer.trainable_weights, 2)
self.assertLen(layer.non_trainable_weights, 2)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 4)

# Try calling fit()
init_lora_a_kernel_value = layer.lora_kernel_a.numpy()
Expand Down
4 changes: 4 additions & 0 deletions keras/layers/core/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def test_enable_lora(self):
layer.enable_lora(4)
self.assertLen(layer.trainable_weights, 2)
self.assertLen(layer.non_trainable_weights, 1)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 3)
# Try eager call
x = np.random.randint(0, 9, size=(64, 3))
y = np.random.random((64, 3, 16))
Expand Down Expand Up @@ -323,6 +325,8 @@ def test_quantize_when_lora_enabled(self):
layer.quantize("int8")
self.assertLen(layer.trainable_weights, 2)
self.assertLen(layer.non_trainable_weights, 2)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 4)

# Try calling fit()
init_lora_a_embeddings_value = layer.lora_embeddings_a.numpy()
Expand Down
2 changes: 2 additions & 0 deletions keras/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,13 +1203,15 @@ def _track_variable(self, variable):
self._tracker.add_to_store("non_trainable_variables", variable)
if not self.trainable:
variable.trainable = False
self._post_track_variable(variable)

def _untrack_variable(self, variable):
previous_lock_state = self._tracker.locked
self._tracker.unlock()
self._tracker.untrack(variable)
if previous_lock_state is True:
self._tracker.lock()
self._post_untrack_variable(variable)

def add_metric(self):
# Permanently disabled
Expand Down
1 change: 0 additions & 1 deletion keras/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def __init__(
self.sparse = sparse

if self.bin_boundaries:
self.built = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly set self.built=True in __init__ will fail the run_build_asserts of symbolic call test in run_layer_test.
The root cause is that self.torch_params will not be initialized.

I'm unsure where the issue is but it should be acceptable to leave self.built=True for build method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from keras import layers

inputs = layers.Input([2])
layer = layers.Dropout(rate=0.2)
layer(inputs)
assert len(layer.torch_params) == len(layer.variables)

This script will fail at the master branch but work fine in this pr.

self.summary = None
else:
self.summary = np.array([[], []], dtype="float32")
Expand Down
4 changes: 2 additions & 2 deletions keras/layers/regularization/alpha_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
self.rate = rate
self.seed = seed
self.noise_shape = noise_shape
self.seed_generator = backend.random.SeedGenerator(seed)
if rate > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True

def call(self, inputs, training=False):
if training and self.rate > 0:
Expand Down
1 change: 0 additions & 1 deletion keras/layers/regularization/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(self, rate, noise_shape=None, seed=None, **kwargs):
if rate > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above


def call(self, inputs, training=False):
if training and self.rate > 0:
Expand Down
3 changes: 2 additions & 1 deletion keras/layers/regularization/gaussian_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def __init__(self, rate, seed=None, **kwargs):
)
self.rate = rate
self.seed = seed
self.seed_generator = backend.random.SeedGenerator(seed)
if rate > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True

def call(self, inputs, training=False):
Expand Down
3 changes: 2 additions & 1 deletion keras/layers/regularization/gaussian_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self, stddev, seed=None, **kwargs):
)
self.stddev = stddev
self.seed = seed
self.seed_generator = backend.random.SeedGenerator(seed)
if stddev > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True

def call(self, inputs, training=False):
Expand Down
6 changes: 3 additions & 3 deletions keras/layers/rnn/stacked_rnn_cells_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_basics(self):
expected_output_shape=(2, 3, 5),
expected_num_trainable_weights=9,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_seed_generators=3,
supports_masking=True,
)
self.run_layer_test(
Expand All @@ -112,7 +112,7 @@ def test_basics(self):
expected_output_shape=(2, 3, 5),
expected_num_trainable_weights=9,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_seed_generators=3,
supports_masking=True,
)
self.run_layer_test(
Expand All @@ -129,7 +129,7 @@ def test_basics(self):
expected_output_shape=(2, 3, 5),
expected_num_trainable_weights=9,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
expected_num_seed_generators=3,
supports_masking=True,
)

Expand Down
8 changes: 8 additions & 0 deletions keras/ops/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,11 @@ def _post_build(self):
def _setattr_hook(self, name, value):
"""Can be overridden for per backend post build actions."""
return name, value

def _post_track_variable(self, variable):
"""Can be overridden for per backend post track actions."""
pass

def _post_untrack_variable(self, variable):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two hooks have been introduced to enable the postprocessing of _track_variable and _untrack_variable in torch backend

"""Can be overridden for per backend post untrack actions."""
pass
22 changes: 14 additions & 8 deletions keras/random/seed_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras.api_export import keras_export
from keras.backend.common import global_state
from keras.utils import jax_utils
from keras.utils.naming import auto_name


@keras_export("keras.random.SeedGenerator")
Expand Down Expand Up @@ -44,7 +45,11 @@ def call(self, x, training=False):
```
"""

def __init__(self, seed=None, **kwargs):
def __init__(self, seed=None, name=None, **kwargs):
if name is None:
name = auto_name(self.__class__.__name__)
self.name = name

custom_backend = kwargs.pop("backend", None)
if kwargs:
raise ValueError(f"Unrecognized keyword arguments: {kwargs}")
Expand All @@ -66,13 +71,14 @@ def seed_initializer(*args, **kwargs):
dtype = kwargs.get("dtype", None)
return self.backend.convert_to_tensor([seed, 0], dtype=dtype)

self.state = self.backend.Variable(
seed_initializer,
shape=(2,),
dtype="uint32",
trainable=False,
name="seed_generator_state",
)
with backend.name_scope(self.name, caller=self):
self.state = self.backend.Variable(
seed_initializer,
shape=(2,),
dtype="uint32",
trainable=False,
name="seed_generator_state",
)

def next(self, ordered=True):
seed_state = self.state
Expand Down
29 changes: 27 additions & 2 deletions keras/testing/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,22 @@ def run_build_asserts(layer):
)
if expected_num_seed_generators is not None:
self.assertLen(
layer._seed_generators,
get_seed_generators(layer),
expected_num_seed_generators,
msg="Unexpected number of _seed_generators",
msg="Unexpected number of seed_generators",
)
if (
backend.backend() == "torch"
and expected_num_trainable_weights is not None
and expected_num_non_trainable_weights is not None
and expected_num_seed_generators is not None
):
self.assertLen(
layer.torch_params,
expected_num_trainable_weights
+ expected_num_non_trainable_weights
+ expected_num_seed_generators,
msg="Unexpected number of torch_params",
)

def run_output_asserts(layer, output, eager=False):
Expand Down Expand Up @@ -662,3 +675,15 @@ def map_shape_dtype_structure(fn, shape, dtype):
raise ValueError(
f"Cannot map function to unknown objects {shape} and {dtype}"
)


def get_seed_generators(layer):
"""Get a List of all seed generators in the layer recursively."""
seed_generators = []
seen_ids = set()
for sublayer in layer._flatten_layers(True, True):
for sg in sublayer._seed_generators:
if id(sg) not in seen_ids:
seed_generators.append(sg)
seen_ids.add(id(sg))
return seed_generators