-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix path of the states in SeedGenerator
and tracking of torch_params
#19495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
SeedGenerator
SeedGenerator
and tracking of torch_params
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19495 +/- ##
=======================================
Coverage 76.25% 76.25%
=======================================
Files 367 367
Lines 41195 41226 +31
Branches 8066 8077 +11
=======================================
+ Hits 31413 31438 +25
Misses 8060 8060
- Partials 1722 1728 +6
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Uniquifying paths for seed generator states is a good idea.
def _track_variables(self): | ||
self.torch_params = torch.nn.ParameterList( | ||
[variable.value for variable in self.variables] | ||
self.torch_params = torch.nn.ParameterDict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dict is more readable, but it's unsafe -- there's no guarantee that variable names are unique within a model (except for Functional models). It's entirely possible to create models with duplicate variable paths, which would cause tracking issues above. So the list is preferable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be safe to use id(variable)
as a key for torch.nn.ParameterDict
. I believe BaseOptimizer
adopts the same approach to get the mapping of variables.
I could not find a solution for safely adding/removing the variable from torch.nn.ParameterList
(in
and remove
of KerasVariable
are not supported)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the original issue with using a list though?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that torch.nn.ParameterList
cannot remove elements. This will cause failure in the int8 quantization because we need to remove the old floating _kernel
.
Also, it is hard to determine if it is safe to append
a new variable because we cannot check whether the item is already in the list.
from keras import layers
layer = layers.Dense(units=16)
layer.build([None, 8])
assert len(layer.torch_params) == len(layer.variables)
layer.enable_lora(rank=2)
assert len(layer.torch_params) == len(layer.variables) # <--
master |
this pr |
---|---|
Failed at line 8 (4 vs. 2) | Success |
The biggest issue is that zero_grad()
will fail to reset the uncaptured variables.
I believe the currect LoRA implementation will not be trained correctly on torch backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To demonstrate the error of zero_grad
using LoRA:
import torch
from keras import layers
layer = layers.Dense(units=3)
layer.build([None, 2])
layer.enable_lora(rank=1)
def list_grads(layer):
grads = dict()
for v in layer.trainable_variables:
grads[v.name] = v.value.grad
return grads
# Fake grads
for v in layer.trainable_variables:
v.value.grad = torch.rand_like(v.value)
print(list_grads(layer))
layer.zero_grad()
print(list_grads(layer))
# master branch
{'bias': tensor([0.6259, 0.4827, 0.6012], device='cuda:0'), 'lora_kernel_a': tensor([[0.6620],
[0.7231]], device='cuda:0'), 'lora_kernel_b': tensor([[0.7123, 0.9257, 0.1676]], device='cuda:0')}
{'bias': None, 'lora_kernel_a': tensor([[0.6620],
[0.7231]], device='cuda:0'), 'lora_kernel_b': tensor([[0.7123, 0.9257, 0.1676]], device='cuda:0')}
# this pr
{'bias': tensor([0.5960, 0.2336, 0.1569], device='cuda:0'), 'lora_kernel_a': tensor([[0.9123],
[0.9217]], device='cuda:0'), 'lora_kernel_b': tensor([[0.3435, 0.9276, 0.6599]], device='cuda:0')}
{'bias': None, 'lora_kernel_a': None, 'lora_kernel_b': None}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
self.sparse = sparse | ||
|
||
if self.bin_boundaries: | ||
self.built = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly set self.built=True
in __init__
will fail the run_build_asserts
of symbolic call test in run_layer_test
.
The root cause is that self.torch_params
will not be initialized.
I'm unsure where the issue is but it should be acceptable to leave self.built=True
for build
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from keras import layers
inputs = layers.Input([2])
layer = layers.Dropout(rate=0.2)
layer(inputs)
assert len(layer.torch_params) == len(layer.variables)
This script will fail at the master branch but work fine in this pr.
if rate > 0: | ||
self.seed_generator = backend.random.SeedGenerator(seed) | ||
self.supports_masking = True | ||
self.built = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
"""Can be overridden for per backend post track actions.""" | ||
pass | ||
|
||
def _post_untrack_variable(self, variable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two hooks have been introduced to enable the postprocessing of _track_variable
and _untrack_variable
in torch backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
This PR:
SeedGenerator
seed_generators
attr toLayer
(primarily for testing purposes)torch_params
in torch backendtorch_params
in_track_variable
and_untrack_variable
In the current codebase, the fact that all states in the seed generator share an identical path
seed_generator_state
should be considered an issue.Additionally, I have found some incorrect
expected_num_seed_generators
inrun_layer_test
within certain tests.