Skip to content

Commit 725b182

Browse files
Merge branch 'keras-team:master' into VGG16
2 parents 822d801 + b890ca9 commit 725b182

File tree

8 files changed

+382
-34
lines changed

8 files changed

+382
-34
lines changed

keras_nlp/src/models/gemma/gemma_attention.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ def _compute_attention(
122122
v,
123123
attention_mask,
124124
training=False,
125+
cache_update_index=0,
125126
):
126127
if self.query_head_dim_normalize:
127128
query_normalization = 1 / np.sqrt(self.head_dim)
@@ -152,29 +153,10 @@ def _compute_attention(
152153
)
153154

154155
if self.use_sliding_window_attention:
155-
all_ones = ops.ones_like(attention_mask)
156-
if keras.config.backend() == "tensorflow":
157-
import tensorflow as tf
158-
159-
sliding_window_size = ops.minimum(
160-
self.sliding_window_size - 1, q_len
161-
)
162-
sliding_window_size = ops.cast(
163-
sliding_window_size, dtype="int32"
164-
)
165-
sliding_mask = tf.linalg.band_part(
166-
all_ones, sliding_window_size - 1, sliding_window_size - 1
167-
)
168-
sliding_mask = ops.cast(sliding_mask, dtype="bool")
169-
bool_attention_mask = ops.cast(attention_mask, dtype="bool")
170-
attention_mask = tf.math.logical_and(
171-
sliding_mask, bool_attention_mask
172-
)
173-
else:
174-
sliding_mask = ops.triu(
175-
all_ones, -1 * self.sliding_window_size + 1
176-
) * ops.tril(all_ones, self.sliding_window_size - 1)
177-
attention_mask = sliding_mask * attention_mask
156+
attention_mask = self._mask_sliding_window(
157+
attention_mask,
158+
cache_update_index=cache_update_index,
159+
)
178160

