热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

从零开始掌握PyTorch:生成对抗网络GAN进阶指南(第九篇)

本文将深入探讨生成对抗网络(GAN)在计算机视觉领域的应用。作为该领域的经典模型,GAN通过生成器和判别器的对抗训练,能够高效地生成高质量的图像。本文不仅回顾了GAN的基本原理,还将介绍一些最新的进展和技术优化方法,帮助读者全面掌握这一重要工具。

开篇

在计算机视觉方向我们介绍了不少基础网络了,今天介绍的这种又是计算机视觉方向的一个骨灰级网络——GAN。GAN又名生成对抗网络,其主要作用是图像生成,我们在用图像训练模型的时候需要大量的数据集。但是如果我们的数据集不够怎么办呢?我们可以利用数据增强的方法,对图像进行上下左右的翻转,做随即剪切,也可以自己生成图像。这个生成图像就会用到我们的GAN网络。
GAN网络之所叫对抗网络是因为其内部有两个编码器,一个generator和一个discriminator。一个用于编码生成图像,一个用于将图像解码。generator企图生成的图像足够像原始图像,企图以假乱真;而discriminator企图戳穿generator的把戏,将其精准辨别真伪。整个网络就在二者的博弈中生成了图像。discriminator主要是判别generator产生的编码和真实图像的解码是否相似,不断提高二者的相似度,最终生成了可以以假乱真的图像。
其实通过这个描述大家就可以意识到,这应该是一个最小最大问题或者是一个最大最小问题。因为discriminator拼命想区分二者,所以他应该让二者区别足够大;而generator拼命想效仿,所以他应该让二者区别足够小。
这里简单介绍一下GAN,详细介绍可以参考GAN原理学习。我们主要看代码实现。

GAN生成对抗网络

库的引入

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image~

设备的配置以及超参数的定义

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'sample'~

如果你了解GAN的话,你应该可以清楚这个latent size,他其实是我们在generator生成图像网络中的隐藏层特征尺寸。
图片的生成地址
我们最终要把生成的图片放到一个文件夹中,所以我们创建一个目录用于存储生成的图片

if not os.path.exists(sample_dir):os.makedirs(sample_dir)~

图像的处理和转换

transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])~

我们将像素点进行归一化,均值为0.5,方差也为0.5。
这里说明一下,由于我们所用的图像要经过灰度转化变为灰度图,所以这里的channel是1维,如果是彩色图,我们则有三个channels,需要对每一个channel都指定均值和方差
数据的引入和加载

mnist = torchvision.datasets.MNIST(root = '../../data/',train = True,transforms = transforms,download = True)
data_loader = torch.utils.data.DataLoader(dataset = mnist,batch_size = batch_size,shuffle = True)~

GAN网络是在对抗的过程中逐步完善逐步提高以假乱真的水平,因此我们不需要分为测试机和训练集,用一份统一的数据就可以了。
Generator和Discrimator的定义

# Discrimator
D = nn.Sequential(nn.Linear(image_size,hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size,hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size,1),nn.Sigmoid())# Generator
G = nn.Sequential(nn.Linear(latent_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,image_size),nn.Tanh())D = D.to(device)
G = G.to(device)~

D与G的结构很相似,区别在于他们的激活函数,在Discrimator中我们通常使用leakyrelu,最后一层使用sigmoid来生成概率。而generator中我们激活函数是relu,最后一层使用双曲正切函数,因为它只用于解码生成图像,不需要计算概率。
损失函数和优化器的定义

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr = 0.0002)
g_optimizer = torch.optim.Adam(G.parameters(),lr = 0.0002)~

这里说明一下,我们使用的BCELoss是二分交叉熵损失函数,这个损失函数具体的形式大家可以看前文的超链接中提到的公式,这里不做展开了。
辅助函数的定义

def denorm(x):out = (x + 1) / 2# 将out限制在0-1return out.clamp(0,1)
# 重置梯度
def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()

训练模型

total_step = len(data_loader)
for epoch in range(num_epochs):for i,(images,_) in enumerate(data_loader):images = images.reshape(batch_size,-1).to(device)# 为计算损失函数生成标签,真是标签是1,虚假标签是0real_labels = torch.ones(batch_size,1).to(device)fake_labels = torch.zeros(batch_size,1).to(device)# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# 这里的损失函数应该都是0,因为我们用真实图片去测试,损失函数一定是0# 我们的目的是将对的分到real中,错的分到fake中,所以要求2个损失outputs = D(images)d_loss_real = criterion(outputs,real_labels)real_score = outputs# Compute BCELoss using fake images# 这里的损失函数应该是1,因为我们用的是虚假图片,且为随机生成的码z = torch.randn(batch_size,latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs,fake_labels)fake_score = outputs# 反向传播优化d_loss = d_loss_fake + d_loss_real# 清空梯度reset_grad()d_loss.backward()d_optimizer.step()# 训练生成器# 用虚假图片计算损失z = torch.randn(batch_size,latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))g_loss = criterion(outputs,real_labels)# 反向传播优化reset_grad()g_loss.backward()g_optimizer.step()if (i+1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),real_score.mean().item(), fake_score.mean().item()))# 保存真实图像if (epoch + 1) == 1:images = images.reshape(images.size(0),1,28,28)save_image(denorm(images),os.path.join(sample_dir,'real_images.png'))# 保存样本图像fake_images = fake_images.reshape(fake_images.size(0),1,28,28)save_image(denorm(fake_images),os.path.join(sample_dir,'fake_images-{}.png'.format(epoch+1)))

