实现一个简化版的stable diffusion

导入

相信大家最近都被stable diffusion刷屏了,无论是不是研究AI的,不如说更大部分都是AI领域之外的人,都多少体验过了,相信也为这个强大的生成模型的效果所震惊。从名字也可以看出stable diffusion自然是基于diffusion扩散技术的,关于扩散模型,我之前也写过一篇文章从头开始实现扩散模型,也在文章中简单说明了扩散模型的一些数学原理和公式情况,感兴趣的可以先去看看。其实在stable diffusion存在另一个和扩散模型同样重要的模型,只不过运气不好,没出现在标题中,那就是CLIP,也就是Contrastive Language-Image Pretraining,是一种对比学习方法,通过对比学习将图像和文本映射到同一语义空间,使用2个编码器,分别是图像编码器和文本编码器,让图像和文本编码后的语义空间尽可能接近,这也是stable diffusion的核心技术。除此之外,还有VAE变分自编码器的使用,VAE是多年前的模型了,我的github中也有从头实现的VAE代码,本身并不复杂,VAE这里主要负责把图像在隐空间和图像空间之间来回转换,因为虽然原来的扩散模型是在图像控件下进行加噪扩散的,但是stable diffusion是在隐空间之中去加噪扩散的,这样不仅加速了训练,节省了几十倍的显存需求,同时隐空间其实更多的是”语义“空间,更加平滑,也更接近高斯分布,因此可以学习到”语义级别“的噪声。另外,除了传统扩散模型中的时间步嵌入之外,SD中U-Net还需要加入文本嵌入,也就是CLIP编码后的嵌入,使用的是cross-attention嵌入方法,图像隐空间信息作为Q,文本嵌入信息作为KV。更多的细节,在后面给出的代码中都可以看得很清楚,我不再过多说明,只是下面简单说明下SD的整体架构和训练推理流程。

整体架构

文本提示 (Text Prompt)
        ↓
[CLIP Text Encoder] → 文本嵌入 (Text Embedding)
        ↓
[U-Net Denoising Model] ← 噪声潜在表示 (Noisy Latent)
        ↓
[VAE Decoder] → 生成图像 (Generated Image)

训练过程

Step 1: VAE 编码

  • 将真实图像 x​ 通过 VAE 编码器得到潜在表示 x0​

Step 2: 噪声添加

  • 随机选择时间步 t∼Uniform(0,T)
  • 采样噪声 ϵ∼N(0,I)
  • 计算带噪潜在表示:$X_t = \sqrt{a_t } X_{t-1}+ \sqrt{1 - a_t} \epsilon$

Step 3: 文本编码

  • 将文本提示通过 CLIP 文本编码器得到嵌入 c

Step 4: 噪声预测

  • 输入 (xt​,t,c) 到 U-Net
  • 输出噪声预测 ϵθ​(xt​,t,c)

Step 5: 损失计算和优化

  • 计算 MSE 损失:L=∥ϵ−ϵθ​∥2
  • 使用 AdamW 优化器更新 U-Net 参数
  • VAE 和 CLIP 参数在训练过程中固定

推理过程

Step 1: 文本编码

  • 输入文本提示 → CLIP 文本编码器 → 文本嵌入 c

Step 2: 初始化噪声

  • 从标准正态分布采样:xT​∼N(0,I)
  • 尺寸:比如 64×64×4(对应 512×512 图像)

Step 3: 迭代去噪

  • 对于 t=T,T−1,…,1 :
    1. 输入 (xt​,t,c) 到 U-Net
    2. 获得噪声预测 ϵθ​
    3. 使用采样器计算 xt−1​

Step 4: 图像重建

  • 将最终潜在表示 x0​ 输入 VAE 解码器
  • 输出 比如 512×512×3 的生成图像

