核心思想
GAN 的核心思想就像两个对手在互相竞争,一个想要造假,一个想要识破假货,最终的目的
是让造假者的水平越来越高,骗过鉴定者,生成高度真实的数据。
让生成器和判别器进行对抗训练,最终生成高质量的假数据,甚至能骗过人类。这些生成的图
像和原始的真实图像相差无几。使用生成对抗网络生成数据的成本很低,生成结果可以直接应用在
各个领域。
网络结构
GAN 由两个神经网络组成:
生成器(Generator, G):负责造假,也就是生成类似于真实数据的假数据。
判别器(Discriminator, D):负责打假,判断输入的数据到底是真实的还是生成器造出来的假数据。
这两个网络通过不断相互对抗,最终让生成器学会生成足以乱真的数据,而判别器变得越来越擅
长分辨真伪。
生成器的目标是生成判别器无法区分的假数据。判别器的目标是准确区分真实数据和生成数
据。经过多轮的交替迭代,生成器可以生成和训练集相似的假图。而判别器也能准确的判断出真实
图片和生成图片。
手算模拟
训练生成器
生成器结构:
Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
输入噪声形状: torch.Size([2, 100, 1, 1])
第1层 | ConvTranspose2d | 输出形状: torch.Size([2, 512, 4, 4])
第2层 | BatchNorm2d | 输出形状: torch.Size([2, 512, 4, 4])
第3层 | ReLU | 输出形状: torch.Size([2, 512, 4, 4])
第4层 | ConvTranspose2d | 输出形状: torch.Size([2, 256, 8, 8])
第5层 | BatchNorm2d | 输出形状: torch.Size([2, 256, 8, 8])
第6层 | ReLU | 输出形状: torch.Size([2, 256, 8, 8])
第7层 | ConvTranspose2d | 输出形状: torch.Size([2, 128, 16, 16])
第8层 | BatchNorm2d | 输出形状: torch.Size([2, 128, 16, 16])
第9层 | ReLU | 输出形状: torch.Size([2, 128, 16, 16])
第10层 | ConvTranspose2d | 输出形状: torch.Size([2, 64, 32, 32])
第11层 | BatchNorm2d | 输出形状: torch.Size([2, 64, 32, 32])
第12层 | ReLU | 输出形状: torch.Size([2, 64, 32, 32])
第13层 | ConvTranspose2d | 输出形状: torch.Size([2, 3, 64, 64])
第14层 | Tanh | 输出形状: torch.Size([2, 3, 64, 64])
import torch import torch.nn as nn # 超参数(与训练场景匹配) nz = 100 # 输入噪声的维度 ngf = 64 # 生成器特征图的基础深度 image_size = 64 # 目标图像尺寸(64×64) num_channels = 3 # 输出图像通道数(RGB为3通道) class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # 生成器的核心网络:通过转置卷积逐步上采样,从噪声生成图像 self.main = nn.Sequential( # 第1层:噪声→4×4高维特征图 nn.ConvTranspose2d( in_channels=nz, # 输入通道:噪声维度100 out_channels=ngf * 8, # 输出通道:64×8=512 kernel_size=4, stride=1, padding=0, bias=False ), nn.BatchNorm2d(ngf * 8), nn.ReLU(inplace=True), # 第2层:4×4→8×8特征图 nn.ConvTranspose2d( in_channels=ngf * 8, out_channels=ngf * 4, # 64×4=256 kernel_size=4, stride=2, padding=1, bias=False ), nn.BatchNorm2d(ngf * 4), nn.ReLU(inplace=True), # 第3层:8×8→16×16特征图 nn.ConvTranspose2d( in_channels=ngf * 4, out_channels=ngf * 2, # 64×2=128 kernel_size=4, stride=2, padding=1, bias=False ), nn.BatchNorm2d(ngf * 2), nn.ReLU(inplace=True), # 第4层:16×16→32×32特征图 nn.ConvTranspose2d( in_channels=ngf * 2, out_channels=ngf, # 64 kernel_size=4, stride=2, padding=1, bias=False ), nn.BatchNorm2d(ngf), nn.ReLU(inplace=True), # 第5层:32×32→64×64图像 nn.ConvTranspose2d( in_channels=ngf, out_channels=num_channels, # 3通道(RGB) kernel_size=4, stride=2, padding=1, bias=False ), nn.Tanh() ) def forward(self, x): return self.main(x) def print_layer_outputs(model, input_tensor): """ 辅助函数:打印模型每一层的输出形状 :param model: 生成器模型(Sequential) :param input_tensor: 输入到模型的张量(噪声) """ print("===== 逐层输出形状 =====") x = input_tensor # 初始输入(噪声) print(f"输入噪声形状: {x.shape}") # 打印初始输入形状 # 遍历每一层,记录输出形状 for i, layer in enumerate(model): x = layer(x) # 执行当前层计算 # 打印层索引、层名称和输出形状 print(f"第{i+1}层 | {layer.__class__.__name__} | 输出形状: {x.shape}") # 测试生成器并打印每一层输出 if __name__ == "__main__": # 检查设备(CPU/GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}\n") # 初始化生成器并移动到设备 generator = Generator().to(device) print("生成器结构:") print(generator.main) # 打印Sequential内部的层结构 print() # 生成随机噪声(批量大小=2,形状为[2, 100, 1, 1]) batch_size = 2 noise = torch.randn(batch_size, nz, 1, 1, device=device) # 打印每一层的输出形状 with torch.no_grad(): # 不计算梯度(仅测试用) print_layer_outputs(generator.main, noise) # 最终生成图像的形状 with torch.no_grad(): generated_images = generator(noise) print("\n===== 最终结果 =====") print(f"生成图像形状: {generated_images.shape}") print("生成成功!输出为(batch_size, 3, 64, 64)的RGB图像")
训练判别器
Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
===== 判别器逐层输出 =====
输入图像形状: torch.Size([2, 3, 64, 64])
第1层 | Conv2d | 输出形状: torch.Size([2, 64, 32, 32])
第2层 | LeakyReLU | 输出形状: torch.Size([2, 64, 32, 32])
第3层 | Conv2d | 输出形状: torch.Size([2, 128, 16, 16])
第4层 | BatchNorm2d | 输出形状: torch.Size([2, 128, 16, 16])
第5层 | LeakyReLU | 输出形状: torch.Size([2, 128, 16, 16])
第6层 | Conv2d | 输出形状: torch.Size([2, 256, 8, 8])
第7层 | BatchNorm2d | 输出形状: torch.Size([2, 256, 8, 8])
第8层 | LeakyReLU | 输出形状: torch.Size([2, 256, 8, 8])
第9层 | Conv2d | 输出形状: torch.Size([2, 512, 4, 4])
第10层 | BatchNorm2d | 输出形状: torch.Size([2, 512, 4, 4])
第11层 | LeakyReLU | 输出形状: torch.Size([2, 512, 4, 4])
第12层 | Conv2d | 输出形状: torch.Size([2, 1, 1, 1])
第13层 | Sigmoid | 输出形状: torch.Size([2, 1, 1, 1])
===== 最终结果 =====
判别器输出概率形状: torch.Size([2, 1, 1, 1])
输出为(batch_size, 1, 1, 1),每个值代表输入图像为真实图像的概率(0~1)
import torch import torch.nn as nn # 超参数(与生成器匹配,确保尺寸兼容) ndf = 64 # 判别器特征图的基础深度 num_channels = 3 # 输入图像通道数(RGB为3通道) image_size = 64 # 输入图像尺寸(64×64,与生成器输出匹配) class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( # 第1层:64×64图像 → 32×32特征图 nn.Conv2d( in_channels=num_channels, # 输入通道:3(RGB图像) out_channels=ndf, # 输出通道:64(ndf) kernel_size=4, stride=2, # 步长2(尺寸减半) padding=1, # 填充1(配合步长2实现尺寸减半) bias=False ), nn.LeakyReLU(0.2, inplace=True), # LeakyReLU激活(避免梯度消失) # 第2层:32×32 → 16×16特征图 nn.Conv2d( in_channels=ndf, out_channels=ndf * 2, # 输出通道:64×2=128 kernel_size=4, stride=2, padding=1, bias=False ), nn.BatchNorm2d(ndf * 2), # 批归一化(稳定训练) nn.LeakyReLU(0.2, inplace=True), # 第3层:16×16 → 8×8特征图 nn.Conv2d( in_channels=ndf * 2, out_channels=ndf * 4, # 输出通道:64×4=256 kernel_size=4, stride=2, padding=1, bias=False ), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), # 第4层:8×8 → 4×4特征图 nn.Conv2d( in_channels=ndf * 4, out_channels=ndf * 8, # 输出通道:64×8=512 kernel_size=4, stride=2, padding=1, bias=False ), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), # 第5层:4×4 → 1×1概率输出 nn.Conv2d( in_channels=ndf * 8, out_channels=1, # 输出通道:1(真实/伪造的概率) kernel_size=4, stride=1, # 步长1(不改变尺寸) padding=0, # 填充0 bias=False ), nn.Sigmoid() # 将输出压缩到[0,1](概率值) ) def forward(self, x): """前向传播:输入图像,输出真实概率(0~1)""" return self.model(x) def print_discriminator_layers(model, input_tensor): """辅助函数:逐层打印判别器的输出形状""" print("===== 判别器逐层输出 =====") x = input_tensor # 初始输入(图像) print(f"输入图像形状: {x.shape}") # 打印输入图像形状 # 遍历每一层,计算并打印输出形状 for i, layer in enumerate(model): x = layer(x) print(f"第{i+1}层 | {layer.__class__.__name__} | 输出形状: {x.shape}") # 测试判别器并打印每一层输出 if __name__ == "__main__": # 检查设备(CPU/GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}\n") # 初始化判别器并移动到设备 discriminator = Discriminator().to(device) print("判别器结构:") print(discriminator.model) # 打印Sequential内部的层结构 print() # 生成测试输入(批量大小=2,3通道64×64图像,模拟真实图像或生成器输出) batch_size = 2 test_image = torch.randn(batch_size, num_channels, image_size, image_size, device=device) # 逐层打印输出形状 with torch.no_grad(): # 不计算梯度(仅测试用) print_discriminator_layers(discriminator.model, test_image) # 最终输出结果 with torch.no_grad(): output_prob = discriminator(test_image) print("\n===== 最终结果 =====") print(f"判别器输出概率形状: {output_prob.shape}") print("输出为(batch_size, 1, 1, 1),每个值代表输入图像为真实图像的概率(0~1)")
单个样本的训练
训练生成器
最小化判别器对假图像的识别能力(让判别器认为假图像是真实的,标签1)。
生成器netG生成的假照片fake_images(64*64*3)
fake_images(64*64*3)记过判别器netD输出 output_fake(0.001)
我们希望它的真实标签torch.ones_like(output_fake, device=device)是1,所以计算它和真实标
签的损失。
训练判别器
真实的照片(real_images)经过判别器(netD)得到真实的输出(output_real)0.4126
我们希望它的真实标签是1,所以计算真实输出和真实标签的损失 (lossD_real)0.8846
生成器netG生成的假照片fake_images(64*64*3)
fake_images(64*64*3)记过判别器netD输出 output_fake(0.001)
我们希望它的真实标签torch.ones_like(output_fake, device=device)是0,所以计算它和真实标
签的损失(lossD_real)0.6961.
最后判别器的损失等于这两个损失的和:
lossD = lossD_real + lossD_fake
完整的代码
import os import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader from PIL import Image import torchvision.utils as vutils import matplotlib.pyplot as plt import numpy as np # 超参数设置(单张图片训练可适当调小batch_size和epoch) batch_size = 1 # 单张图片训练,batch_size设为1更合适 image_size = 64 nz = 100 # 生成器输入噪声维度 ngf = 64 # 生成器特征图深度 ndf = 64 # 判别器特征图深度 num_epochs = 1 # 单张图训练可减少轮次,避免过拟合 lr = 0.0002 beta1 = 0.5 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"使用设备: {device}") # 数据预处理(保持与原图尺寸/通道一致) transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # RGB三通道标准化 ]) # 自定义单张图片数据集:重复使用单张图片作为训练数据 class SingleImageDataset(Dataset): def __init__(self, image_path, transform=None, repeat=1000): """ :param image_path: 单张图片的路径 :param transform: 预处理函数 :param repeat: 重复次数(控制数据集大小,至少为num_epochs*batch_size) """ self.image = Image.open(image_path).convert('RGB') # 加载图片并转为RGB self.transform = transform self.repeat = repeat # 重复生成该图片的次数 def __len__(self): return self.repeat # 数据集长度=重复次数 def __getitem__(self, idx): img = self.image.copy() # 复制图片避免重复操作同一对象 if self.transform: img = self.transform(img) # 应用预处理 return img, 0 # 标签仅占位(无实际意义) # 定义生成器(与原图一致,输出3通道RGB图像) class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() self.model = nn.Sequential( nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False), # 输出3通道(RGB) nn.Tanh() ) def forward(self, x): return self.model(x) # 定义判别器(与原图一致,输入3通道RGB图像) class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Conv2d(3, ndf, 4, 2, 1, bias=False), # 输入3通道(RGB) nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, x): return self.model(x) # 初始化网络 netG = Generator().to(device) netD = Discriminator().to(device) # 损失函数和优化器 criterion = nn.BCELoss() optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) # 加载单张图片数据集(替换为你的图片路径) image_path = "test_image.jpg" # 单张训练图片的路径 dataset = SingleImageDataset( image_path=image_path, transform=transform, repeat=num_epochs * batch_size # 确保数据集长度足够覆盖所有训练迭代 ) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) # 单图无需打乱 # 训练GAN print("开始单张图片训练...") for epoch in range(num_epochs): # 每轮训练开始前初始化累计损失(用于统计单轮平均损失) total_lossD = 0.0 total_lossG = 0.0 for i, data in enumerate(dataloader): # -------------------------- # 1. 加载真实图像并准备变量 # -------------------------- real_images = data[0].to(device) # 单张图片的批次数据(形状:[batch_size, 3, 64, 64]) batch_size_current = real_images.size(0) # 当前批次大小(单图训练时为1 # 打印批次信息(每10个批次打印一次,避免输出过多) if i % 1 == 0: print(f"\n===== Epoch [{epoch+1}/{num_epochs}] | Batch [{i+1}/{len(dataloader)}] =====") print(f"真实图像形状: {real_images.shape} | 批次大小: {batch_size_current}") # -------------------------- # 2. 训练判别器(Discriminator) # 目标:最大化对真实图像的识别(标签1)和对假图像的识别(标签0) # -------------------------- netD.zero_grad() # 清零判别器梯度 # 2.1 计算真实图像的损失 output_real = netD(real_images).view(-1) # 判别器对真实图像的输出(形状:[batch_size],值为0~1的概率)) lossD_real = criterion(output_real, torch.ones_like(output_real, device=device)) # 标签为1(真实) print(lossD_real) # 2.2 计算生成图像(假图像)的损失 noise = torch.randn(batch_size_current, nz, 1, 1, device=device) # 随机噪声(形状:[batch_size, 100, 1, 1]) fake_images = netG(noise) # 生成器生成假图像(形状:[batch_size, 3, 64, 64]) output_fake = netD(fake_images.detach()).view(-1) # 判别器对假图像的输出(detach():不更新生成器梯度) lossD_fake = criterion(output_fake, torch.zeros_like(output_fake, device=device)) # 标签为0(伪造) print(lossD_fake) # 2.3 总损失反向传播并更新判别器参数 lossD = lossD_real + lossD_fake print(lossD) lossD.backward() # 计算判别器梯度 optimizerD.step() # 更新判别器权重 # -------------------------- # 3. 训练生成器(Generator) # 目标:最小化判别器对假图像的识别能力(让判别器认为假图像是真实的,标签1) # -------------------------- netG.zero_grad() # 清零生成器梯度 output_fake = netD(fake_images).view(-1) # 重新计算判别器对假图像的输出(保留梯度,用于更新生成器) print(torch.ones_like(output_fake, device=device)) lossG = criterion(output_fake, torch.ones_like(output_fake, device=device)) # 标签为1(欺骗判别器) lossG.backward() # 计算生成器梯度 optimizerG.step() # 更新生成器权重 # -------------------------- # 4. 记录并打印详细信息 # -------------------------- total_lossD += lossD.item() total_lossG += lossG.item() # 打印当前批次的详细指标(每10个批次打印一次) if i % 1 == 0: # 打印判别器对真实/假图像的输出概率(均值,反映判别能力) print(f"判别器对真实图像的输出概率(均值): {output_real.mean().item():.4f}") # 接近1表示判别器能识别真实图像 print(f"判别器对假图像的输出概率(均值): {output_fake.mean().item():.4f}") # 接近0表示判别器能识别假图像 # 打印当前批次的损失 print(f"当前批次损失 - D_real: {lossD_real.item():.4f}, D_fake: {lossD_fake.item():.4f}, D_total: {lossD.item():.4f}, G: {lossG.item():.4f}") # -------------------------- # 5. 打印本轮训练的平均损失 # -------------------------- avg_lossD = total_lossD / len(dataloader) avg_lossG = total_lossG / len(dataloader) print(f"\n===== Epoch [{epoch+1}/{num_epochs}] 训练结束 =====") print(f"平均损失 - D: {avg_lossD:.4f}, G: {avg_lossG:.4f}") print("--------------------------------------------------") print("训练完成!")
公式理解
目标函数
判别器 D 的目标(最大化)
生成器 G 的目标(最小化)