179161
attention_mask = attention_mask[:, None, None, :, :]
180162
orig_dtype = attention_logits.dtype
@@ -189,6 +171,32 @@ def _compute_attention(
189171
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
190172
return ops.reshape(results, (b, q_len, self.num_query_heads, h))
191173

174+
def _mask_sliding_window(
175+
self,
176+
attention_mask,
177+
cache_update_index=0,
178+
):
179+
batch_size, query_len, key_len = ops.shape(attention_mask)
180+
# Compute the sliding window for square attention.
181+
all_ones = ops.ones((key_len, key_len), "bool")
182+
if keras.config.backend() == "tensorflow":
183+
# TODO: trui/tril has issues with dynamic shape on the tensorflow
184+
# backend. We should fix, but use `band_part` for now.
185+
import tensorflow as tf
186+
187+
band_size = ops.minimum(key_len, self.sliding_window_size - 1)
188+
band_size = ops.cast(band_size, "int32")
189+
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
190+
else:
191+
sliding_mask = ops.triu(
192+
all_ones, -1 * self.sliding_window_size + 1
193+
) * ops.tril(all_ones, self.sliding_window_size - 1)
194+
# Slice the window for short queries during generation.
195+
start = (cache_update_index, 0)
196+
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
197+
sliding_mask = ops.expand_dims(sliding_mask, 0)
198+
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))
199+
192200
def call(
193201
self,
194202
x,
@@ -216,7 +224,12 @@ def call(
216224
value = self.value_dense(x)
217225

218226
attention_vec = self._compute_attention(
219-
query, key, value, attention_mask, training=training
227+
query,
228+
key,
229+
value,
230+
attention_mask,
231+
training=training,
232+
cache_update_index=cache_update_index,
220233
)
221234

222235
# Wipe attn vec if there are no attended tokens.

keras_nlp/src/models/gemma/gemma_backbone_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,27 @@ def test_backbone_basics(self):
203203
expected_output_shape=(2, 10, 16),
204204
)
205205

206+
def test_sliding_window(self):
207+
# Test sliding window correctness by hand.
208+
backbone = GemmaBackbone(**self.init_kwargs)
209+
attention = backbone.transformer_layers[0].attention
210+
mask = attention._mask_sliding_window(ops.ones((1, 10, 10), "bool"))
211+
expected = [
212+
[
213+
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
214+
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
215+
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
216+
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
217+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
218+
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
219+
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
220+
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
221+
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
222+
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
223+
]
224+
]
225+
self.assertAllEqual(mask, expected)
226+
206227
@pytest.mark.large
207228
def test_saved_model(self):
208229
self.run_model_saving_test(

keras_nlp/src/models/gemma/gemma_causal_lm_test.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,22 @@ def setUp(self):
3939
self.tokenizer,
4040
sequence_length=8,
4141
)
42+
# Test Gemma 2 like config, as it's the more complicated case.
4243
self.backbone = GemmaBackbone(
4344
vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
4445
num_layers=2,
45-
num_query_heads=2,
46-
num_key_value_heads=1,
47-
hidden_dim=4,
48-
intermediate_dim=8,
46+
num_query_heads=4,
47+
num_key_value_heads=2,
48+
hidden_dim=8,
49+
intermediate_dim=16,
4950
head_dim=2,
51+
sliding_window_size=3,
52+
use_sliding_window_attention=True,
53+
attention_logit_soft_cap=50,
54+
final_logit_soft_cap=30,
55+
query_head_dim_normalize=False,
56+
use_post_ffw_norm=True,
57+
use_post_attention_norm=True,
5058
)
5159
self.init_kwargs = {
5260
"preprocessor": self.preprocessor,
@@ -63,6 +71,24 @@ def test_causal_lm_basics(self):
6371
expected_output_shape=(2, 8, 11),
6472
)
6573

74+
def test_cache_correctness(self):
75+
token_ids = self.input_data["token_ids"]
76+
padding_mask = ops.ones_like(self.input_data["padding_mask"])
77+
causal_lm = GemmaCausalLM(**self.init_kwargs)
78+
full_logits = causal_lm(
79+
{"token_ids": token_ids, "padding_mask": padding_mask}
80+
)
81+
token_ids = self.input_data["token_ids"]
82+
_, cache = causal_lm._build_cache(token_ids)
83+
cache = ops.zeros_like(cache)
84+
cached_logits = []
85+
for i in range(self.preprocessor.sequence_length):
86+
sliced = token_ids[:, i][:, None]
87+
logits, _, cache = causal_lm.call_with_cache(sliced, cache, i)
88+
cached_logits.append(logits)
89+
cached_logits = ops.concatenate(cached_logits, 1)
90+
self.assertAllClose(full_logits, cached_logits, atol=0.002)
91+
6692
def test_generate(self):
6793
causal_lm = GemmaCausalLM(**self.init_kwargs)
6894
# String input.
@@ -230,7 +256,7 @@ def test_score_layer_intercept_fn_exfiltration(self):
230256
# Setup prompts, models, and associated expected shapes.
231257
prompts = ["the quick brown fox", "the quick brown fox"]
232258
causal_lm = GemmaCausalLM(**self.init_kwargs)
233-
expected_embedded_shape = (2, 8, 4)
259+
expected_embedded_shape = (2, 8, 8)
234260
expected_score_shape = (2, 8, 11)
235261

236262
# Preprocess prompts to get tokenized representations and padding masks.

keras_nlp/src/models/llama/llama_presets.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,59 @@
1717
backbone_presets = {
1818
"llama2_7b_en": {
1919
"metadata": {
20-
"description": "LLaMA 2 7B Base model",
20+
"description": "7 billion parameter, 32-layer, base LLaMA 2 model.",
2121
"params": 6738415616,
2222
"official_name": "LLaMA 2",
2323
"path": "llama2",
2424
"model_card": "https://github.com/meta-llama/llama",
2525
},
2626
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en/1",
2727
},
28+
"llama2_7b_en_int8": {
29+
"metadata": {
30+
"description": (
31+
"7 billion parameter, 32-layer, base LLaMA 2 model with "
32+
"activation and weights quantized to int8."
33+
),
34+
"params": 6739839488,
35+
"official_name": "LLaMA 2",
36+
"path": "llama2",
37+
"model_card": "https://github.com/meta-llama/llama",
38+
},
39+
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_7b_en_int8/1",
40+
},
2841
"llama2_instruct_7b_en": {
2942
"metadata": {
30-
"description": "LLaMA 2 7B Chat model",
43+
"description": (
44+
"7 billion parameter, 32-layer, instruction tuned LLaMA 2 "
45+
"model."
46+
),
3147
"params": 6738415616,
3248
"official_name": "LLaMA 2",
3349
"path": "llama2",
3450
"model_card": "https://github.com/meta-llama/llama",
3551
},
3652
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en/1",
3753
},
54+
"llama2_instruct_7b_en_int8": {
55+
"metadata": {
56+
"description": (
57+
"7 billion parameter, 32-layer, instruction tuned LLaMA 2 "
58+
"model with activation and weights quantized to int8."
59+
),
60+
"params": 6739839488,
61+
"official_name": "LLaMA 2",
62+
"path": "llama2",
63+
"model_card": "https://github.com/meta-llama/llama",
64+
},
65+
"kaggle_handle": "kaggle://keras/llama2/keras/llama2_instruct_7b_en_int8/1",
66+
},
3867
"vicuna_1.5_7b_en": {
3968
"metadata": {
40-
"description": "Vicuna v1.5 7B Chat model",
69+
"description": (
70+
"7 billion parameter, 32-layer, instruction tuned Vicuna v1.5 "
71+
"model."
72+
),
4173
"params": 6738415616,
4274
"official_name": "Vicuna",
4375
"path": "vicuna",

keras_nlp/src/models/llama3/llama3_presets.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,53 @@
1717
backbone_presets = {
1818
"llama3_8b_en": {
1919
"metadata": {
20-
"description": "LLaMA 3 8B Base model",
20+
"description": "8 billion parameter, 32-layer, base LLaMA 3 model.",
2121
"params": 8030261248,
2222
"official_name": "LLaMA 3",
2323
"path": "llama3",
2424
"model_card": "https://github.com/meta-llama/llama3",
2525
},
2626
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en/3",
2727
},
28+
"llama3_8b_en_int8": {
29+
"metadata": {
30+
"description": (
31+
"8 billion parameter, 32-layer, base LLaMA 3 model with "
32+
"activation and weights quantized to int8."
33+
),
34+
"params": 8031894016,
35+
"official_name": "LLaMA 3",
36+
"path": "llama3",
37+
"model_card": "https://github.com/meta-llama/llama3",
38+
},
39+
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_8b_en_int8/1",
40+
},
2841
"llama3_instruct_8b_en": {
2942
"metadata": {
30-
"description": "LLaMA 3 8B Instruct model",
43+
"description": (
44+
"8 billion parameter, 32-layer, instruction tuned LLaMA 3 "
45+
"model."
46+
),
3147
"params": 8030261248,
3248
"official_name": "LLaMA 3",
3349
"path": "llama3",
3450
"model_card": "https://github.com/meta-llama/llama3",
3551
},
3652
"kaggle_handle": "kaggle://keras/llama3/keras/llama3_instruct_8b_en/3",
3753
},
54+
"llama3_instruct_8b_en_int8": {
55+
"metadata": {
56+
"description": (
57+
"8 billion parameter, 32-layer, instruction tuned LLaMA 3 "
58+
"model with activation and weights quantized to int8."
59+
),
60+
"params": 8031894016,
61+
"official_name": "LLaMA 3",
62+
"path": "llama3",
63+
"model_card": "https://github.com/meta-llama/llama3",
64+
},
65+
"kaggle_handle": (
66+
"kaggle://keras/llama3/keras/llama3_instruct_8b_en_int8/1"
67+
),
68+
},
3869
}

keras_nlp/src/utils/transformers/convert.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616

1717
from keras_nlp.src.utils.transformers.convert_bert import load_bert_backbone
1818
from keras_nlp.src.utils.transformers.convert_bert import load_bert_tokenizer
19+
from keras_nlp.src.utils.transformers.convert_distilbert import (
20+
load_distilbert_backbone,
21+
)
22+
from keras_nlp.src.utils.transformers.convert_distilbert import (
23+
load_distilbert_tokenizer,
24+
)
1925
from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_backbone
2026
from keras_nlp.src.utils.transformers.convert_gemma import load_gemma_tokenizer
2127
from keras_nlp.src.utils.transformers.convert_gpt2 import load_gpt2_backbone
@@ -56,6 +62,8 @@ def load_transformers_backbone(cls, preset, load_weights):
5662
return load_pali_gemma_backbone(cls, preset, load_weights)
5763
if cls.__name__ == "GPT2Backbone":
5864
return load_gpt2_backbone(cls, preset, load_weights)
65+
if cls.__name__ == "DistilBertBackbone":
66+
return load_distilbert_backbone(cls, preset, load_weights)
5967
raise ValueError(
6068
f"{cls} has not been ported from the Hugging Face format yet. "
6169
"Please check Hugging Face Hub for the Keras model. "
@@ -85,6 +93,8 @@ def load_transformers_tokenizer(cls, preset):
8593
return load_pali_gemma_tokenizer(cls, preset)
8694
if cls.__name__ == "GPT2Tokenizer":
8795
return load_gpt2_tokenizer(cls, preset)
96+
if cls.__name__ == "DistilBertTokenizer":
97+
return load_distilbert_tokenizer(cls, preset)
8898
raise ValueError(
8999
f"{cls} has not been ported from the Hugging Face format yet. "
90100
"Please check Hugging Face Hub for the Keras model. "

0 commit comments

Comments
 (0)