从零开始实现LoRA微调

目标

  1. 从零实现 LoRA 层
  2. 支持动态注入/卸载 LoRA
  3. 支持多个 LoRA 适配器热插拔切换
  4. 适配 HuggingFace Transformers 模型(以 facebook/opt-125m 为例)

第一部分:手动实现 LoRA 核心模块

LoRA Linear 层(替换原生 nn.Linear)

# lora_layer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any

class LoRALinear(nn.Module):
    def __init__(
        self,
        linear_layer: nn.Linear,  # 原始线性层
        r: int = 8,               # LoRA 秩
        lora_alpha: int = 16,     # 缩放因子
        lora_dropout: float = 0.0,
        adapter_name: str = "default",
    ):
        super().__init__()
        self.linear = linear_layer
        self.r = r
        self.lora_alpha = lora_alpha
        self.scaling = self.lora_alpha / self.r
        self.adapter_name = adapter_name

        # 冻结原始权重
        for param in self.linear.parameters():
            param.requires_grad = False

        # 初始化 LoRA 矩阵 A 和 B
        in_features = linear_layer.in_features
        out_features = linear_layer.out_features

        self.lora_A = nn.Parameter(torch.zeros(r, in_features))
        self.lora_B = nn.Parameter(torch.zeros(out_features, r))

        # Dropout
        self.dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0 else lambda x: x

        # 默认不启用
        self.active = False

        self.reset_parameters()

    def reset_parameters(self):
        # 初始化 A 为正态分布,B 为零(标准 LoRA 初始化)
        nn.init.kaiming_uniform_(self.lora_A, a=5**0.5)
        nn.init.zeros_(self.lora_B)

    def forward(self, x: torch.Tensor):
        if self.active:
            # 自动迁移到 x 的设备(防止任何设备错位)
            if self.lora_A.device != x.device:
                self.lora_A.data = self.lora_A.data.to(x.device)
                self.lora_B.data = self.lora_B.data.to(x.device)
            # 原始输出 + LoRA 修正
            base_out = self.linear(x)
            lora_out = self.dropout(x) @ self.lora_A.T @ self.lora_B.T * self.scaling
            return base_out + lora_out
        else:
            return self.linear(x)

    def activate(self):
        self.active = True

    def deactivate(self):
        self.active = False

    def state_dict_lora(self) -> Dict[str, torch.Tensor]:
        """仅返回 LoRA 参数,便于保存/加载"""
        return {
            f"{self.adapter_name}.lora_A": self.lora_A.data.clone(),
            f"{self.adapter_name}.lora_B": self.lora_B.data.clone(),
        }

    def load_state_dict_lora(self, state_dict: Dict[str, torch.Tensor], module_full_name: str = ""):
        """加载 LoRA 参数,支持带模块前缀的 key"""
        prefix = module_full_name.replace(".", "_") + "." if module_full_name else ""
        try:
            self.lora_A.data.copy_(state_dict[f"{prefix}{self.adapter_name}.lora_A"])
            self.lora_B.data.copy_(state_dict[f"{prefix}{self.adapter_name}.lora_B"])
        except KeyError as e:
            available_keys = list(state_dict.keys())
            raise KeyError(
                f"Could not find LoRA weights for adapter '{self.adapter_name}' under prefix '{prefix}'.\n"
                f"Available keys: {available_keys}\n"
                f"Original error: {e}"
        )

LoRA 热插拔管理器

默认LoRA是没有热插拔功能的,但是可以通过写一个管理器实现,这样就可以随时切换不同LoRA模型。

# lora_manager_manual.py
from typing import Dict, List, Optional, Any
import torch
import torch.nn as nn
from collections import defaultdict
from lora_layer import LoRALinear

