Skip to content

Commit 06255ac

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f659f2d commit 06255ac

File tree

8 files changed

+134
-80
lines changed

8 files changed

+134
-80
lines changed

litgpt/attention.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import math
2-
from typing import Optional, Tuple, List, Union
2+
from typing import List, Optional, Tuple, Union
33

44
import torch
55
from torch.nn import functional as F
66
from torch.nn.attention import SDPBackend, sdpa_kernel
77

88
from litgpt.attention_utils import (
9-
filter_sdpa_kernels,
10-
build_mask_cache,
11-
build_mask_slice,
129
attention_compute_scores,
1310
attention_compute_weighted_values,
11+
build_mask_cache,
12+
build_mask_slice,
13+
filter_sdpa_kernels,
1414
)
1515
from litgpt.config import Config
1616

@@ -94,6 +94,7 @@ class MultiHeadSelfAttention:
9494
`torch.nn.functional.scaled_dot_product_attention` is never used.
9595
9696
"""
97+
9798
def __init__(
9899
self,
99100
config: Config,
@@ -156,12 +157,16 @@ def __call__(
156157
if use_eager_sdpa or sliding_window_size is not None or not is_causal:
157158
# Build attention mask
158159
if is_causal:
159-
mask = build_mask_cache(
160-
max_seq_length=T,
161-
sliding_window_size=sliding_window_size,
162-
dtype=query.dtype,
163-
device=query.device,
164-
).view(1, 1, T, T).detach()
160+
mask = (
161+
build_mask_cache(
162+
max_seq_length=T,
163+
sliding_window_size=sliding_window_size,
164+
dtype=query.dtype,
165+
device=query.device,
166+
)
167+
.view(1, 1, T, T)
168+
.detach()
169+
)
165170
elif (not use_eager_sdpa) or T > 1:
166171
# We need a mask if T > 1, since inference needs to be causal
167172
# for the new tokens
@@ -196,7 +201,12 @@ def _use_eager_sdpa(
196201
return_attn_weights: bool,
197202
k_and_v: KeysAndValues,
198203
) -> bool:
199-
return return_attn_weights or self.use_eager_sdpa_always or self.config.attention_logit_softcapping is not None or not k_and_v.both_in_parallel()
204+
return (
205+
return_attn_weights
206+
or self.use_eager_sdpa_always
207+
or self.config.attention_logit_softcapping is not None
208+
or not k_and_v.both_in_parallel()
209+
)
200210

201211
def _filter_sdpa_kernels(
202212
self,

litgpt/attention_utils.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import torch
44
from torch.backends.cuda import (
5-
can_use_flash_attention,
6-
can_use_efficient_attention,
75
can_use_cudnn_attention,
6+
can_use_efficient_attention,
7+
can_use_flash_attention,
88
)
9-
from torch.nn.attention import SDPBackend, SDPAParams
9+
from torch.nn.attention import SDPAParams, SDPBackend
1010

1111

1212
def filter_sdpa_kernels(
@@ -20,9 +20,7 @@ def filter_sdpa_kernels(
2020
enable_gqa: bool,
2121
**kwargs,
2222
) -> List[SDPBackend]:
23-
params = SDPAParams(
24-
query, key, value, attn_mask, dropout_p, is_causal, enable_gqa
25-
)
23+
params = SDPAParams(query, key, value, attn_mask, dropout_p, is_causal, enable_gqa)
2624
new_kernels = []
2725
for kernel in sdpa_kernels:
2826
if kernel == SDPBackend.FLASH_ATTENTION and not can_use_flash_attention(params):
@@ -103,7 +101,10 @@ def mask_cache_bool(
103101
) -> torch.Tensor:
104102
# Usual causal mask:
105103
mask = torch.ones(
106-
max_seq_length, max_seq_length, device=device, dtype=dtype,
104+
max_seq_length,
105+
max_seq_length,
106+
device=device,
107+
dtype=dtype,
107108
).triu(diagonal=1)
108109
if sliding_window_size is not None:
109110
mask += torch.ones_like(mask).tril(diagonal=-sliding_window_size)
@@ -147,25 +148,52 @@ def mask_slice_bool(
147148
tp_dtype = token_positions.dtype
148149
batch_size, n_query_groups, _ = token_positions.shape
149150
assert n_head % n_query_groups == 0 and n_head >= n_query_groups
150-
token_positions = token_positions.to(device=device).unsqueeze(2).expand(
151-
-1, -1, num, -1,
151+
token_positions = (
152+
token_positions.to(device=device)
153+
.unsqueeze(2)
154+
.expand(
155+
-1,
156+
-1,
157+
num,
158+
-1,
159+
)
152160
)
153161
kwargs = dict(device=device, dtype=tp_dtype)
154-
bool_mask = torch.arange(
155-
input_pos, input_pos + num, **kwargs,
156-
).view(1, 1, -1, 1).expand_as(token_positions) < token_positions
157-
if sliding_window_size is not None:
158-
extra_mask = torch.arange(
159-
input_pos - sliding_window_size,
160-
input_pos + num - sliding_window_size,
162+
bool_mask = (
163+
torch.arange(
164+
input_pos,
165+
input_pos + num,
161166
**kwargs,
162-
).view(1, 1, -1, 1).expand_as(token_positions) >= token_positions
167+
)
168+
.view(1, 1, -1, 1)
169+
.expand_as(token_positions)
170+
< token_positions
171+
)
172+
if sliding_window_size is not None:
173+
extra_mask = (
174+
torch.arange(
175+
input_pos - sliding_window_size,
176+
input_pos + num - sliding_window_size,
177+
**kwargs,
178+
)
179+
.view(1, 1, -1, 1)
180+
.expand_as(token_positions)
181+
>= token_positions
182+
)
163183
bool_mask |= extra_mask
164184
if n_head != n_query_groups:
165185
q_per_kv = n_head // n_query_groups
166-
bool_mask = bool_mask.unsqueeze(2).expand(
167-
-1, -1, q_per_kv, -1, -1,
168-
).reshape(batch_size, n_head, num, -1)
186+
bool_mask = (
187+
bool_mask.unsqueeze(2)
188+
.expand(
189+
-1,
190+
-1,
191+
q_per_kv,
192+
-1,
193+
-1,
194+
)
195+
.reshape(batch_size, n_head, num, -1)
196+
)
169197
return bool_mask
170198

171199

@@ -197,7 +225,12 @@ def build_mask_slice(
197225
198226
"""
199227
bool_mask = mask_slice_bool(
200-
input_pos, num, token_positions, n_head, device, sliding_window_size,
228+
input_pos,
229+
num,
230+
token_positions,
231+
n_head,
232+
device,
233+
sliding_window_size,
201234
)
202235
mask = torch.zeros(bool_mask.shape, dtype=dtype, device=device)
203236
mask.masked_fill_(bool_mask, minus_infinity(dtype))

