目标
- 从零实现 LoRA 层
- 支持动态注入/卸载 LoRA
- 支持多个 LoRA 适配器热插拔切换
- 适配 HuggingFace Transformers 模型(以
facebook/opt-125m
为例)
第一部分:手动实现 LoRA 核心模块
LoRA Linear 层(替换原生 nn.Linear)
# lora_layer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
class LoRALinear(nn.Module):
def __init__(
self,
linear_layer: nn.Linear, # 原始线性层
r: int = 8, # LoRA 秩
lora_alpha: int = 16, # 缩放因子
lora_dropout: float = 0.0,
adapter_name: str = "default",
):
super().__init__()
self.linear = linear_layer
self.r = r
self.lora_alpha = lora_alpha
self.scaling = self.lora_alpha / self.r
self.adapter_name = adapter_name
# 冻结原始权重
for param in self.linear.parameters():
param.requires_grad = False
# 初始化 LoRA 矩阵 A 和 B
in_features = linear_layer.in_features
out_features = linear_layer.out_features
self.lora_A = nn.Parameter(torch.zeros(r, in_features))
self.lora_B = nn.Parameter(torch.zeros(out_features, r))
# Dropout
self.dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0 else lambda x: x
# 默认不启用
self.active = False
self.reset_parameters()
def reset_parameters(self):
# 初始化 A 为正态分布,B 为零(标准 LoRA 初始化)
nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
nn.init.zeros_(self.lora_B)
def forward(self, x: torch.Tensor):
if self.active:
# 自动迁移到 x 的设备(防止任何设备错位)
if self.lora_A.device != x.device:
self.lora_A.data = self.lora_A.data.to(x.device)
self.lora_B.data = self.lora_B.data.to(x.device)
# 原始输出 + LoRA 修正
base_out = self.linear(x)
lora_out = self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling
return base_out + lora_out
else:
return self.linear(x)
def activate(self):
self.active = True
def deactivate(self):
self.active = False
def state_dict_lora(self) -> Dict[str, torch.Tensor]:
"""仅返回 LoRA 参数,便于保存/加载"""
return {
f"{self.adapter_name}.lora_A": self.lora_A.data.clone(),
f"{self.adapter_name}.lora_B": self.lora_B.data.clone(),
}
def load_state_dict_lora(self, state_dict: Dict[str, torch.Tensor], module_full_name: str = ""):
"""加载 LoRA 参数,支持带模块前缀的 key"""
prefix = module_full_name.replace(".", "_") + "." if module_full_name else ""
try:
self.lora_A.data.copy_(state_dict[f"{prefix}{self.adapter_name}.lora_A"])
self.lora_B.data.copy_(state_dict[f"{prefix}{self.adapter_name}.lora_B"])
except KeyError as e:
available_keys = list(state_dict.keys())
raise KeyError(
f"Could not find LoRA weights for adapter '{self.adapter_name}' under prefix '{prefix}'.\n"
f"Available keys: {available_keys}\n"
f"Original error: {e}"
)
LoRA 热插拔管理器
默认LoRA是没有热插拔功能的,但是可以通过写一个管理器实现,这样就可以随时切换不同LoRA模型。
# lora_manager_manual.py
from typing import Dict, List, Optional, Any
import torch
import torch.nn as nn
from collections import defaultdict
from lora_layer import LoRALinear
class LoRAManagerManual:
def __init__(self, model: nn.Module, target_module_names: List[str]):
"""
:param model: 原始模型
:param target_module_names: 要替换的模块名,如 ["q_proj", "v_proj"]
"""
self.model = model
self.target_module_names = target_module_names
self.adapters: Dict[str, Dict[str, LoRALinear]] = {} # adapter_name -> {full_module_name: LoRALinear}
self.active_adapter: Optional[str] = None
self.original_modules: Dict[str, nn.Module] = {} # 保存原始模块,用于卸载恢复
# 找到所有目标模块并备份
self._find_and_backup_target_modules()
def _find_and_backup_target_modules(self):
"""遍历模型,找到目标模块并备份"""
for name, module in self.model.named_modules():
if any(target_name in name for target_name in self.target_module_names):
self.original_modules[name] = module
def add_adapter(self, adapter_name: str, r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.0):
"""
注入新的 LoRA 适配器(替换原始 Linear 层)
"""
if adapter_name in self.adapters:
raise ValueError(f"Adapter {adapter_name} already exists!")
adapter_modules = {}
for full_name, orig_module in self.original_modules.items():
if isinstance(orig_module, nn.Linear):
# 创建 LoRA 包装层
lora_layer = LoRALinear(
linear_layer=orig_module,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
adapter_name=adapter_name,
)
adapter_modules[full_name] = lora_layer
# 替换进模型
parent_name = ".".join(full_name.split(".")[:-1])
child_name = full_name.split(".")[-1]
parent = self._get_parent_module(parent_name)
setattr(parent, child_name, lora_layer)
self.adapters[adapter_name] = adapter_modules
print(f"✅ Added LoRA adapter: {adapter_name}")
def set_adapter_trainable(self, adapter_name: str):
"""
设置指定 adapter 为可训练,其余 adapter 冻结
一个adapter对应一个LoRA,同一模型可以附加多个LoRA
但是同时只激活一个
"""
if adapter_name not in self.adapters:
raise ValueError(f"Adapter {adapter_name} not found!")
# 先冻结所有 adapter 的参数
for name, adapter_modules in self.adapters.items():
requires_grad = (name == adapter_name)
for lora_layer in adapter_modules.values():
lora_layer.lora_A.requires_grad = requires_grad
lora_layer.lora_B.requires_grad = requires_grad
if requires_grad:
lora_layer.activate() # 训练时必须激活!
else:
lora_layer.deactivate()
self.active_adapter = adapter_name
print(f"🎓 Training mode: only '{adapter_name}' is trainable.")
def get_trainable_parameters(self) -> int:
"""返回当前可训练参数数量"""
total_params = 0
for name, param in self.model.named_parameters():
if param.requires_grad:
total_params += param.numel()
return total_params
def _get_parent_module(self, parent_name: str) -> nn.Module:
"""根据模块路径获取父模块"""
if parent_name == "":
return self.model
names = parent_name.split(".")
module = self.model
for name in names:
module = getattr(module, name)
return module
def activate_adapter(self, adapter_name: str):
"""激活指定适配器"""
if adapter_name not in self.adapters:
raise ValueError(f"Adapter {adapter_name} not found!")
# 先停用所有
self.deactivate_all()
for lora_layer in self.adapters[adapter_name].values():
lora_layer.activate()
self.active_adapter = adapter_name
print(f"🔌 Activated adapter: {adapter_name}")
def deactivate_all(self):
"""停用所有 LoRA 适配器"""
for adapter_dict in self.adapters.values():
for lora_layer in adapter_dict.values():
lora_layer.deactivate()
self.active_adapter = None
print("🔌 All adapters deactivated (base model only)")
def save_adapter(self, adapter_name: str, path: str):
"""保存指定 adapter 的 LoRA 权重"""
if adapter_name not in self.adapters:
raise ValueError(f"Adapter {adapter_name} not found!")
state_dict = {}
for full_name, lora_layer in self.adapters[adapter_name].items():
prefix = full_name.replace(".", "_")
for k, v in lora_layer.state_dict_lora().items():
state_dict[f"{prefix}.{k}"] = v
torch.save(state_dict, path)
print(f"💾 Saved adapter {adapter_name} to {path}")
def load_adapter_weights(self, adapter_name: str, path: str):
"""加载 LoRA 权重到指定 adapter"""
if adapter_name not in self.adapters:
raise ValueError(f"Adapter {adapter_name} not found!")
state_dict = torch.load(path, map_location="cpu")
for full_name, lora_layer in self.adapters[adapter_name].items():
# 传入 full_name 用于构建 key
try:
lora_layer.load_state_dict_lora(state_dict, full_name)
except KeyError as e:
print(f"⚠️ Failed to load for module {full_name}: {e}")
continue
print(f"📥 Loaded weights into adapter {adapter_name} from {path}")
def remove_adapter(self, adapter_name: str):
"""卸载并移除 adapter,恢复原始模块"""
if adapter_name not in self.adapters:
print(f"AdapterManager: {adapter_name} not found.")
return
# 恢复原始模块
for full_name, lora_layer in self.adapters[adapter_name].items():
parent_name = ".".join(full_name.split(".")[:-1])
child_name = full_name.split(".")[-1]
parent = self._get_parent_module(parent_name)
setattr(parent, child_name, lora_layer.linear) # 恢复原始 Linear
del self.adapters[adapter_name]
if self.active_adapter == adapter_name:
self.active_adapter = None
print(f"⏏️ Removed adapter: {adapter_name}")
def list_adapters(self):
print("AdapterManager: Loaded adapters:")
for name in self.adapters.keys():
status = "ACTIVE" if name == self.active_adapter else "inactive"
print(f" - {name} ({status})")
第二部分:训练LoRA —— 适配 OPT-125m
数据集
极小的“数学问答”数据集来演示训练,为了显示效果,加上特有的结尾词nya~
:
# training_data.py
MATH_DATASET = [
"Q: What is 2+2? A: 4 nya~",
"Q: What is 5*6? A: 30 nya~",
"Q: What is 12-7? A: 5 nya~",
"Q: What is 8/2? A: 4 nya~",
"Q: What is 9+6? A: 15 nya~",
"Q: What is 7*7? A: 49 nya~",
"Q: What is 100-33? A: 67 nya~",
"Q: What is 144/12? A: 12 nya~",
]
# 构建训练批次
def get_batch(dataset, tokenizer, batch_size=4, device="cpu"):
import random
batch_texts = random.sample(dataset, batch_size)
encodings = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
input_ids = encodings.input_ids.to(device)
labels = input_ids.clone().to(device)
return input_ids, labels
训练代码
# train_lora_manual.py
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from lora_layer import LoRALinear
from lora_manager_manual import LoRAManagerManual
from training_data import MATH_DATASET, get_batch
# ========== 配置 ==========
model_name = "facebook/opt-125m"
device = "cuda" if torch.cuda.is_available() else "cpu"
epochs = 50
batch_size = 4
learning_rate = 1e-3
warmup_steps = 10
target_modules = ["q_proj", "v_proj"]
# ========== 1. 加载模型和 tokenizer ==========
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
# 设置 pad_token(OPT 没有默认 pad_token)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# ========== 2. 初始化 LoRA 管理器 ==========
lora_manager = LoRAManagerManual(model, target_modules)
# ========== 冻结整个模型 ==========
def freeze_all_parameters(model):
for param in model.parameters():
param.requires_grad = False
print("🔒 Frozen entire base model.")
freeze_all_parameters(model)
# ========== 3. 添加训练用的 LoRA 适配器 ==========
lora_manager.add_adapter("math_lora", r=8, lora_alpha=16, lora_dropout=0.1)
# ========== 4. 设置该 adapter 为可训练 ==========
lora_manager.set_adapter_trainable("math_lora")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")
# ========== 5. 准备优化器和调度器 ==========
optimizer = optim.AdamW(
[p for p in model.parameters() if p.requires_grad],
lr=learning_rate
)
# 模拟训练步数
total_steps = epochs * (len(MATH_DATASET) // batch_size)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=warmup_steps,
num_training_steps=total_steps
)
# ========== 6. 训练循环 ==========
model.train()
print("🚀 Starting training...")
for epoch in range(epochs):
epoch_loss = 0.0
num_batches = len(MATH_DATASET) // batch_size
for step in range(num_batches):
input_ids, labels = get_batch(MATH_DATASET, tokenizer, batch_size, device)
optimizer.zero_grad()
outputs = model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(
[p for p in model.parameters() if p.requires_grad], max_norm=1.0
)
optimizer.step()
scheduler.step()
epoch_loss += loss.item()
avg_loss = epoch_loss / num_batches
print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")
print("✅ Training completed!")
# ========== 7. 保存训练好的 LoRA 权重 ==========
lora_manager.save_adapter("math_lora", "./trained_math_lora.bin")
# ========== 8. 测试推理效果 ==========
model.eval()
lora_manager.activate_adapter("math_lora") # 确保激活
test_prompt = "Q: What is 15+25? A:"
inputs = tokenizer(test_prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=10,
do_sample=False,
num_beams=1,
pad_token_id=tokenizer.eos_token_id
)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\n🧪 Test Inference Result:")
print(result)
运行效果
🔒 Frozen entire base model.
✅ Added LoRA adapter: math_lora
🎓 Training mode: only 'math_lora' is trainable.
Total params: 125,534,208
Trainable params: 294,912 (0.23%)
🚀 Starting training...
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch 1/50 | Loss: 5.2867
Epoch 2/50 | Loss: 4.9194
Epoch 3/50 | Loss: 4.5034
......
Epoch 48/50 | Loss: 0.3609
Epoch 49/50 | Loss: 0.2970
Epoch 50/50 | Loss: 0.5831
✅ Training completed!
💾 Saved adapter math_lora to ./trained_math_lora.bin
🔌 All adapters deactivated (base model only)
🔌 Activated adapter: math_lora
🧪 Test Inference Result:
Q: What is 15+25? A: 25 nya~~~~~~~
-
可以看到整个OPT-125m的模型总参数就是名字中提到的125m,也就是1.25亿参数,而用LoRA微调的训练参数只有29万,是原模型参数的0.23%
-
预测可以看到答案后也加了
nya~~~~~~~
,说明我们的LoRA训练时成功的,至于结果,不必在意,一是这个模型很小,只是实验用;二是微调的数据也就几条,效果可想而知。我们只需要关注LoRA微调本身是否成功即可
总结
从零实现LoRA,帮助读者了解LoRA最底层的原理,比单纯文字说明更清晰更深入,这个LoRA实现已经是生产级别,可以直接应用到更大的模型中,当然实际生产中,我们不需要自己实现LoRA,可以借助peft
库来完成,更加简单。