Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions litgpt/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def __init__(
self,
# ↓ this part is for pretrained weights
in_features: int,
out_features: int,
# ↓ the remaining part is for LoRA
head_size: int,
n_head: int,
Expand All @@ -199,7 +198,6 @@ def __init__(

Args:
in_features: number of input features of the pretrained weights
out_features: number of output features of the pretrained weights
head_size: size of a single attention head
n_head: number of attention heads
n_query_groups: number of query groups (see diagram in `litgpt/config.py`)
Expand All @@ -214,6 +212,7 @@ def __init__(
and `value` but keep `key` without weight updates we should pass `[True, False, True]`
"""
super(LoRALinear, self).__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
out_features = head_size * (n_head + 2 * n_query_groups)
self.linear = torch.nn.Linear(in_features, out_features, **kwargs)
self.head_size = head_size
self.n_head = n_head
Expand All @@ -229,18 +228,17 @@ def __init__(
# ⚬ out_features: 384 (3 * embedding_size)
# ⚬ r: 2
# ⚬ enable_lora: [True, False, True]
self._all_qkv_shapes = (
# if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
# might not be equal to `head_size * n_head`, thus we use it directly here
head_size * n_head,
head_size * n_query_groups,
head_size * n_query_groups,
)
if r > 0 and any(enable_lora):
self.lora_A = nn.Parameter(torch.empty((r * sum(enable_lora), in_features))) # (4, 128)
enable_q, enable_k, enable_v = enable_lora
# qkv_shapes will be used to split a tensor with weights correctly
qkv_shapes = (
# if `head_size` is explicitly specified in the config, `n_embd` (or `in_features`)
# might not be equal to `head_size * n_head`, thus we use it directly here
head_size * n_head * enable_q,
head_size * n_query_groups * enable_k,
head_size * n_query_groups * enable_v,
)
self.qkv_shapes = [s for s in qkv_shapes if s]
self.qkv_shapes = [s for s, e in zip(self._all_qkv_shapes, enable_lora) if e]
self.lora_B = nn.Parameter(torch.empty(sum(self.qkv_shapes), r)) # (256, 2))
# Notes about shapes above
# - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices;
Expand All @@ -266,15 +264,13 @@ def lora_ind(self) -> torch.Tensor:
"""Lazy creation of a buffer with LoRA indices to overcome the limitation when FSDP with meta device is used."""
# Indices are needed to properly pad weight updates with zeros.
if not hasattr(self, "_lora_ind"):
enable_q, enable_k, enable_v = self.enable_lora
kv_embd_size = self.linear.in_features // (self.n_head // self.n_query_groups)
off = 0
lora_ind = []
if enable_q:
lora_ind.extend(range(0, self.linear.in_features))
if enable_k:
lora_ind.extend(range(self.linear.in_features, self.linear.in_features + kv_embd_size))
if enable_v:
lora_ind.extend(range(self.linear.in_features + kv_embd_size, self.linear.out_features))
for enable, size in zip(self.enable_lora, self._all_qkv_shapes):
if enable:
lora_ind.extend(range(off, off + size))
off += size
assert len(lora_ind) == sum(self.qkv_shapes) # Sanity check
self.register_buffer(
"_lora_ind", torch.tensor(lora_ind, device=self.linear.weight.device), persistent=False
)
Expand Down Expand Up @@ -527,10 +523,8 @@ class CausalSelfAttention(BaseCausalSelfAttention):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
# key, query, value projections for all heads, but in a batch
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
self.qkv = LoRAQKVLinear(
in_features=config.n_embd,
out_features=shape,
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
Expand Down
50 changes: 40 additions & 10 deletions tests/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def __init__(self, *args, **kwargs):
original_linear = torch.nn.Linear
# Our bnb does this sort of monkey patching
torch.nn.Linear = MyLinear
layer = LoRAQKVLinear(1, 1, 1, 1, 1)
layer = LoRAQKVLinear(1, 1, 1, 1)
assert isinstance(layer.linear, original_linear)
torch.nn.Linear = original_linear

Expand Down Expand Up @@ -354,9 +354,7 @@ def test_lora_gpt_query_groups_merge_and_forward_no_exception(n_query_groups, ap
)
def test_lora_qkv_linear_compare_conv1d(head_size, n_head, enable_lora):
C = 12
layer = LoRAQKVLinear(
C, 3 * C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora
)
layer = LoRAQKVLinear(C, head_size=head_size, n_head=n_head, n_query_groups=n_head, r=2, enable_lora=enable_lora)
x = torch.randn((1, 1, C))
a = F.linear(x, layer.lora_A).transpose(-2, -1) # after_A
b = layer.lora_B.data.unsqueeze(-1)
Expand Down Expand Up @@ -389,7 +387,7 @@ def test_lora_linear_weights_merged_status(rank, expected_merged):
)
def test_lora_qkv_linear_weights_merged_status(rank, enable_lora, expected_merged):
C = 10
layer = LoRAQKVLinear(C, 3 * C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora)
layer = LoRAQKVLinear(C, head_size=5, n_head=2, n_query_groups=2, r=rank, enable_lora=enable_lora)
assert not layer.merged
layer.merge()
assert layer.merged == expected_merged
Expand Down Expand Up @@ -909,14 +907,11 @@ def test_zero_pad_cpu_and_mocked_mps():
n_head = 12
n_query_groups = 3
in_features = 128
kv_embed_dim = in_features // (n_head // n_query_groups)
out_features = in_features + 2 * kv_embed_dim
enable_lora = [True, False, True]
enable_lora = (True, False, True)
r = 4

model = LoRAQKVLinear(
in_features=in_features,
out_features=out_features,
head_size=head_size,
n_head=n_head,
n_query_groups=n_query_groups,
Expand All @@ -926,7 +921,7 @@ def test_zero_pad_cpu_and_mocked_mps():

batch_size = 64
seq_len = 64
embed_dim = 160
embed_dim = head_size * (n_head + n_query_groups)
x = torch.randn(batch_size, seq_len, embed_dim)

result_cpu = model.zero_pad(x)
Expand Down Expand Up @@ -1167,3 +1162,38 @@ def test_load_from_full_model_state_dict():

output_cpu_offload = model_cpu_offload(x)
assert output_cpu_offload.shape == (1, config.block_size, config.padded_vocab_size)


def test_forward_qwen3_4b():
"""
Tests whether LoRA forward works with the Qwen3-4B model, whose
transformer block has some non-standard configs:

* `n_embd != n_head * head_size`
* `n_head != n_query_groups`
* LoRA adapters added to queries and values, but not keys

"""
device = torch.device("cpu")
T = 20
config = Config.from_name(
"Qwen3-4B",
block_size=T,
n_layer=2,
vocab_size=128,
padded_vocab_size=128,
lora_r=8,
lora_alpha=16,
lora_dropout=0.05,
lora_query=True,
lora_key=False,
lora_value=True,
)
model = LoRAGPT(config)
x = torch.randint(
low=0,
high=config.padded_vocab_size,
size=(2, T),
device=device,
)
y = model(x)
Loading