litgpt/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def find_multiple(n: int, k: int) -> int:
2121
return n
2222
return n + k - (n % k)
2323

24+
2425
# See `Config.start_of_layer_hook`. A start of layer hook is called just before
2526
# a layer is computed. The call is `hook(x, block_idx, input_pos)`, where
2627
# `x` is the layer input, `block_idx` the number of the layer, and `input_pos`

litgpt/generate/speculative_decoding.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
multinomial_num_samples_1,
2121
sample_top_p,
2222
)
23+
from litgpt.kvcache import DenseKVCache
2324
from litgpt.model import GPT
2425
from litgpt.prompts import PromptStyle, has_prompt_style, load_prompt_style
2526
from litgpt.tokenizer import Tokenizer
@@ -30,7 +31,6 @@
3031
get_default_supported_precision,
3132
load_checkpoint,
3233
)
33-
from litgpt.kvcache import DenseKVCache
3434

3535

3636
def sample(
@@ -149,7 +149,8 @@ def speculative_decoding(
149149
draft_token = token
150150
for idx in range(speculative_k):
151151
logits = draft_model(
152-
idx=draft_token.unsqueeze(0), input_pos=draft_input_pos,
152+
idx=draft_token.unsqueeze(0),
153+
input_pos=draft_input_pos,
153154
)
154155
draft_token, draft_prob = sample(logits, **sample_kwargs)
155156
draft_input_pos += 1
@@ -161,7 +162,8 @@ def speculative_decoding(
161162
# Feed both original token and draft tokens to get target probabilities
162163
candidate_tokens = torch.cat((token, draft_tokens))
163164
target_logits = target_model(
164-
idx=candidate_tokens.unsqueeze(0), input_pos=input_pos,
165+
idx=candidate_tokens.unsqueeze(0),
166+
input_pos=input_pos,
165167
)
166168

167169
# Step 3: Convert target logits to probabilities using same sampling params
@@ -211,7 +213,7 @@ def speculative_decoding(
211213
draft_model(idx=draft_token.unsqueeze(0), input_pos=draft_input_pos)
212214
new_token, _ = sample(target_logits, **sample_kwargs)
213215
else:
214-
input_pos += (len(accepted_tokens) + 1)
216+
input_pos += len(accepted_tokens) + 1
215217
_resize_kv_caches(draft_model, input_pos)
216218
_resize_kv_caches(target_model, input_pos)
217219
return torch.cat((*accepted_tokens, new_token))
@@ -316,7 +318,10 @@ def generate(
316318
)
317319
_process_prompt(draft_model, prompt, prompt_chunksize, **sample_kwargs)
318320
token = _process_prompt(
319-
target_model, prompt, prompt_chunksize, **sample_kwargs,
321+
target_model,
322+
prompt,
323+
prompt_chunksize,
324+
**sample_kwargs,
320325
)
321326
input_pos = prompt_size
322327

litgpt/kvcache/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Dict, Optional, Tuple, Union, List
2+
from typing import Dict, List, Optional, Tuple, Union
33

44
import torch
55
from torch.nn.attention import SDPBackend

litgpt/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
from functools import partial
1111
from typing import Any, List, Optional, Tuple, Union
1212

13-
from typing_extensions import Self
14-
1513
import torch
1614
import torch.nn as nn
1715
import torch.nn.functional as F
16+
from typing_extensions import Self
1817

1918
from litgpt.attention import (
2019
DefaultKeysAndValues,
@@ -604,7 +603,6 @@ def __init__(
604603
self.config = config
605604
self.block_idx = block_idx
606605

607-
608606
def forward(
609607
self,
610608
x: torch.Tensor,

0 commit comments

Comments
 (0)