Skip to content

Conversation

james77777778
Copy link
Contributor

This PR:

  • Fix the incorrect paths of the states in SeedGenerator
  • Add seed_generators attr to Layer (primarily for testing purposes)
  • Add assertion for len of torch_params in torch backend
  • Update torch_params in _track_variable and _untrack_variable

In the current codebase, the fact that all states in the seed generator share an identical path seed_generator_state should be considered an issue.

Additionally, I have found some incorrect expected_num_seed_generators in run_layer_test within certain tests.

@james77777778 james77777778 changed the title Fix path of the states in SeedGenerator Fix path of the states in SeedGenerator and tracking of torch_params Apr 11, 2024
@codecov-commenter
Copy link

codecov-commenter commented Apr 11, 2024

Codecov Report

Attention: Patch coverage is 85.00000% with 6 lines in your changes are missing coverage. Please review.

Project coverage is 76.25%. Comparing base (dca1d8a) to head (eefcc8c).

Files Patch % Lines
keras/backend/torch/layer.py 90.00% 0 Missing and 1 partial ⚠️
keras/layers/regularization/alpha_dropout.py 50.00% 0 Missing and 1 partial ⚠️
keras/layers/regularization/gaussian_dropout.py 50.00% 0 Missing and 1 partial ⚠️
keras/layers/regularization/gaussian_noise.py 50.00% 0 Missing and 1 partial ⚠️
keras/random/seed_generator.py 85.71% 0 Missing and 1 partial ⚠️
keras/testing/test_case.py 90.90% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##           master   #19495   +/-   ##
=======================================
  Coverage   76.25%   76.25%           
=======================================
  Files         367      367           
  Lines       41195    41226   +31     
  Branches     8066     8077   +11     
=======================================
+ Hits        31413    31438   +25     
  Misses       8060     8060           
- Partials     1722     1728    +6     
Flag Coverage Δ
keras 76.11% <85.00%> (+<0.01%) ⬆️
keras-jax 60.27% <57.50%> (-0.01%) ⬇️
keras-numpy 54.24% <57.50%> (-0.01%) ⬇️
keras-tensorflow 61.53% <57.50%> (-0.01%) ⬇️
keras-torch 60.39% <80.00%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Uniquifying paths for seed generator states is a good idea.

def _track_variables(self):
self.torch_params = torch.nn.ParameterList(
[variable.value for variable in self.variables]
self.torch_params = torch.nn.ParameterDict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dict is more readable, but it's unsafe -- there's no guarantee that variable names are unique within a model (except for Functional models). It's entirely possible to create models with duplicate variable paths, which would cause tracking issues above. So the list is preferable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be safe to use id(variable) as a key for torch.nn.ParameterDict. I believe BaseOptimizer adopts the same approach to get the mapping of variables.

I could not find a solution for safely adding/removing the variable from torch.nn.ParameterList
(in and remove of KerasVariable are not supported)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the original issue with using a list though?

Copy link
Contributor Author

@james77777778 james77777778 Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that torch.nn.ParameterList cannot remove elements. This will cause failure in the int8 quantization because we need to remove the old floating _kernel.
Also, it is hard to determine if it is safe to append a new variable because we cannot check whether the item is already in the list.

from keras import layers

layer = layers.Dense(units=16)
layer.build([None, 8])
assert len(layer.torch_params) == len(layer.variables)

layer.enable_lora(rank=2)
assert len(layer.torch_params) == len(layer.variables)  # <--
master this pr
Failed at line 8 (4 vs. 2) Success

The biggest issue is that zero_grad() will fail to reset the uncaptured variables.
I believe the currect LoRA implementation will not be trained correctly on torch backend.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To demonstrate the error of zero_grad using LoRA:

import torch

from keras import layers

layer = layers.Dense(units=3)
layer.build([None, 2])
layer.enable_lora(rank=1)


def list_grads(layer):
    grads = dict()
    for v in layer.trainable_variables:
        grads[v.name] = v.value.grad
    return grads


# Fake grads
for v in layer.trainable_variables:
    v.value.grad = torch.rand_like(v.value)

print(list_grads(layer))
layer.zero_grad()
print(list_grads(layer))
# master branch
{'bias': tensor([0.6259, 0.4827, 0.6012], device='cuda:0'), 'lora_kernel_a': tensor([[0.6620],
        [0.7231]], device='cuda:0'), 'lora_kernel_b': tensor([[0.7123, 0.9257, 0.1676]], device='cuda:0')}
{'bias': None, 'lora_kernel_a': tensor([[0.6620],
        [0.7231]], device='cuda:0'), 'lora_kernel_b': tensor([[0.7123, 0.9257, 0.1676]], device='cuda:0')}

# this pr
{'bias': tensor([0.5960, 0.2336, 0.1569], device='cuda:0'), 'lora_kernel_a': tensor([[0.9123],
        [0.9217]], device='cuda:0'), 'lora_kernel_b': tensor([[0.3435, 0.9276, 0.6599]], device='cuda:0')}
{'bias': None, 'lora_kernel_a': None, 'lora_kernel_b': None}

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

self.sparse = sparse

if self.bin_boundaries:
self.built = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly set self.built=True in __init__ will fail the run_build_asserts of symbolic call test in run_layer_test.
The root cause is that self.torch_params will not be initialized.

I'm unsure where the issue is but it should be acceptable to leave self.built=True for build method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from keras import layers

inputs = layers.Input([2])
layer = layers.Dropout(rate=0.2)
layer(inputs)
assert len(layer.torch_params) == len(layer.variables)

This script will fail at the master branch but work fine in this pr.

if rate > 0:
self.seed_generator = backend.random.SeedGenerator(seed)
self.supports_masking = True
self.built = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@james77777778 james77777778 requested a review from fchollet April 12, 2024 06:50
@james77777778 james77777778 marked this pull request as draft April 12, 2024 07:50
"""Can be overridden for per backend post track actions."""
pass

def _post_untrack_variable(self, variable):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two hooks have been introduced to enable the postprocessing of _track_variable and _untrack_variable in torch backend

@james77777778 james77777778 marked this pull request as ready for review April 12, 2024 08:56
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Apr 12, 2024
@fchollet fchollet merged commit 30622e3 into keras-team:master Apr 12, 2024
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Apr 12, 2024
@james77777778 james77777778 deleted the fix-torch_params branch April 13, 2024 02:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

4 participants