关于采样器,再多说一些,如果你看了我之前扩散模型的文章,或者了解扩散模型,应该知道之前的DDPM采样,是采样1000步,每步模型输出的噪音,来根据公式计算均值,然后再使用β作为方差,引入随机均值0方差1的噪声,利用重参数化构建图像数据,第1000步不用添加噪声,直接获取采样完成的图像。这个属于是原始的理论公式,这里SD中当然是可以使用的,但是缺点很明显:

  • 速度极慢(分钟级)

  • 每步加噪,随机性过大,不好控制

单是速度太慢,在工程上就已经不怎么考虑了,而更现代的方法,比如DDIMDPM可以用小几十步就能获取几乎和DDPM采样1000步差不多质量的图像。现代的方法本质上是一种数值逼近的方法,采用一阶,二阶或者更高阶的方法,有的是确定性采样,有的也需要加噪,可以按需选择。而数值逼近效果好的原因,也跟SD是在隐空间扩散有关,隐空间这种更加平滑的”语义级别“的加噪,使得跳跃步数的采样逼近效果也很不错。反过来,传统扩散模型也可以采用现代采样器,即使是在图像空间,效果也不算差,速度却是质的飞跃。2021年提出DDIM的论文Denoising Diffusion Implicit Models中就展示了在 CIFAR10、CelebA、LSUN 上,DDIM 用 100 步就能匹配 DDPM 1000 步质量。

完整代码

项目结构

stable_diffusion_toy/
├── dataset/
│   ├── images/          # 存放数据集图片
│   └── captions.txt     # 对应文本描述
├── models/
│   ├── unet.py          # 自定义U-Net
│   └── diffusion.py      # 扩散模型主类
├── utils/
│   ├── dataset.py       # 数据集加载
│   └── ddim.py          # DDIM采样器
├── train.py             # 训练脚本
├── inference.py         # 推理脚本
└── config.py            # 配置文件

首先要说明的是,因为是玩具项目,不会使用很大的模型,更多的数据,更长的训练时间,所以效果不会很好,重点关注底层原理,你要知道官方的SD使用LAION-5B数据集,包含50亿对图像文本样本,然后在数百A100 GPU上训练的,而本项目将会使用自己网上爬取的1000个图片样本,图片命名1.jpg~1000.jpg,主要是人物和交通工具,然后文本描述同样是1000行,每行对应一个图片,训练个50轮,且模型小,效果自然不用奢求太多,这种项目也不是普通人能从头训练的起的。VAECLIP则使用预训练好的模型,整个SD训练过程,这两个的参数都是固定的。

数据集我就不给出了,我爬取的数据集质量很差,且文本描述也是批量的那种,如果真的想自己训练,直接使用LAION-5B公开数据集,从中选1000张,或者更多即可,按照我说的格式即可,也就是图片命名1.jpg~1000.jpg,文本描述1000行,每行对应一个图片

unet.py

简化版u-net,模型层数不高,和传统diffusion扩散模型中的u-net的一个区别是瓶颈层除了使用交叉注意力对文本编码嵌入处理之外,还使用了空间自注意力对自身图像空间(隐空间)的空间关系进行处理,但是这其实也只是针对最开始的扩散模型,实际上后来的扩散模型为了进一步提高图像分辨率生成质量,是早就有加入自注意力机制的,尤其是在空间分辨率较低的层加入自注意力层,可以非常有效捕获图像的空间特征。空间自注意力还是属于无条件生成,而交叉注意力属于有条件生成,所以依赖文本编码条件,而文本编码如果简单的通过拼接或者相加就会破坏图像空间的空间结构,Cross-Attention 在保持图像空间维度的同时引入语义,可以做到图像不同空间区域关注不同的提示词,而相加显然就是全局的,无法区分不同区域应关注不同词。当然交叉注意力层也是简化版,我之前写过从头实现transformers的文章,里面对transformers架构的实现更详细一些,想要更了解的可以去看一下。

