Skip to content

Commit fcd4136

Browse files
committed
Fix in LoRA code
1 parent 0b0c819 commit fcd4136

File tree

2 files changed

+55
-31
lines changed

2 files changed

+55
-31
lines changed

litgpt/lora.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ def __init__(
178178
self,
179179
# ↓ this part is for pretrained weights
180180
in_features: int,
181-
out_features: int,
182181
# ↓ the remaining part is for LoRA
183182
head_size: int,
184183
n_head: int,
@@ -199,7 +198,6 @@ def __init__(
199198
200199
Args:
201200
in_features: number of input features of the pretrained weights
202-
out_features: number of output features of the pretrained weights
203201
head_size: size of a single attention head
204202
n_head: number of attention heads
205203
n_query_groups: number of query groups (see diagram in `litgpt/config.py`)
@@ -214,6 +212,7 @@ def __init__(
214212
and `value` but keep `key` without weight updates we should pass `[True, False, True]`
215213
"""
216214
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
215+
out_features = head_size * (n_head + 2 * n_query_groups)
217216
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
218217
self.head_size = head_size
219218
self.n_head = n_head
@@ -229,18 +228,17 @@ def __init__(
229228
# ⚬ out_features: 384 (3 * embedding_size)
230229
# ⚬ r: 2
231230
# ⚬ enable_lora: [True, False, True]
231+
self._all_qkv_shapes = (
232+
# if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
233+
# might not be equal to `head_size * n_head`, thus we use it directly here
234+
head_size * n_head,
235+
head_size * n_query_groups,
236+
head_size * n_query_groups,
237+
)
232238
if r > 0 and any(enable_lora):
233239
self.lora_A = nn.Parameter(torch.empty((r * sum(enable_lora), in_features))) # (4, 128)
234-
enable_q, enable_k, enable_v = enable_lora
235240
# qkv_shapes will be used to split a tensor with weights correctly
236-
qkv_shapes = (
237-
# if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
238-
# might not be equal to `head_size * n_head`, thus we use it directly here
239-
head_size * n_head * enable_q,
240-
head_size * n_query_groups * enable_k,
241-
head_size * n_query_groups * enable_v,
242-
)
243-
self.qkv_shapes = [s for s in qkv_shapes if s]
241+
self.qkv_shapes = [s for s, e in zip(self._all_qkv_shapes, enable_lora) if e]
244242
self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r)) # (256, 2))
245243
# Notes about shapes above
246244
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
@@ -266,15 +264,13 @@ def lora_ind(self) -> torch.Tensor:
266264
"""Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used."""
267265
# Indices are needed to properly pad weight updates with zeros.
268266
if not hasattr(self, "_lora_ind"):
269-
enable_q, enable_k, enable_v = self.enable_lora
270-
kv_embd_size = self.linear.in_features // (self.n_head // self.n_query_groups)
267+
off = 0
271268
lora_ind = []
272-
if enable_q:
273-
lora_ind.extend(range(0, self.linear.in_features))
274-
if enable_k:
275-
lora_ind.extend(range(self.linear.in_features, self.linear.in_features + kv_embd_size))
276-
if enable_v:
277-
lora_ind.extend(range(self.linear.in_features + kv_embd_size, self.linear.out_features))
269+
for enable, size in zip(self.enable_lora, self._all_qkv_shapes):
270+
if enable:
271+
lora_ind.extend(range(off, off + size))
272+
off += size
273+
assert len(lora_ind) == sum(self.qkv_shapes) # Sanity check
278274
self.register_buffer(
279275
"_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False
280276
)
@@ -527,10 +523,8 @@ class CausalSelfAttention(BaseCausalSelfAttention):
527523
def __init__(self, config: Config, block_idx: int) -> None:
528524
super().__init__(config, block_idx)
529525
# key, query, value projections for all heads, but in a batch
530-
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
531526
self.qkv = LoRAQKVLinear(
532527
in_features=config.n_embd,
533-
out_features=shape,
534528
r=config.lora_r,
535529
lora_alpha=config.lora_alpha,
536530
lora_dropout=config.lora_dropout,

tests/test_lora.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def __init__(self, *args, **kwargs):
248248
original_linear = torch.nn.Linear
249249
# Our bnb does this sort of monkey patching
250250
torch.nn.Linear = MyLinear
251-
layer = LoRAQKVLinear(1, 1, 1, 1, 1)
251+
layer = LoRAQKVLinear(1, 1, 1, 1)
252252
assert isinstance(layer.linear, original_linear)
253253
torch.nn.Linear = original_linear
254254

@@ -354,9 +354,7 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap
354354
)
355355
def test_lora_qkv_linear_compare_conv1d(head_size, n_head, enable_lora):
356356
C = 12
357-
layer = LoRAQKVLinear(
358-
C, 3 * C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora
359-
)
357+
layer = LoRAQKVLinear(C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora)
360358
x = torch.randn((1, 1, C))
361359
a = F.linear(x, layer.lora_A).transpose(-2, -1) # after_A
362360
b = layer.lora_B.data.unsqueeze(-1)
@@ -389,7 +387,7 @@ def test_lora_linear_weights_merged_status(rank, expected_merged):
389387
)
390388
def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merged):
391389
C = 10
392-
layer = LoRAQKVLinear(C, 3 * C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora)
390+
layer = LoRAQKVLinear(C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora)
393391
assert not layer.merged
394392
layer.merge()
395393
assert layer.merged == expected_merged
@@ -909,14 +907,11 @@ def test_zero_pad_cpu_and_mocked_mps():
909907
n_head = 12
910908
n_query_groups = 3
911909
in_features = 128
912-
kv_embed_dim = in_features // (n_head // n_query_groups)
913-
out_features = in_features + 2 * kv_embed_dim
914-
enable_lora = [True, False, True]
910+
enable_lora = (True, False, True)
915911
r = 4
916912

917913
model = LoRAQKVLinear(
918914
in_features=in_features,
919-
out_features=out_features,
920915
head_size=head_size,
921916
n_head=n_head,
922917
n_query_groups=n_query_groups,
@@ -926,7 +921,7 @@ def test_zero_pad_cpu_and_mocked_mps():
926921

927922
batch_size = 64
928923
seq_len = 64
929-
embed_dim = 160
924+
embed_dim = head_size * (n_head + n_query_groups)
930925
x = torch.randn(batch_size, seq_len, embed_dim)
931926

932927
result_cpu = model.zero_pad(x)
@@ -1167,3 +1162,38 @@ def test_load_from_full_model_state_dict():
11671162

11681163
output_cpu_offload = model_cpu_offload(x)
11691164
assert output_cpu_offload.shape == (1, config.block_size, config.padded_vocab_size)
1165+
1166+
1167+
def test_forward_qwen3_4b():
1168+
"""
1169+
Tests whether LoRA forward works with the Qwen3-4B model, whose
1170+
transformer block has some non-standard configs:
1171+
1172+
* `n_embd != n_head * head_size`
1173+
* `n_head != n_query_groups`
1174+
* LoRA adapters added to queries and values, but not keys
1175+
1176+
"""
1177+
device = torch.device("cpu")
1178+
T = 20
1179+
config = Config.from_name(
1180+
"Qwen3-4B",
1181+
block_size=T,
1182+
n_layer=2,
1183+
vocab_size=128,
1184+
padded_vocab_size=128,
1185+
lora_r=8,
1186+
lora_alpha=16,
1187+
lora_dropout=0.05,
1188+
lora_query=True,
1189+
lora_key=False,
1190+
lora_value=True,
1191+
)
1192+
model = LoRAGPT(config)
1193+
x = torch.randint(
1194+
low=0,
1195+
high=config.padded_vocab_size,
1196+
size=(2, T),
1197+
device=device,
1198+
)
1199+
y = model(x)

0 commit comments

Comments
 (0)