class LoRAManagerManual:
    def __init__(self, model: nn.Module, target_module_names: List[str]):
        """
        :param model: 原始模型
        :param target_module_names: 要替换的模块名,如 ["q_proj", "v_proj"]
        """
        self.model = model
        self.target_module_names = target_module_names
        self.adapters: Dict[str, Dict[str, LoRALinear]] = {}  # adapter_name -> {full_module_name: LoRALinear}
        self.active_adapter: Optional[str] = None
        self.original_modules: Dict[str, nn.Module] = {}  # 保存原始模块,用于卸载恢复

        # 找到所有目标模块并备份
        self._find_and_backup_target_modules()

    def _find_and_backup_target_modules(self):
        """遍历模型,找到目标模块并备份"""
        for name, module in self.model.named_modules():
            if any(target_name in name for target_name in self.target_module_names):
                self.original_modules[name] = module

    def add_adapter(self, adapter_name: str, r: int = 8, lora_alpha: int = 16, lora_dropout: float = 0.0):
        """
        注入新的 LoRA 适配器(替换原始 Linear 层)
        """
        if adapter_name in self.adapters:
            raise ValueError(f"Adapter {adapter_name} already exists!")

        adapter_modules = {}
        for full_name, orig_module in self.original_modules.items():
            if isinstance(orig_module, nn.Linear):
                # 创建 LoRA 包装层
                lora_layer = LoRALinear(
                    linear_layer=orig_module,
                    r=r,
                    lora_alpha=lora_alpha,
                    lora_dropout=lora_dropout,
                    adapter_name=adapter_name,
                )
                adapter_modules[full_name] = lora_layer

                # 替换进模型
                parent_name = ".".join(full_name.split(".")[:-1])
                child_name = full_name.split(".")[-1]
                parent = self._get_parent_module(parent_name)
                setattr(parent, child_name, lora_layer)

        self.adapters[adapter_name] = adapter_modules
        print(f"✅ Added LoRA adapter: {adapter_name}")

    def set_adapter_trainable(self, adapter_name: str):
        """
        设置指定 adapter 为可训练,其余 adapter 冻结
        一个adapter对应一个LoRA,同一模型可以附加多个LoRA
        但是同时只激活一个
        """
        if adapter_name not in self.adapters:
            raise ValueError(f"Adapter {adapter_name} not found!")

        # 先冻结所有 adapter 的参数
        for name, adapter_modules in self.adapters.items():
            requires_grad = (name == adapter_name)
            for lora_layer in adapter_modules.values():
                lora_layer.lora_A.requires_grad = requires_grad
                lora_layer.lora_B.requires_grad = requires_grad
                if requires_grad:
                    lora_layer.activate()  # 训练时必须激活!
                else:
                    lora_layer.deactivate()

        self.active_adapter = adapter_name
        print(f"🎓 Training mode: only '{adapter_name}' is trainable.")

    def get_trainable_parameters(self) -> int:
        """返回当前可训练参数数量"""
        total_params = 0
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                total_params += param.numel()
        return total_params

    def _get_parent_module(self, parent_name: str) -> nn.Module:
        """根据模块路径获取父模块"""
        if parent_name == "":
            return self.model
        names = parent_name.split(".")
        module = self.model
        for name in names:
            module = getattr(module, name)
        return module

    def activate_adapter(self, adapter_name: str):
        """激活指定适配器"""
        if adapter_name not in self.adapters:
            raise ValueError(f"Adapter {adapter_name} not found!")

        # 先停用所有
        self.deactivate_all()

        for lora_layer in self.adapters[adapter_name].values():
            lora_layer.activate()

        self.active_adapter = adapter_name
        print(f"🔌 Activated adapter: {adapter_name}")

    def deactivate_all(self):
        """停用所有 LoRA 适配器"""
        for adapter_dict in self.adapters.values():
            for lora_layer in adapter_dict.values():
                lora_layer.deactivate()
        self.active_adapter = None
        print("🔌 All adapters deactivated (base model only)")

    def save_adapter(self, adapter_name: str, path: str):
        """保存指定 adapter 的 LoRA 权重"""
        if adapter_name not in self.adapters:
            raise ValueError(f"Adapter {adapter_name} not found!")

        state_dict = {}
        for full_name, lora_layer in self.adapters[adapter_name].items():
            prefix = full_name.replace(".", "_")
            for k, v in lora_layer.state_dict_lora().items():
                state_dict[f"{prefix}.{k}"] = v

        torch.save(state_dict, path)
        print(f"💾 Saved adapter {adapter_name} to {path}")

    def load_adapter_weights(self, adapter_name: str, path: str):
        """加载 LoRA 权重到指定 adapter"""
        if adapter_name not in self.adapters:
            raise ValueError(f"Adapter {adapter_name} not found!")

        state_dict = torch.load(path, map_location="cpu")

        for full_name, lora_layer in self.adapters[adapter_name].items():
            # 传入 full_name 用于构建 key
            try:
                lora_layer.load_state_dict_lora(state_dict, full_name)
            except KeyError as e:
                print(f"⚠️  Failed to load for module {full_name}: {e}")
                continue

        print(f"📥 Loaded weights into adapter {adapter_name} from {path}")

    def remove_adapter(self, adapter_name: str):
        """卸载并移除 adapter,恢复原始模块"""
        if adapter_name not in self.adapters:
            print(f"AdapterManager: {adapter_name} not found.")
            return

        # 恢复原始模块
        for full_name, lora_layer in self.adapters[adapter_name].items():
            parent_name = ".".join(full_name.split(".")[:-1])
            child_name = full_name.split(".")[-1]
            parent = self._get_parent_module(parent_name)
            setattr(parent, child_name, lora_layer.linear)  # 恢复原始 Linear

        del self.adapters[adapter_name]
        if self.active_adapter == adapter_name:
            self.active_adapter = None
        print(f"⏏️  Removed adapter: {adapter_name}")

    def list_adapters(self):
        print("AdapterManager: Loaded adapters:")
        for name in self.adapters.keys():
            status = "ACTIVE" if name == self.active_adapter else "inactive"
            print(f"  - {name} ({status})")

第二部分:训练LoRA —— 适配 OPT-125m