为了让读者更清楚交叉注意力层的嵌入,我再详细说明下,假设图像特征(Query):x 维度[B, L, C],其中 L = H×W,C = 256,文本编码(Key/Value):context 维度[B, N, D],其中 N = 77(CLIP token 数),D = 768。首先通过映射权重矩阵to_q,to_k,to_v投影到相同的注意力头维度空间,比如8头注意力,每头维度32,那么投影维度就是32x8=256,此时:to_q: [B, L, 256] → [B, L, 256] to_k: [B, 77, 768] → [B, 77, 256] to_v: [B, 77, 768] → [B, 77, 256] 然后多头拆分后,维度依次变成[B, 8, L, 32],[B, 8, 77, 32],[B, 8, 77, 32](交换了下维度,将头维度前移),然后qk开始计算相似度,sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale,此时维度就是q[B, 8, L, 32],k[B, 8, 77, 32] -> sim[B, 8, L, 77],表示对于 batch 中每个样本、每个注意力头、图像的每个位置(共 L 个),计算它与 77 个文本 token 的相似度。然后softmax归一化之后就是得到每个图像位置对每个文本 token 的注意力权重,然后加权聚合(Attn × V)之后,维度变化权重[B, 8, L, 77],v[B, 8, 77, 32] -> [B, 8, L, 32] ,每个图像位置得到一个 融合了文本语义的 32 维向量(每个头),这也正是上面说的图像不同空间区域关注不同的提示词的来源,然后头合并就没什么好说的了。

实际的SD中的u-net的编码解码模块分为多个层级,表示不同的分辨率(如 64×64, 32×32, 16×16, 8×8) ,每个层级的resnet块之后都会加入一个空间自注意力层和交叉注意力文本嵌入层,本文简化版只加在瓶颈层。

# unet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

