热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

pytorch实现GAN(生成对抗网络)生成二次元头像(附代码)

目录GAN基本概念GAN算法流程代码实现与讲解1、准备数据集代码实现定义鉴别器定义生成器训练补充附完整代码参考链接及书目GAN基本概念GAN,全称Genera

目录

GAN基本概念

 GAN算法流程

代码实现与讲解 

1、准备数据集

代码实现

定义鉴别器

定义生成器

训练

补充

附完整代码

参考链接及书目



GAN基本概念

GAN, 全称Generative Adversarial Networks,中文名为生成对抗网络,是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。主要包括生成网络和对抗网络。想要具体了解其实现原理的可以看一下Ian Goodfellow大牛的论文:Generative Adversarial Networks,这篇paper算是这个领域的开山之作。

使用生成对抗网络的目的就是生成一些接近真实的东西,比如让机器自己生成一幅画、一段文字,甚至是在数据匮乏的情况下,生成一些我们需要的数据集等等。而“对抗”就是为了让网络变得“聪明”。比如有人靠仿制前朝的翡翠、瓷器等宝物谋生,为了制作出高仿的物品,他肯定要先看许多真的宝物,然后经过多次模仿和训练,他仿制的宝物骗过了鉴宝师的眼睛。但正所谓“魔高一尺道高一丈”,鉴定专家也不断提高自己的技术,慢慢的他以前的作品就被专家看出是假的,于是他又进行训练提高自己的仿制技术,再次骗过鉴宝专家的眼睛(宝友,可不兴啊~~~)。然而没过多久,鉴宝专家技术提升又识别出了他仿制的赝品,所以他又必须再提高,循环往复,这就是一个对抗的过程。

而生成器则主要是根据我们的输入,产生一些真实的输出,用来训练鉴别器识别真伪的能力。

 GAN算法流程

生成对抗网络的算法流程如下:

1、初始化生成器和鉴别器;

2、训练迭代直至满足条件。主要过程包括以下两个部分:

    1)固定生成器,升级鉴别器。向生成器输入随机向量,产生一些输出,标注为0,表示为假数据。然后从真实数据集抽取一些数据,标注为1,表示为真数据。用真假混合的数据集训练鉴别器(其实就是一个二分类模型)。

    2)固定鉴别器,升级生成器。将生成器和鉴别器连成一个网络,由生成器根据随机输入产生的结果传入鉴别器,鉴别器对数据的真实性进行打分,越真实得到的分数越接近1.在这个过程中,我们固定鉴别器参数,只更新生成器参数,使生成器产生的图片得到的分数越来越高,也就是越来越接近真实数据。

代码实现与讲解 

1、准备数据集

准备好真是的二次元头像数据集,该数据集是从著名的动漫图库网站konachan.net中爬取的。随机抽取数据集中的图片样本,像素大小为96*96(已对爬取的头像进行了处理)。

在代码文件所在同级目录创建名为imgs的文件夹,并在imgs下新建0和1两个文件夹,将真实数据集存放在文件夹imgs下的1文件夹中。

代码实现

首先导入需要用到的库,并定义图片预处理方式、训练集和训练加载器的工作方式。