数据集

极小的“数学问答”数据集来演示训练,为了显示效果,加上特有的结尾词nya~

# training_data.py
MATH_DATASET = [
    "Q: What is 2+2? A: 4 nya~",
    "Q: What is 5*6? A: 30 nya~",
    "Q: What is 12-7? A: 5 nya~",
    "Q: What is 8/2? A: 4 nya~",
    "Q: What is 9+6? A: 15 nya~",
    "Q: What is 7*7? A: 49 nya~",
    "Q: What is 100-33? A: 67 nya~",
    "Q: What is 144/12? A: 12 nya~",
]

# 构建训练批次
def get_batch(dataset, tokenizer, batch_size=4, device="cpu"):
    import random
    batch_texts = random.sample(dataset, batch_size)
    encodings = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True)
    input_ids = encodings.input_ids.to(device)
    labels = input_ids.clone().to(device)  
    return input_ids, labels

训练代码

# train_lora_manual.py
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup
from lora_layer import LoRALinear
from lora_manager_manual import LoRAManagerManual
from training_data import MATH_DATASET, get_batch


# ========== 配置 ==========
model_name = "facebook/opt-125m"
device = "cuda" if torch.cuda.is_available() else "cpu"
epochs = 50
batch_size = 4
learning_rate = 1e-3
warmup_steps = 10
target_modules = ["q_proj", "v_proj"]

# ========== 1. 加载模型和 tokenizer ==========
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

# 设置 pad_token(OPT 没有默认 pad_token)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

# ========== 2. 初始化 LoRA 管理器 ==========
lora_manager = LoRAManagerManual(model, target_modules)

# ========== 冻结整个模型 ========== 
def freeze_all_parameters(model):
    for param in model.parameters():
        param.requires_grad = False
    print("🔒 Frozen entire base model.")

freeze_all_parameters(model)

# ========== 3. 添加训练用的 LoRA 适配器 ==========
lora_manager.add_adapter("math_lora", r=8, lora_alpha=16, lora_dropout=0.1)

# ========== 4. 设置该 adapter 为可训练 ==========
lora_manager.set_adapter_trainable("math_lora")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")

# ========== 5. 准备优化器和调度器 ==========
optimizer = optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=learning_rate
)

# 模拟训练步数
total_steps = epochs * (len(MATH_DATASET) // batch_size)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# ========== 6. 训练循环 ==========
model.train()
print("🚀 Starting training...")

for epoch in range(epochs):
    epoch_loss = 0.0
    num_batches = len(MATH_DATASET) // batch_size

    for step in range(num_batches):
        input_ids, labels = get_batch(MATH_DATASET, tokenizer, batch_size, device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(
            [p for p in model.parameters() if p.requires_grad], max_norm=1.0
        )

        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / num_batches
    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")

print("✅ Training completed!")

# ========== 7. 保存训练好的 LoRA 权重 ==========
lora_manager.save_adapter("math_lora", "./trained_math_lora.bin")

# ========== 8. 测试推理效果 ==========
model.eval()
lora_manager.activate_adapter("math_lora")  # 确保激活

test_prompt = "Q: What is 15+25? A:"
inputs = tokenizer(test_prompt, return_tensors="pt").to(device)

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=10,
        do_sample=False,
        num_beams=1,
        pad_token_id=tokenizer.eos_token_id
    )

result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\n🧪 Test Inference Result:")
print(result)

运行效果

🔒 Frozen entire base model.
✅ Added LoRA adapter: math_lora
🎓 Training mode: only 'math_lora' is trainable.
Total params: 125,534,208
Trainable params: 294,912 (0.23%)
🚀 Starting training...
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Epoch 1/50 | Loss: 5.2867
Epoch 2/50 | Loss: 4.9194
Epoch 3/50 | Loss: 4.5034
......
Epoch 48/50 | Loss: 0.3609
Epoch 49/50 | Loss: 0.2970
Epoch 50/50 | Loss: 0.5831
✅ Training completed!
💾 Saved adapter math_lora to ./trained_math_lora.bin
🔌 All adapters deactivated (base model only)
🔌 Activated adapter: math_lora

🧪 Test Inference Result:
Q: What is 15+25? A: 25 nya~~~~~~~
  • 可以看到整个OPT-125m的模型总参数就是名字中提到的125m,也就是1.25亿参数,而用LoRA微调的训练参数只有29万,是原模型参数的0.23%

  • 预测可以看到答案后也加了nya~~~~~~~,说明我们的LoRA训练时成功的,至于结果,不必在意,一是这个模型很小,只是实验用;二是微调的数据也就几条,效果可想而知。我们只需要关注LoRA微调本身是否成功即可

总结

从零实现LoRA,帮助读者了解LoRA最底层的原理,比单纯文字说明更清晰更深入,这个LoRA实现已经是生产级别,可以直接应用到更大的模型中,当然实际生产中,我们不需要自己实现LoRA,可以借助peft库来完成,更加简单。