@@ -178,7 +178,6 @@ def __init__(
178
178
self ,
179
179
# ↓ this part is for pretrained weights
180
180
in_features : int ,
181
- out_features : int ,
182
181
# ↓ the remaining part is for LoRA
183
182
head_size : int ,
184
183
n_head : int ,
@@ -199,7 +198,6 @@ def __init__(
199
198
200
199
Args:
201
200
in_features: number of input features of the pretrained weights
202
- out_features: number of output features of the pretrained weights
203
201
head_size: size of a single attention head
204
202
n_head: number of attention heads
205
203
n_query_groups: number of query groups (see diagram in `litgpt/config.py`)
@@ -214,6 +212,7 @@ def __init__(
214
212
and `value` but keep `key` without weight updates we should pass `[True, False, True]`
215
213
"""
216
214
super (LoRALinear , self ).__init__ (r = r , lora_alpha = lora_alpha , lora_dropout = lora_dropout )
215
+ out_features = head_size * (n_head + 2 * n_query_groups )
217
216
self .linear = torch .nn .Linear (in_features , out_features , ** kwargs )
218
217
self .head_size = head_size
219
218
self .n_head = n_head
@@ -229,18 +228,17 @@ def __init__(
229
228
# ⚬ out_features: 384 (3 * embedding_size)
230
229
# ⚬ r: 2
231
230
# ⚬ enable_lora: [True, False, True]
231
+ self ._all_qkv_shapes = (
232
+ # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
233
+ # might not be equal to `head_size * n_head`, thus we use it directly here
234
+ head_size * n_head ,
235
+ head_size * n_query_groups ,
236
+ head_size * n_query_groups ,
237
+ )
232
238
if r > 0 and any (enable_lora ):
233
239
self .lora_A = nn .Parameter (torch .empty ((r * sum (enable_lora ), in_features ))) # (4, 128)
234
- enable_q , enable_k , enable_v = enable_lora
235
240
# qkv_shapes will be used to split a tensor with weights correctly
236
- qkv_shapes = (
237
- # if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
238
- # might not be equal to `head_size * n_head`, thus we use it directly here
239
- head_size * n_head * enable_q ,
240
- head_size * n_query_groups * enable_k ,
241
- head_size * n_query_groups * enable_v ,
242
- )
243
- self .qkv_shapes = [s for s in qkv_shapes if s ]
241
+ self .qkv_shapes = [s for s , e in zip (self ._all_qkv_shapes , enable_lora ) if e ]
244
242
self .lora_B = nn .Parameter (torch .empty (sum (self .qkv_shapes ), r )) # (256, 2))
245
243
# Notes about shapes above
246
244
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
@@ -266,15 +264,13 @@ def lora_ind(self) -> torch.Tensor:
266
264
"""Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used."""
267
265
# Indices are needed to properly pad weight updates with zeros.
268
266
if not hasattr (self , "_lora_ind" ):
269
- enable_q , enable_k , enable_v = self .enable_lora
270
- kv_embd_size = self .linear .in_features // (self .n_head // self .n_query_groups )
267
+ off = 0
271
268
lora_ind = []
272
- if enable_q :
273
- lora_ind .extend (range (0 , self .linear .in_features ))
274
- if enable_k :
275
- lora_ind .extend (range (self .linear .in_features , self .linear .in_features + kv_embd_size ))
276
- if enable_v :
277
- lora_ind .extend (range (self .linear .in_features + kv_embd_size , self .linear .out_features ))
269
+ for enable , size in zip (self .enable_lora , self ._all_qkv_shapes ):
270
+ if enable :
271
+ lora_ind .extend (range (off , off + size ))
272
+ off += size
273
+ assert len (lora_ind ) == sum (self .qkv_shapes ) # Sanity check
278
274
self .register_buffer (
279
275
"_lora_ind" , torch .tensor (lora_ind , device = self .linear .weight .device ), persistent = False
280
276
)
@@ -527,10 +523,8 @@ class CausalSelfAttention(BaseCausalSelfAttention):
527
523
def __init__ (self , config : Config , block_idx : int ) -> None :
528
524
super ().__init__ (config , block_idx )
529
525
# key, query, value projections for all heads, but in a batch
530
- shape = (config .n_head + 2 * config .n_query_groups ) * config .head_size
531
526
self .qkv = LoRAQKVLinear (
532
527
in_features = config .n_embd ,
533
- out_features = shape ,
534
528
r = config .lora_r ,
535
529
lora_alpha = config .lora_alpha ,
536
530
lora_dropout = config .lora_dropout ,
0 commit comments