保存模型

torch.save(G.state_dict(),'G.cpkt')
torch.save(D.state_dict(),'D.cpkt')

总结

GAN是一种比较常用的生成图像或者是判断两个图像间差异的网络,应用较多而且还有很多变体,比如DCGAN或者是CGAN,大家如果感兴趣可以精读一下相关论文。好啦GAN就介绍到这里啦,下次我们说VAE变分自编码器。


推荐阅读
  • 本文探讨了图像标签的多种分类场景及其在以图搜图技术中的应用,涵盖了从基础理论到实际项目实施的全面解析。 ... [详细]
  • 为了解决不同服务器间共享图片的需求,我们最初考虑建立一个FTP图片服务器。然而,考虑到项目是一个简单的CMS系统,为了简化流程,团队决定探索七牛云存储的解决方案。本文将详细介绍使用七牛云存储的过程和心得。 ... [详细]
  • 深入理解Java多线程并发处理:基础与实践
    本文探讨了Java中的多线程并发处理机制,从基本概念到实际应用,帮助读者全面理解并掌握多线程编程技巧。通过实例解析和理论阐述,确保初学者也能轻松入门。 ... [详细]
  • 深入浅出TensorFlow数据读写机制
    本文详细介绍TensorFlow中的数据读写操作,包括TFRecord文件的创建与读取,以及数据集(dataset)的相关概念和使用方法。 ... [详细]
  • 基于2-channelnetwork的图片相似度判别一、相关理论本篇博文主要讲解2015年CVPR的一篇关于图像相似度计算的文章:《LearningtoCompar ... [详细]
  • 利用Java与Tesseract-OCR实现数字识别
    本文深入探讨了如何利用Java语言结合Tesseract-OCR技术来实现图像中的数字识别功能,旨在为开发者提供详细的指导和实践案例。 ... [详细]
  • 本文详细介绍了 TensorFlow 的入门实践,特别是使用 MNIST 数据集进行数字识别的项目。文章首先解析了项目文件结构,并解释了各部分的作用,随后逐步讲解了如何通过 TensorFlow 实现基本的神经网络模型。 ... [详细]
  • Python + Pytest 接口自动化测试中 Token 关联登录的实现方法
    本文将深入探讨 Python 和 Pytest 在接口自动化测试中如何实现 Token 关联登录,内容详尽、逻辑清晰,旨在帮助读者掌握这一关键技能。 ... [详细]
  • 回顾与学习是进步的阶梯。再次审视卷积神经网络(CNNs),我对之前不甚明了的概念有了更深的理解。本文旨在分享这些新的见解,并探讨CNNs在图像识别和自然语言处理等领域中的实际应用。 ... [详细]
  • 大数据时代的机器学习:人工特征工程与线性模型的局限
    本文探讨了在大数据背景下,人工特征工程与线性模型的应用及其局限性。随着数据量的激增和技术的进步,传统的特征工程方法面临挑战,文章提出了未来发展的可能方向。 ... [详细]
  • 如何用GPU服务器运行Python
    如何用GPU服务器运行Python-目录前言一、服务器登录1.1下载安装putty1.2putty远程登录 1.3查看GPU、显卡常用命令1.4Linux常用命令二、 ... [详细]
  • 本文详细介绍了如何在Python和PyTorch环境中实现Tensor与NumPy数组之间的转换,以及PIL图像对象与NumPy数组之间的相互转换。内容包括具体的转换函数及其使用示例。 ... [详细]
  • 本文介绍了一个使用Keras框架构建的卷积神经网络(CNN)实例,主要利用了Keras提供的MNIST数据集以及相关的层,如Dense、Dropout、Activation等,构建了一个具有两层卷积和两层全连接层的CNN模型。 ... [详细]
  • 本文介绍了如何通过十折交叉验证方法评估回归模型的性能。我们将使用PyTorch框架,详细展示数据处理、模型定义、训练及评估的完整流程。 ... [详细]
  • 吴裕雄探讨混合神经网络模型在深度学习中的应用:结合RNN与CNN优化网络性能
    本文由吴裕雄撰写,深入探讨了如何利用Python、Keras及TensorFlow构建混合神经网络模型,特别是通过结合递归神经网络(RNN)和卷积神经网络(CNN),实现对网络运行效率的有效提升。 ... [详细]
author-avatar
王碧婷568473
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有