import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 图片显示
def img_show(inputs, picname):plt.ion()inputs = inputs / 2 + 0.5inputs = inputs.numpy().transpose((1, 2, 0))plt.imshow(inputs)plt.pause(0.01)plt.savefig(picname + ".jpg")plt.close()# 串联多个变换操作
data_transform = transforms.Compose([transforms.RandomHorizontalFlip(), # 依概率p水平翻转,默认p=0.5transforms.ToTensor(), # 转为tensor,并归一化至[0-1]# 标准化,把[0-1]变换到[-1,1],其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定。# 原来的[0-1]最小值0变成(0-0.5)/0.5=-1,最大值1变成(1-0.5)/0.5=1transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])# 参数data_transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
train_set = datasets.ImageFolder('imgs', data_transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=5,shuffle=True, num_workers=4) # 数据加载inputs, _ = next(iter(train_loader))
# make_grid的作用是将若干幅图像拼成一幅图像
img_show(torchvision.utils.make_grid(inputs), "RealDataSample")


定义鉴别器

为了生成高质量的图片,使用深度卷积神经网络作为鉴别器,使用深度反卷积神经网络作为生成器。每次卷积之后使用批归一化和LeakyReLU激活函数加速收敛。代码如下:

# 定义鉴别器
class Discriminator(nn.Module):def __init__(self, nc, ndf):super(Discriminator, self).__init__()# 使用深度卷积网络作为鉴别器self.layer1 = nn.Sequential(nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ndf), nn.LeakyReLu(0.2, inplace=True))self.layer2 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ndf * 2), nn.LeakyReLu(0.2, inplace=True))self.layer3 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ndf * 4), nn.LeakyReLu(0.2, inplace=True))self.layer4 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ndf * 8), nn.LeakyReLu(0.2, inplace=True))self.fc = nn.Sequential(nn.Linear(256 * 6 * 6, 1), nn.Sigmoid())def forward(self, x):out = self.layer4(self.layer3(self.layer2(self.layer1(x))))out = self.fc(out.view(-1, 256 * 6 * 6))return out

定义生成器

生成器主要完成由随机向量生成图片的过程。代码如下:

# 定义生成器
class Generator(nn.Module):def __init__(self, nc, ngf, nz, feature_size):super(Generator, self).__init__()self.prj = nn.Linear(feature_size, nz * 6 * 6)# nn.Sequential:一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行self.layer1 = nn.Sequential(nn.ConvTranspose2d(nz, ngf * 4, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ngf * 4), nn.ReLu())self.layer2 = nn.Sequential(nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ngf * 2), nn.ReLu())self.layer3 = nn.Sequential(nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=4, stride=2, padding=1),nn.BatchNorm2d(ngf), nn.ReLu())self.layer4 = nn.Sequential(nn.ConvTranspose2d(ngf, nc, kernel_size=4, stride=2, padding=1),nn.Tanh())def forward(self, x):out = self.prj(x).view(-1, 1024, 6, 6)out = self.layer4(self.layer3(self.layer2(self.layer1(out))))return out

训练

在构建了鉴别器和生成器之后,先初始化鉴别器和生成器。

# 初始化鉴别器和生成器
d = Discriminator(3, 32)
g = Generator(3, 128, 1024, 100)

然后设置损失函数和优化器。

criterion = nn.BCELoss() # 损失函数
lr = 0.0003 # 学习率
d_optimizer = torch.optim.Adam(d.parameters(), lr=lr) # 定义鉴别器的优化器
g_optimizer = torch.optim.Adam(g.parameters(), lr=lr) # 定义生成器的优化器

定义训练函数。

# 训练过程
def train(d, g, criterion, d_optimizer, g_optimizer, epochs=1, show_every=1000, print_every=10):iter_count = 0for epoch in range(epochs):for inputs, _ in train_loader:real_inputs = inputs # 真实样本fake_inputs = g(torch.randn(5, 100)) # 伪造样本real_labels = torch.ones(real_inputs.size(0)) # 真实标签fake_labels = torch.zeros(5) # 伪造标签real_outputs = d(real_inputs)d_loss_real = criterion(real_outputs, real_labels)fake_outputs = d(fake_inputs)d_loss_fake = criterion(fake_outputs, fake_labels)d_loss = d_loss_real + d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()fake_inputs = g(torch.randn(5, 100))outputs = d(fake_inputs)real_labels = torch.ones(outputs.size(0))g_loss = criterion(outputs, real_labels)g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (iter_count % show_every == 0):print('Epoch:{}, Iter:{}, D:{.4}, G:{.4}'.format(epoch,iter_count,d_loss.item(),g_loss.item()))picname = "Epoch_" + str(epoch) + "Iter_" + str(iter_count)img_show(torchvision.utils.make_grid(fake_inputs.data), picname)if (iter_count % print_every == 0):print('Epoch:{}, Iter:{}, D:{.4}, G:{.4}'.format(epoch,iter_count,d_loss.item(),g_loss.item()))iter_count += 1print('Finished Training!')

开始训练。

# 训练
train(d, g, criterion, d_optimizer, g_optimizer, epochs=300)

补充:

代码中生成器使用了LeakyReLU激活函数,鉴别器使用了ReLU激活函数。两者区别如下:

LeakyReLU激活函数是ReLU的变体,其表达式为:

y_{i}&#61;\left\{\begin{matrix} x_{i} , x_{i} \geqslant 0& \\ \frac{x_{i}}{a_{i}} , x_{i}<0 & \end{matrix}\right.def train(d, g, criterion, d_optimizer, g_optimizer, epochs&#61;1, show_every&#61;1000, print_every&#61;10):iter_count &#61; 0for epoch in range(epochs):for inputs, _ in train_loader:real_inputs &#61; inputs # 真实样本fake_inputs &#61; g(torch.randn(5, 100)) # 伪造样本real_labels &#61; torch.ones(real_inputs.size(0)) # 真实标签fake_labels &#61; torch.zeros(5) # 伪造标签real_outputs &#61; d(real_inputs)d_loss_real &#61; criterion(real_outputs, real_labels)fake_outputs &#61; d(fake_inputs)d_loss_fake &#61; criterion(fake_outputs, fake_labels)d_loss &#61; d_loss_real &#43; d_loss_faked_optimizer.zero_grad()d_loss.backward()d_optimizer.step()fake_inputs &#61; g(torch.randn(5, 100))outputs &#61; d(fake_inputs)real_labels &#61; torch.ones(outputs.size(0))g_loss &#61; criterion(outputs, real_labels)g_optimizer.zero_grad()g_loss.backward()g_optimizer.step()if (iter_count % show_every &#61;&#61; 0):print(&#39;Epoch:{}, Iter:{}, D:{.4}, G:{.4}&#39;.format(epoch,iter_count,d_loss.item(),g_loss.item()))picname &#61; "Epoch_" &#43; str(epoch) &#43; "Iter_" &#43; str(iter_count)img_show(torchvision.utils.make_grid(fake_inputs.data), picname)if (iter_count % print_every &#61;&#61; 0):print(&#39;Epoch:{}, Iter:{}, D:{.4}, G:{.4}&#39;.format(epoch,iter_count,d_loss.item(),g_loss.item()))iter_count &#43;&#61; 1print(&#39;Finished Training&#xff01;&#39;)# 主程序
if __name__ &#61;&#61; &#39;__main__&#39;:# 串联多个变换操作data_transform &#61; transforms.Compose([transforms.RandomHorizontalFlip(), # 依概率p水平翻转&#xff0c;默认p&#61;0.5transforms.ToTensor(), # 转为tensor&#xff0c;并归一化至[0-1]# 标准化&#xff0c;把[0-1]变换到[-1,1]&#xff0c;其中mean和std分别通过(0.5,0.5,0.5)和(0.5,0.5,0.5)进行指定。# 原来的[0-1]最小值0变成(0-0.5)/0.5&#61;-1&#xff0c;最大值1变成(1-0.5)/0.5&#61;1transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])# 参数data_transform&#xff1a;对图片进行预处理的操作&#xff08;函数&#xff09;&#xff0c;原始图片作为输入&#xff0c;返回一个转换后的图片。train_set &#61; datasets.ImageFolder(&#39;imgs&#39;, data_transform)train_loader &#61; torch.utils.data.DataLoader(train_set, batch_size&#61;5,shuffle&#61;True, num_workers&#61;4) # 数据加载inputs, _ &#61; next(iter(train_loader))# make_grid的作用是将若干幅图像拼成一幅图像img_show(torchvision.utils.make_grid(inputs), "RealDataSample")# 初始化鉴别器和生成器d &#61; Discriminator(3, 32)g &#61; Generator(3, 128, 1024, 100)criterion &#61; nn.BCELoss() # 损失函数lr &#61; 0.0003 # 学习率d_optimizer &#61; torch.optim.Adam(d.parameters(), lr&#61;lr) # 定义鉴别器的优化器g_optimizer &#61; torch.optim.Adam(g.parameters(), lr&#61;lr) # 定义生成器的优化器# 训练train(d, g, criterion, d_optimizer, g_optimizer, epochs&#61;300)

参考链接及书目

Pytorch深度学习入门--曾芃壹

https://zhuanlan.zhihu.com/p/24767059

https://blog.csdn.net/qq_38410428/article/details/94719553

https://www.baidu.com/link?url&#61;WFl0YU3KyqRVxEK6sEclXW5Rrj7mEWaJ3hJR4VPKbB1RYP8R1My2a41FcxAEiBNW2D1mftNaXPEWM0_jDToXIW2usQVDbT60Jxs3kwWBYk7&wd&#61;&eqid&#61;a730217300106c0100000006610b3d6c


推荐阅读
author-avatar
繁華落盡灬熙
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有