热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

从零开始掌握PyTorch:生成对抗网络GAN进阶指南(第九篇)

本文将深入探讨生成对抗网络(GAN)在计算机视觉领域的应用。作为该领域的经典模型,GAN通过生成器和判别器的对抗训练,能够高效地生成高质量的图像。本文不仅回顾了GAN的基本原理,还将介绍一些最新的进展和技术优化方法,帮助读者全面掌握这一重要工具。

开篇

在计算机视觉方向我们介绍了不少基础网络了,今天介绍的这种又是计算机视觉方向的一个骨灰级网络——GAN。GAN又名生成对抗网络,其主要作用是图像生成,我们在用图像训练模型的时候需要大量的数据集。但是如果我们的数据集不够怎么办呢?我们可以利用数据增强的方法,对图像进行上下左右的翻转,做随即剪切,也可以自己生成图像。这个生成图像就会用到我们的GAN网络。
GAN网络之所叫对抗网络是因为其内部有两个编码器,一个generator和一个discriminator。一个用于编码生成图像,一个用于将图像解码。generator企图生成的图像足够像原始图像,企图以假乱真;而discriminator企图戳穿generator的把戏,将其精准辨别真伪。整个网络就在二者的博弈中生成了图像。discriminator主要是判别generator产生的编码和真实图像的解码是否相似,不断提高二者的相似度,最终生成了可以以假乱真的图像。
其实通过这个描述大家就可以意识到,这应该是一个最小最大问题或者是一个最大最小问题。因为discriminator拼命想区分二者,所以他应该让二者区别足够大;而generator拼命想效仿,所以他应该让二者区别足够小。
这里简单介绍一下GAN,详细介绍可以参考GAN原理学习。我们主要看代码实现。

GAN生成对抗网络

库的引入

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image~

设备的配置以及超参数的定义

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'sample'~

如果你了解GAN的话,你应该可以清楚这个latent size,他其实是我们在generator生成图像网络中的隐藏层特征尺寸。
图片的生成地址
我们最终要把生成的图片放到一个文件夹中,所以我们创建一个目录用于存储生成的图片

if not os.path.exists(sample_dir):os.makedirs(sample_dir)~

图像的处理和转换

transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])~

我们将像素点进行归一化,均值为0.5,方差也为0.5。
这里说明一下,由于我们所用的图像要经过灰度转化变为灰度图,所以这里的channel是1维,如果是彩色图,我们则有三个channels,需要对每一个channel都指定均值和方差
数据的引入和加载

mnist = torchvision.datasets.MNIST(root = '../../data/',train = True,transforms = transforms,download = True)
data_loader = torch.utils.data.DataLoader(dataset = mnist,batch_size = batch_size,shuffle = True)~

GAN网络是在对抗的过程中逐步完善逐步提高以假乱真的水平,因此我们不需要分为测试机和训练集,用一份统一的数据就可以了。
Generator和Discrimator的定义

# Discrimator
D = nn.Sequential(nn.Linear(image_size,hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size,hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size,1),nn.Sigmoid())# Generator
G = nn.Sequential(nn.Linear(latent_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,image_size),nn.Tanh())D = D.to(device)
G = G.to(device)~

D与G的结构很相似,区别在于他们的激活函数,在Discrimator中我们通常使用leakyrelu,最后一层使用sigmoid来生成概率。而generator中我们激活函数是relu,最后一层使用双曲正切函数,因为它只用于解码生成图像,不需要计算概率。
损失函数和优化器的定义

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr = 0.0002)
g_optimizer = torch.optim.Adam(G.parameters(),lr = 0.0002)~

这里说明一下,我们使用的BCELoss是二分交叉熵损失函数,这个损失函数具体的形式大家可以看前文的超链接中提到的公式,这里不做展开了。
辅助函数的定义

def denorm(x):out = (x + 1) / 2# 将out限制在0-1return out.clamp(0,1)
# 重置梯度
def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()

训练模型

