Skip to content

Commit b1e4057

Browse files
Fix SeedGenerator in tf.distribute.MirroredStrategy (#20612)
1 parent 811d004 commit b1e4057

File tree

3 files changed

+10
-7
lines changed

3 files changed

+10
-7
lines changed

keras/src/backend/tensorflow/core.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,7 @@ def _initialize(self, value):
4444
)
4545

4646
def _initialize_with_initializer(self, initializer):
47-
self._value = tf.Variable(
48-
lambda: initializer(self._shape, dtype=self._dtype),
49-
dtype=self._dtype,
50-
trainable=self.trainable,
51-
name=self.name,
52-
aggregation=self._map_aggregation(self.aggregation),
53-
)
47+
self._initialize(lambda: initializer(self._shape, dtype=self._dtype))
5448

5549
def _deferred_initialize(self):
5650
if self._value is not None:

keras/src/backend/tensorflow/distribute_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,14 @@ def test_variable_aggregation(self):
137137
self.assertEqual(v2.aggregation, "sum")
138138
self.assertEqual(v2.value.aggregation, tf.VariableAggregation.SUM)
139139

140+
def test_seed_generator(self):
141+
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])
142+
with strategy.scope():
143+
seed_generator = keras.random.SeedGenerator(42)
144+
states = strategy.run(lambda: seed_generator.state.value).values
145+
for s in states:
146+
self.assertAllClose(keras.ops.convert_to_numpy(s), (42, 0))
147+
140148
def test_correctness_with_fit_and_regularizer(self):
141149
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])
142150

keras/src/random/seed_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def seed_initializer(*args, **kwargs):
8989
shape=(2,),
9090
dtype=self.backend.random_seed_dtype(),
9191
trainable=False,
92+
aggregation="none",
9293
name="seed_generator_state",
9394
)
9495

0 commit comments

Comments
 (0)