开篇
在计算机视觉方向我们介绍了不少基础网络了,今天介绍的这种又是计算机视觉方向的一个骨灰级网络——GAN。GAN又名生成对抗网络,其主要作用是图像生成,我们在用图像训练模型的时候需要大量的数据集。但是如果我们的数据集不够怎么办呢?我们可以利用数据增强的方法,对图像进行上下左右的翻转,做随即剪切,也可以自己生成图像。这个生成图像就会用到我们的GAN网络。
GAN网络之所叫对抗网络是因为其内部有两个编码器,一个generator和一个discriminator。一个用于编码生成图像,一个用于将图像解码。generator企图生成的图像足够像原始图像,企图以假乱真;而discriminator企图戳穿generator的把戏,将其精准辨别真伪。整个网络就在二者的博弈中生成了图像。discriminator主要是判别generator产生的编码和真实图像的解码是否相似,不断提高二者的相似度,最终生成了可以以假乱真的图像。
其实通过这个描述大家就可以意识到,这应该是一个最小最大问题或者是一个最大最小问题。因为discriminator拼命想区分二者,所以他应该让二者区别足够大;而generator拼命想效仿,所以他应该让二者区别足够小。
这里简单介绍一下GAN,详细介绍可以参考GAN原理学习。我们主要看代码实现。
GAN生成对抗网络
库的引入
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image~
设备的配置以及超参数的定义
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'sample'~
如果你了解GAN的话,你应该可以清楚这个latent size,他其实是我们在generator生成图像网络中的隐藏层特征尺寸。
图片的生成地址
我们最终要把生成的图片放到一个文件夹中,所以我们创建一个目录用于存储生成的图片
if not os.path.exists(sample_dir):os.makedirs(sample_dir)~
图像的处理和转换
transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])~
我们将像素点进行归一化,均值为0.5,方差也为0.5。
这里说明一下,由于我们所用的图像要经过灰度转化变为灰度图,所以这里的channel是1维,如果是彩色图,我们则有三个channels,需要对每一个channel都指定均值和方差
数据的引入和加载
mnist = torchvision.datasets.MNIST(root = '../../data/',train = True,transforms = transforms,download = True)、
data_loader = torch.utils.data.DataLoader(dataset = mnist,batch_size = batch_size,shuffle = True)~
GAN网络是在对抗的过程中逐步完善逐步提高以假乱真的水平,因此我们不需要分为测试机和训练集,用一份统一的数据就可以了。
Generator和Discrimator的定义
D = nn.Sequential(nn.Linear(image_size,hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size,hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size,1),nn.Sigmoid())
G = nn.Sequential(nn.Linear(latent_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,image_size),nn.Tanh())D = D.to(device)
G = G.to(device)~
D与G的结构很相似,区别在于他们的激活函数,在Discrimator中我们通常使用leakyrelu,最后一层使用sigmoid来生成概率。而generator中我们激活函数是relu,最后一层使用双曲正切函数,因为它只用于解码生成图像,不需要计算概率。
损失函数和优化器的定义
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr = 0.0002)
g_optimizer = torch.optim.Adam(G.parameters(),lr = 0.0002)~
这里说明一下,我们使用的BCELoss是二分交叉熵损失函数,这个损失函数具体的形式大家可以看前文的超链接中提到的公式,这里不做展开了。
辅助函数的定义
def denorm(x):out = (x + 1) / 2return out.clamp(0,1)
def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()
训练模型
total_step = len(data_loader)
for epoch in range(num_epochs):for i,(images,_) in enumerate(data_loader):images = images.reshape(batch_size,-1).to(device)real_labels = torch.ones(batch_size,1).to(device)fake_labels = torch.zeros(batch_size,1).to(device)outputs = D(images)d_loss_real = criterion(outputs,real_labels)real_score = outputsz = torch.randn(batch_size,latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs,fake_labels)fake_score = outputsd_loss = d_loss_fake + d_loss_realreset_grad()d_loss.backward()d_optimizer.step()z = torch.randn(batch_size,latent_size).to(device)fake_images = G(z)outputs = D(fake_images)g_loss = criterion(outputs,real_labels)reset_grad()g_loss.backward()g_optimizer.step()if (i+1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),real_score.mean().item(), fake_score.mean().item()))if (epoch + 1) == 1:images = images.reshape(images.size(0),1,28,28)save_image(denorm(images),os.path.join(sample_dir,'real_images.png'))fake_images = fake_images.reshape(fake_images.size(0),1,28,28)save_image(denorm(fake_images),os.path.join(sample_dir,'fake_images-{}.png'.format(epoch+1)))
保存模型
torch.save(G.state_dict(),'G.cpkt')
torch.save(D.state_dict(),'D.cpkt')
总结
GAN是一种比较常用的生成图像或者是判断两个图像间差异的网络,应用较多而且还有很多变体,比如DCGAN或者是CGAN,大家如果感兴趣可以精读一下相关论文。好啦GAN就介绍到这里啦,下次我们说VAE变分自编码器。