total_step = len(data_loader)
for epoch in range(num_epochs):for i,(images,_) in enumerate(data_loader):images = images.reshape(batch_size,-1).to(device)# 为计算损失函数生成标签,真是标签是1,虚假标签是0real_labels = torch.ones(batch_size,1).to(device)fake_labels = torch.zeros(batch_size,1).to(device)# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# 这里的损失函数应该都是0,因为我们用真实图片去测试,损失函数一定是0# 我们的目的是将对的分到real中,错的分到fake中,所以要求2个损失outputs = D(images)d_loss_real = criterion(outputs,real_labels)real_score = outputs# Compute BCELoss using fake images# 这里的损失函数应该是1,因为我们用的是虚假图片,且为随机生成的码z = torch.randn(batch_size,latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs,fake_labels)fake_score = outputs# 反向传播优化d_loss = d_loss_fake + d_loss_real# 清空梯度reset_grad()d_loss.backward()d_optimizer.step()# 训练生成器# 用虚假图片计算损失z = torch.randn(batch_size,latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))g_loss = criterion(outputs,real_labels)# 反向传播优化reset_grad()g_loss.backward()g_optimizer.step()if (i+1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),real_score.mean().item(), fake_score.mean().item()))# 保存真实图像if (epoch + 1) == 1:images = images.reshape(images.size(0),1,28,28)save_image(denorm(images),os.path.join(sample_dir,'real_images.png'))# 保存样本图像fake_images = fake_images.reshape(fake_images.size(0),1,28,28)save_image(denorm(fake_images),os.path.join(sample_dir,'fake_images-{}.png'.format(epoch+1)))

保存模型

torch.save(G.state_dict(),'G.cpkt')
torch.save(D.state_dict(),'D.cpkt')

总结

GAN是一种比较常用的生成图像或者是判断两个图像间差异的网络,应用较多而且还有很多变体,比如DCGAN或者是CGAN,大家如果感兴趣可以精读一下相关论文。好啦GAN就介绍到这里啦,下次我们说VAE变分自编码器。


推荐阅读
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • 毕业设计:基于机器学习与深度学习的垃圾邮件(短信)分类算法实现
    本文详细介绍了如何使用机器学习和深度学习技术对垃圾邮件和短信进行分类。内容涵盖从数据集介绍、预处理、特征提取到模型训练与评估的完整流程,并提供了具体的代码示例和实验结果。 ... [详细]
  • 本文探讨了图像标签的多种分类场景及其在以图搜图技术中的应用,涵盖了从基础理论到实际项目实施的全面解析。 ... [详细]
  • MySQL索引详解与优化
    本文深入探讨了MySQL中的索引机制,包括索引的基本概念、优势与劣势、分类及其实现原理,并详细介绍了索引的使用场景和优化技巧。通过具体示例,帮助读者更好地理解和应用索引以提升数据库性能。 ... [详细]
  • 深入浅出TensorFlow数据读写机制
    本文详细介绍TensorFlow中的数据读写操作,包括TFRecord文件的创建与读取,以及数据集(dataset)的相关概念和使用方法。 ... [详细]
  • 基于2-channelnetwork的图片相似度判别一、相关理论本篇博文主要讲解2015年CVPR的一篇关于图像相似度计算的文章:《LearningtoCompar ... [详细]
  • 如何用GPU服务器运行Python
    如何用GPU服务器运行Python-目录前言一、服务器登录1.1下载安装putty1.2putty远程登录 1.3查看GPU、显卡常用命令1.4Linux常用命令二、 ... [详细]
  • 图神经网络模型综述
    本文综述了图神经网络(Graph Neural Networks, GNN)的发展,从传统的数据存储模型转向图和动态模型,探讨了模型中的显性和隐性结构,并详细介绍了GNN的关键组件及其应用。 ... [详细]
  • 本文详细介绍如何使用Python进行配置文件的读写操作,涵盖常见的配置文件格式(如INI、JSON、TOML和YAML),并提供具体的代码示例。 ... [详细]
  • 深入理解Tornado模板系统
    本文详细介绍了Tornado框架中模板系统的使用方法。Tornado自带的轻量级、高效且灵活的模板语言位于tornado.template模块,支持嵌入Python代码片段,帮助开发者快速构建动态网页。 ... [详细]
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 本文介绍了如何利用TensorFlow框架构建一个简单的非线性回归模型。通过生成200个随机数据点进行训练,模型能够学习并预测这些数据点的非线性关系。 ... [详细]
  • 利用Java与Tesseract-OCR实现数字识别
    本文深入探讨了如何利用Java语言结合Tesseract-OCR技术来实现图像中的数字识别功能,旨在为开发者提供详细的指导和实践案例。 ... [详细]
  • 在Ubuntu 16.04中使用Anaconda安装TensorFlow
    本文详细介绍了如何在Ubuntu 16.04系统上通过Anaconda环境管理工具安装TensorFlow。首先,需要下载并安装Anaconda,然后配置环境变量以确保系统能够识别Anaconda命令。接着,创建一个特定的Python环境用于安装TensorFlow,并通过指定的镜像源加速安装过程。最后,通过一个简单的线性回归示例验证TensorFlow的安装是否成功。 ... [详细]
  • 本文详细记录了作者从7月份的提前批到9、10月份正式批的秋招经历,包括各公司的面试流程、技术问题及HR面的常见问题。通过这次秋招,作者深刻体会到了技术积累和面试准备的重要性。 ... [详细]
author-avatar
王碧婷568473
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有