class TimestepEmbedding(nn.Module):
    """时间步嵌入层"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.half_dim = dim // 2
        # 第一个 Linear 输入应为 half_dim,输出 dim
        self.emb = nn.Sequential(
            nn.Linear(self.half_dim, dim),  # 输入是 half_dim
            nn.SiLU(),
            nn.Linear(dim, dim)
        )

    def forward(self, t):
        # t: [B]
        t = t.float()
        half_dim = self.half_dim
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)  # [half_dim]
        emb = t[:, None] * emb[None, :]  # [B, half_dim]
        # 转换 emb 的 dtype 以匹配 Linear 层
        emb = emb.to(self.emb[0].weight.dtype)  # 自动匹配 Linear 的 dtype
        # 保留 [B, half_dim],让 Linear 层映射到 dim
        return self.emb(emb)  # [B, half_dim] → Linear → [B, dim]

class CrossAttention(nn.Module):
    """简化交叉注意力层"""
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64):
        super().__init__()
        self.dim_head = dim_head  # 保存 dim_head
        self.inner_dim = dim_head * heads
        context_dim = context_dim or query_dim

        self.scale = dim_head ** -0.5
        self.heads = heads
        self.to_q = nn.Linear(query_dim, self.inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, self.inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, self.inner_dim, bias=False)
        self.to_out = nn.Linear(self.inner_dim, query_dim)

    def forward(self, x, context=None):
        B, L, _ = x.shape  # x: [B, L, C]
        h = self.heads
        dim_head = self.dim_head  # 显式使用保存的 dim_head

        q = self.to_q(x)  # [B, L, 512]
        context = context if context is not None else x
        k = self.to_k(context)  # [B, 77, 512]
        v = self.to_v(context)  # [B, 77, 512]

        # 分头: [B, L, 512] -> [B, L, h, dim_head] -> [B, h, L, dim_head]
        q = q.view(B, L, h, dim_head).transpose(1, 2)
        k = k.view(B, context.shape[1], h, dim_head).transpose(1, 2) 
        v = v.view(B, context.shape[1], h, dim_head).transpose(1, 2)  

        # 计算注意力
        sim = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale  # [B, h, L, 77]
        attn = sim.softmax(dim=-1)  # [B, h, L, 77]

        # 加权求和
        out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)  # [B, h, L, dim_head]

        # 合并头
        out = out.transpose(1, 2).contiguous().view(B, L, self.inner_dim)  # [B, L, 512]
        return self.to_out(out)

class ResnetBlock(nn.Module):
    """带时间条件的残差块"""
    def __init__(self, in_c, out_c, time_emb_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_emb_dim, out_c)
        )
        self.conv1 = nn.Conv2d(in_c, out_c, 3, padding=1)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, padding=1)
        self.norm1 = nn.GroupNorm(min(8, in_c), in_c)  # 最多8组,但不超过通道数
        self.norm2 = nn.GroupNorm(min(8, out_c), out_c)
        self.shortcut = nn.Conv2d(in_c, out_c, 1) if in_c != out_c else nn.Identity()

    def forward(self, x, t):
        if x.shape[1] != self.norm1.num_channels:
            raise ValueError(f"通道数不匹配!输入: {x.shape[1]}, 期望: {self.norm1.num_channels}")

        def block1(x_in):
            return self.conv1(F.silu(self.norm1(x_in)))

        def block2(x_in):
            return self.conv2(F.silu(self.norm2(x_in)))

        h = checkpoint(block1, x)
        h += self.mlp(t)[:, :, None, None]
        h = checkpoint(block2, h)
        return h + self.shortcut(x)

class AttentionBlock(nn.Module):
    """空间自注意力块"""
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.GroupNorm(8, dim)
        self.attn = nn.MultiheadAttention(dim, 4, batch_first=True)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        h = h.view(B, C, -1).transpose(1, 2)  # [B, H*W, C]
        h, _ = self.attn(h, h, h)
        h = h.transpose(1, 2).view(B, C, H, W)
        return x + h

class UNet(nn.Module):
    """简化版条件U-Net,支持文本交叉注意力"""
    def __init__(self, in_channels=4, out_channels=4, text_embed_dim=512):
        super().__init__()

        # 时间步嵌入
        time_embed_dim = 256
        self.time_embed = TimestepEmbedding(time_embed_dim)
        self.time_mlp = nn.Sequential(
            nn.Linear(time_embed_dim, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim)
        )

        # 编码器层
        self.enc1 = ResnetBlock(in_channels, 64, time_embed_dim)
        self.enc2 = ResnetBlock(64, 128, time_embed_dim)
        self.enc3 = ResnetBlock(128, 256, time_embed_dim)

        # 瓶颈层
        self.mid1 = ResnetBlock(256, 256, time_embed_dim)
        self.mid_attn = AttentionBlock(256)
        self.mid_cross_attn = CrossAttention(256, text_embed_dim)
        self.mid2 = ResnetBlock(256, 256, time_embed_dim)

        # 解码器层
        self.dec1 = ResnetBlock(256 + 128, 128, time_embed_dim)  
        self.dec2 = ResnetBlock(128 + 64, 64, time_embed_dim) 
        self.dec3 = ResnetBlock(64, 64, time_embed_dim) 

        # 输出层
        self.out = nn.Sequential(
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            nn.Conv2d(64, out_channels, 3, padding=1)
        )

        # 下采样 & 上采样
        self.downsample = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

    def forward(self, x, t, context):
        # x: [B, C, H, W], t: [B], context: [B, L, 768]
        t_emb = self.time_mlp(self.time_embed(t))

        # 编码器
        h1 = self.enc1(x, t_emb)  # [B, 64, H, W]
        h2 = self.enc2(self.downsample(h1), t_emb)  # [B, 128, H/2, W/2]
        h3 = self.enc3(self.downsample(h2), t_emb)  # [B, 256, H/4, W/4]

        # 瓶颈
        h = self.mid1(h3, t_emb)
        h = self.mid_attn(h)
        B, C, H, W = h.shape
        h_flat = h.view(B, C, -1).transpose(1, 2)  # [B, H*W, C]
        h_flat = self.mid_cross_attn(x=h_flat, context=context)
        h = h_flat.transpose(1, 2).view(B, C, H, W)
        h = self.mid2(h, t_emb)

        # 解码器
        h = self.upsample(h)  
        h = torch.cat([h, h2], dim=1)               
        h = self.dec1(h, t_emb)                     

        h = self.upsample(h)                         
        h = torch.cat([h, h1], dim=1)                
        h = self.dec2(h, t_emb)                  

        h = self.dec3(h, t_emb) 
        out = self.out(h)        
        return out

diffusion.py

# diffusion.py
import torch
import torch.nn.functional as F
from config import Config

class DiffusionModel:
    def __init__(self, unet, vae, text_encoder, config):
        self.unet = unet
        self.vae = vae
        self.text_encoder = text_encoder
        self.config = config
        self.device = config.device

        # 预计算beta schedule
        self.num_train_timesteps = config.num_train_timesteps
        self.betas = torch.linspace(0.0001, 0.02, self.num_train_timesteps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)

        # 移动到设备
        self.betas = self.betas.to(self.device)
        self.alphas_cumprod = self.alphas_cumprod.to(self.device)
        self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(self.device)
        self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(self.device)

    def add_noise(self, x_start, t):
        """前向加噪"""
        noise = torch.randn_like(x_start)
        sqrt_alpha = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        x_noisy = sqrt_alpha * x_start + sqrt_one_minus_alpha * noise
        return x_noisy, noise

    def get_loss(self, x_start, text_tokens, t):
        """计算扩散损失"""
        # 编码文本
        with torch.no_grad():
            text_embeddings = self.text_encoder(text_tokens)[0]  # [B, 77, 768]


        # 编码图像到latent空间
        with torch.no_grad():
            latents = self.vae.encode(x_start).latent_dist.sample() * 0.18215  # VAE缩放因子


        # 加噪
        x_noisy, noise = self.add_noise(latents, t)


        # 预测噪声
        noise_pred = self.unet(x_noisy, t, text_embeddings)

        # MSE损失
        loss = F.mse_loss(noise_pred, noise)
        return loss

dataset.py

# dataset.py
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import CLIPTokenizer
from config import Config

class TextImageDataset(Dataset):
    def __init__(self, image_dir, caption_file, tokenizer, image_size=64):
        self.image_dir = image_dir
        self.captions = []
        self.image_names = []

        # 读取caption文件
        with open(caption_file, 'r', encoding='utf-8') as f:
            lines = f.readlines()
            for i, line in enumerate(lines):
                caption = line.strip()
                image_name = f"{i+1}.jpg"  # 假设图片命名为 1.jpg, 2.jpg...
                image_path = os.path.join(image_dir, image_name)
                if os.path.exists(image_path):
                    self.captions.append(caption)
                    self.image_names.append(image_name)

        self.tokenizer = tokenizer
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])  # [-1, 1]
        ])

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        # 加载图像
        img_path = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)  # [C, H, W]

        # 编码文本
        text = self.captions[idx]
        tokens = self.tokenizer(
            text,
            padding="max_length",
            max_length=77,  # CLIP最大长度
            truncation=True,
            return_tensors="pt"
        )
        input_ids = tokens.input_ids.squeeze(0)  # [77]

        return image, input_ids

def get_dataloader(config):
    # 初始化CLIP tokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")  
    dataset = TextImageDataset(
        config.image_dir,
        config.caption_file,
        tokenizer,
        image_size=config.image_size  # 原始图像尺寸
    )

    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0
    )

    return dataloader, tokenizer

ddim.py

# ddim.py
import torch

class DDIMSampler:
    def __init__(self, diffusion_model):
        self.diffusion = diffusion_model
        self.device = diffusion_model.device

    @torch.no_grad()
    def sample(self, text_tokens, batch_size=1, guidance_scale=7.5, num_steps=50):
        """DDIM采样生成图像"""
        unet = self.diffusion.unet
        vae = self.diffusion.vae
        text_encoder = self.diffusion.text_encoder
        config = self.diffusion.config

        # 编码文本
        text_embeddings = text_encoder(text_tokens)[0]  # [B, 77, 768]
        uncond_tokens = torch.zeros_like(text_tokens)
        uncond_embeddings = text_encoder(uncond_tokens)[0]

        # 合并条件与无条件嵌入(CFG)
        context = torch.cat([uncond_embeddings, text_embeddings])

        # 初始化随机噪声
        latents = torch.randn(
            (batch_size, config.latent_channels, config.latent_size, config.latent_size),  # 👈 修复1:latent_size
            dtype=torch.float32, 
            device=self.device
        )

        # DDIM时间步
        timesteps = torch.linspace(
            self.diffusion.num_train_timesteps - 1, 0, num_steps, dtype=torch.long
        ).to(self.device)

        for i, t in enumerate(timesteps):
            t_batch = t.repeat(batch_size * 2).to(latents.dtype) 

            # 预测噪声
            latent_model_input = torch.cat([latents] * 2)
            noise_pred = unet(latent_model_input, t_batch, context)

            # 分离无条件和条件预测
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # DDIM更新公式 
            alpha_t = self.diffusion.alphas_cumprod[t].to(latents.dtype)
            alpha_t_prev = self.diffusion.alphas_cumprod[timesteps[i+1]].to(latents.dtype) if i < len(timesteps)-1 else torch.tensor(1.0, dtype=latents.dtype, device=self.device)

            # 标准 DDIM 公式(sigma=0)
            pred_x0 = (latents - torch.sqrt(1 - alpha_t) * noise_pred) / torch.sqrt(alpha_t)
            pred_x0 = torch.clamp(pred_x0, -1.0, 1.0)  

            # 重构 latents
            latents = torch.sqrt(alpha_t_prev) * pred_x0 + torch.sqrt(1 - alpha_t_prev) * noise_pred

            # 清理缓存
            latents = latents.detach()
            del noise_pred, noise_pred_uncond, noise_pred_text, pred_x0, dir_xt
            torch.cuda.empty_cache()

        # 解码latent到图像 
        latents = latents / 0.18215
        images = vae.decode(latents.float()).sample 
        images = (images / 2 + 0.5).clamp(0, 1)  # [-1,1] -> [0,1]

        return images

config.py

训练参数记得根据你得实际显存进行调整,可以考虑使用colab训练,colab上使用256的图片尺寸和16的batchsize应该问题不大。

# config.py
import torch

class Config:
    # 设备配置
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 数据路径
    image_dir = "./dataset/images"
    caption_file = "./dataset/captions.txt"

    # 模型参数
    latent_channels = 4          # VAE压缩后通道数

    image_size = 256     # VAE压缩后图像尺寸
    latent_size = image_size // 8 

    text_embed_dim = 512         # CLIP文本嵌入维度
    unet_in_channels = 4         # 输入通道(latent)
    unet_out_channels = 4        # 输出通道(预测噪声)

    # 训练参数
    batch_size = 8
    epochs = 50
    learning_rate = 5e-5
    num_train_timesteps = 1000   # 扩散步数
    save_path = "./models/sd_toy_64.pth"

    # 采样参数
    num_inference_steps = 50     # DDIM步数
    guidance_scale = 7.5         # 文本引导强度

train.py

训练时候可以考虑使用半精度,在unet,vae和text_encoder后加half(),当然其他地方也就有需要修改的,比如采样。当然更方便的是直接使用pyTorch中的混合精度训练函数。如果想用就自行修改即可,几行代码的事。

# train,py
import torch
import torch.optim as optim
from tqdm import tqdm
from config import Config
from models.unet import UNet
from models.diffusion import DiffusionModel
from utils.dataset import get_dataloader
from transformers import CLIPTextModel
from diffusers import AutoencoderKL
import torch.nn.functional as F

def train():
    config = Config()
    device = config.device


    # 加载预训练模型
    print("加载预训练VAE和CLIP...")
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=True).to(device)
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    # 初始化U-Net
    unet = UNet(
        in_channels=config.unet_in_channels,
        out_channels=config.unet_out_channels,
        text_embed_dim=config.text_embed_dim
    ).to(device)

    # 初始化扩散模型
    diffusion = DiffusionModel(unet, vae, text_encoder, config)

    # 数据加载器
    dataloader, tokenizer = get_dataloader(config)

    # 优化器
    optimizer = optim.AdamW(unet.parameters(), lr=config.learning_rate)

    # 训练循环
    print("开始训练...")
    unet.train()
    for epoch in range(config.epochs):
        total_loss = 0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{config.epochs}")

        for batch_idx, (images, input_ids) in enumerate(progress_bar):
            images = images.to(device) 
            input_ids = input_ids.to(device)

            # 随机时间步
            t = torch.randint(0, config.num_train_timesteps, (images.shape[0],), device=device)

            # 只计算 unet 梯度,其他冻结
            with torch.no_grad():
                text_embeddings = text_encoder(input_ids)[0]
                latents = vae.encode(images).latent_dist.mode()  

            latents = latents * 0.18215  # VAE缩放因子
            latents.requires_grad_(True)  
            x_noisy, noise = diffusion.add_noise(latents, t)
            noise_pred = unet(x_noisy, t, text_embeddings)

            loss = F.mse_loss(noise_pred, noise)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # 清理中间变量 + 缓存
            del x_noisy, noise, noise_pred, latents, text_embeddings
            torch.cuda.empty_cache()

            total_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} 平均损失: {avg_loss:.6f}")

    # 保存模型
    torch.save(unet.state_dict(), config.save_path)
    print(f"模型已保存至 {config.save_path}")

if __name__ == "__main__":
    train()

inference.py

# inference.py
import torch
import matplotlib.pyplot as plt
from config import Config
from models.unet import UNet
from models.diffusion import DiffusionModel
from utils.ddim import DDIMSampler
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL

from utils.DPM import DiffusersSchedulerSampler


def inference(prompt):
    torch.cuda.empty_cache()  # 清理缓存
    torch.set_grad_enabled(False)  # 确保不计算梯度
    config = Config()
    device = config.device

    # 加载预训练模型
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

    # 初始化并加载训练好的U-Net
    unet = UNet(
        in_channels=config.unet_in_channels,
        out_channels=config.unet_out_channels,
        text_embed_dim=config.text_embed_dim
    ).to(device)

    unet.load_state_dict(torch.load(config.save_path, map_location=device))
    unet.eval()    
    # 初始化扩散模型和采样器
    diffusion = DiffusionModel(unet, vae, text_encoder, config)
    sampler = DDIMSampler(diffusion)

    # 编码文本
    tokens = tokenizer(
        prompt,
        padding="max_length",
        max_length=77,
        truncation=True,
        return_tensors="pt"
    ).input_ids.to(device)

    # 采样生成
    print(f"正在生成提示词: {prompt}")
    images = sampler.sample(
        tokens,
        batch_size=1,
        guidance_scale=config.guidance_scale,
        num_steps=config.num_inference_steps
    )

    # 显示结果
    image = images[0].permute(1, 2, 0).cpu().float().numpy()
    plt.imshow(image)
    plt.title(prompt)
    plt.axis('off')
    plt.show()

    return image

if __name__ == "__main__":
    prompt = "a red car"
    inference(prompt)

推理效果

经过一个小时的训练,简单的几个效果如下:

Image 1 Image 2 Image 3

可以看到,整体效果虽然一般,这是我们已经预先知道的,但是局部效果,颜色和轮廓还是可以看到基本符合提示词要求的,比如图2的红色的车和图3的长发女人,局部已经很明显了,尤其是图2的红色的车,2辆跑车中清晰的部分感官上已经非常清晰了。