热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

从零开始掌握PyTorch:生成对抗网络GAN进阶指南(第九篇)

本文将深入探讨生成对抗网络(GAN)在计算机视觉领域的应用。作为该领域的经典模型,GAN通过生成器和判别器的对抗训练,能够高效地生成高质量的图像。本文不仅回顾了GAN的基本原理,还将介绍一些最新的进展和技术优化方法,帮助读者全面掌握这一重要工具。

开篇

在计算机视觉方向我们介绍了不少基础网络了,今天介绍的这种又是计算机视觉方向的一个骨灰级网络——GAN。GAN又名生成对抗网络,其主要作用是图像生成,我们在用图像训练模型的时候需要大量的数据集。但是如果我们的数据集不够怎么办呢?我们可以利用数据增强的方法,对图像进行上下左右的翻转,做随即剪切,也可以自己生成图像。这个生成图像就会用到我们的GAN网络。
GAN网络之所叫对抗网络是因为其内部有两个编码器,一个generator和一个discriminator。一个用于编码生成图像,一个用于将图像解码。generator企图生成的图像足够像原始图像,企图以假乱真;而discriminator企图戳穿generator的把戏,将其精准辨别真伪。整个网络就在二者的博弈中生成了图像。discriminator主要是判别generator产生的编码和真实图像的解码是否相似,不断提高二者的相似度,最终生成了可以以假乱真的图像。
其实通过这个描述大家就可以意识到,这应该是一个最小最大问题或者是一个最大最小问题。因为discriminator拼命想区分二者,所以他应该让二者区别足够大;而generator拼命想效仿,所以他应该让二者区别足够小。
这里简单介绍一下GAN,详细介绍可以参考GAN原理学习。我们主要看代码实现。

GAN生成对抗网络

库的引入

import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image~

设备的配置以及超参数的定义

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'sample'~

如果你了解GAN的话,你应该可以清楚这个latent size,他其实是我们在generator生成图像网络中的隐藏层特征尺寸。
图片的生成地址
我们最终要把生成的图片放到一个文件夹中,所以我们创建一个目录用于存储生成的图片

if not os.path.exists(sample_dir):os.makedirs(sample_dir)~

图像的处理和转换

transforms = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5], std=[0.5])])~

我们将像素点进行归一化,均值为0.5,方差也为0.5。
这里说明一下,由于我们所用的图像要经过灰度转化变为灰度图,所以这里的channel是1维,如果是彩色图,我们则有三个channels,需要对每一个channel都指定均值和方差
数据的引入和加载

mnist = torchvision.datasets.MNIST(root = '../../data/',train = True,transforms = transforms,download = True)
data_loader = torch.utils.data.DataLoader(dataset = mnist,batch_size = batch_size,shuffle = True)~

GAN网络是在对抗的过程中逐步完善逐步提高以假乱真的水平,因此我们不需要分为测试机和训练集,用一份统一的数据就可以了。
Generator和Discrimator的定义

# Discrimator
D = nn.Sequential(nn.Linear(image_size,hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size,hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size,1),nn.Sigmoid())# Generator
G = nn.Sequential(nn.Linear(latent_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,hidden_size),nn.ReLU(),nn.Linear(hidden_size,image_size),nn.Tanh())D = D.to(device)
G = G.to(device)~

D与G的结构很相似,区别在于他们的激活函数,在Discrimator中我们通常使用leakyrelu,最后一层使用sigmoid来生成概率。而generator中我们激活函数是relu,最后一层使用双曲正切函数,因为它只用于解码生成图像,不需要计算概率。
损失函数和优化器的定义

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr = 0.0002)
g_optimizer = torch.optim.Adam(G.parameters(),lr = 0.0002)~

这里说明一下,我们使用的BCELoss是二分交叉熵损失函数,这个损失函数具体的形式大家可以看前文的超链接中提到的公式,这里不做展开了。
辅助函数的定义

def denorm(x):out = (x + 1) / 2# 将out限制在0-1return out.clamp(0,1)
# 重置梯度
def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()

训练模型

total_step = len(data_loader)
for epoch in range(num_epochs):for i,(images,_) in enumerate(data_loader):images = images.reshape(batch_size,-1).to(device)# 为计算损失函数生成标签,真是标签是1,虚假标签是0real_labels = torch.ones(batch_size,1).to(device)fake_labels = torch.zeros(batch_size,1).to(device)# Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# 这里的损失函数应该都是0,因为我们用真实图片去测试,损失函数一定是0# 我们的目的是将对的分到real中,错的分到fake中,所以要求2个损失outputs = D(images)d_loss_real = criterion(outputs,real_labels)real_score = outputs# Compute BCELoss using fake images# 这里的损失函数应该是1,因为我们用的是虚假图片,且为随机生成的码z = torch.randn(batch_size,latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs,fake_labels)fake_score = outputs# 反向传播优化d_loss = d_loss_fake + d_loss_real# 清空梯度reset_grad()d_loss.backward()d_optimizer.step()# 训练生成器# 用虚假图片计算损失z = torch.randn(batch_size,latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))g_loss = criterion(outputs,real_labels)# 反向传播优化reset_grad()g_loss.backward()g_optimizer.step()if (i+1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),real_score.mean().item(), fake_score.mean().item()))# 保存真实图像if (epoch + 1) == 1:images = images.reshape(images.size(0),1,28,28)save_image(denorm(images),os.path.join(sample_dir,'real_images.png'))# 保存样本图像fake_images = fake_images.reshape(fake_images.size(0),1,28,28)save_image(denorm(fake_images),os.path.join(sample_dir,'fake_images-{}.png'.format(epoch+1)))

保存模型

torch.save(G.state_dict(),'G.cpkt')
torch.save(D.state_dict(),'D.cpkt')

总结

GAN是一种比较常用的生成图像或者是判断两个图像间差异的网络,应用较多而且还有很多变体,比如DCGAN或者是CGAN,大家如果感兴趣可以精读一下相关论文。好啦GAN就介绍到这里啦,下次我们说VAE变分自编码器。


推荐阅读
  • 点互信息在自然语言处理中的应用与优化
    点互信息(Pointwise Mutual Information, PMI)是一种用于评估两个事件之间关联强度的统计量,在自然语言处理领域具有广泛应用。本文探讨了 PMI 在词共现分析、语义关系提取和情感分析等任务中的具体应用,并提出了几种优化方法,以提高其在大规模数据集上的计算效率和准确性。通过实验验证,这些优化策略显著提升了模型的性能。 ... [详细]
  • 深入解析 ELF 文件格式与静态链接技术
    本文详细探讨了ELF文件格式及其在静态链接过程中的应用。在C/C++代码转化为可执行文件的过程中,需经过预处理、编译、汇编和链接等关键步骤。最终生成的可执行文件不仅包含系统可识别的机器码,还遵循了严格的文件结构规范,以确保其在操作系统中的正确加载和执行。 ... [详细]
  • 【并发编程】全面解析 Java 内存模型,一篇文章带你彻底掌握
    本文深入解析了 Java 内存模型(JMM),从基础概念到高级特性进行全面讲解,帮助读者彻底掌握 JMM 的核心原理和应用技巧。通过详细分析内存可见性、原子性和有序性等问题,结合实际代码示例,使开发者能够更好地理解和优化多线程并发程序。 ... [详细]
  • 深入解析 UIImageView 与 UIImage 的关键细节与应用技巧
    本文深入探讨了 UIImageView 和 UIImage 的核心特性及应用技巧。首先,详细介绍了如何在 UIImageView 中实现动画效果,包括创建和配置 UIImageView 实例的具体步骤。此外,还探讨了 UIImage 的加载方式及其对性能的影响,提供了优化图像显示和内存管理的有效方法。通过实例代码和实际应用场景,帮助开发者更好地理解和掌握这两个重要类的使用技巧。 ... [详细]
  • 本文深入探讨了 Python Watchdog 库的使用方法和应用场景。通过详细的代码示例,展示了如何利用 Watchdog 监控文件系统的变化,包括文件的创建、修改和删除等操作。文章不仅介绍了 Watchdog 的基本功能,还探讨了其在实际项目中的高级应用,如日志监控和自动化任务触发。读者将能够全面了解 Watchdog 的工作原理及其在不同场景下的应用技巧。 ... [详细]
  • MySQL:不仅仅是数据库那么简单
    MySQL不仅是一款高效、可靠的数据库管理系统,它还具备丰富的功能和扩展性,支持多种存储引擎,适用于各种应用场景。从简单的网站开发到复杂的企业级应用,MySQL都能提供强大的数据管理和优化能力,满足不同用户的需求。其开源特性也促进了社区的活跃发展,为技术进步提供了持续动力。 ... [详细]
  • 在探讨C语言编程文本编辑器的最佳选择与专业推荐时,本文将引导读者构建一个基础的文本编辑器程序。该程序不仅能够打开并显示文本文件的内容及其路径,还集成了菜单和工具栏功能,为用户提供更加便捷的操作体验。通过本案例的学习,读者可以深入了解文本编辑器的核心实现机制。 ... [详细]
  • 2021年7月22日上午9点至中午12点,我专注于Java的学习,重点补充了之前在视频中遗漏的多线程知识。首先,我了解了进程的概念,即程序在内存中运行时形成的一个独立执行单元。其次,学习了线程作为进程的组成部分,是进程中可并发执行的最小单位,负责处理具体的任务。此外,我还深入研究了Runnable接口的使用方法及其在多线程编程中的重要作用。 ... [详细]
  • 在Python编程中,探讨了并发与并行的概念及其区别。并发指的是系统同时处理多个任务的能力,而并行则指在同一时间点上并行执行多个任务。文章详细解析了阻塞与非阻塞操作、同步与异步编程模型,以及IO多路复用技术的应用。通过模拟socket发送HTTP请求的过程,展示了如何创建连接、发送数据和接收响应,并强调了默认情况下socket的阻塞特性。此外,还介绍了如何利用这些技术优化网络通信性能和提高程序效率。 ... [详细]
  • 本书详细介绍了在最新Linux 4.0内核环境下进行Java与Linux设备驱动开发的全面指南。内容涵盖设备驱动的基本概念、开发环境的搭建、操作系统对设备驱动的影响以及具体开发步骤和技巧。通过丰富的实例和深入的技术解析,帮助读者掌握设备驱动开发的核心技术和最佳实践。 ... [详细]
  • Java服务问题快速定位与解决策略全面指南 ... [详细]
  • 从零起步:使用IntelliJ IDEA搭建Spring Boot应用的详细指南
    从零起步:使用IntelliJ IDEA搭建Spring Boot应用的详细指南 ... [详细]
  • 在Spring Boot项目中,通过YAML配置文件为静态变量设置值的方法与实践涉及以下几个步骤:首先,创建一个新的配置类。需要注意的是,自动生成的setter方法默认是非静态的,因此需要手动将其修改为静态方法,以确保静态变量能够正确初始化。此外,建议使用`@Value`注解或`@ConfigurationProperties`注解来注入配置属性,以提高代码的可读性和维护性。 ... [详细]
  • 本文深入探讨了 MXOTDLL.dll 在 C# 环境中的应用与优化策略。针对近期公司从某生物技术供应商采购的指纹识别设备,该设备提供的 DLL 文件是用 C 语言编写的。为了更好地集成到现有的 C# 系统中,我们对原生的 C 语言 DLL 进行了封装,并利用 C# 的互操作性功能实现了高效调用。此外,文章还详细分析了在实际应用中可能遇到的性能瓶颈,并提出了一系列优化措施,以确保系统的稳定性和高效运行。 ... [详细]
  • 在Hive中合理配置Map和Reduce任务的数量对于优化不同场景下的性能至关重要。本文探讨了如何控制Hive任务中的Map数量,分析了当输入数据超过128MB时是否会自动拆分,以及Map数量是否越多越好的问题。通过实际案例和实验数据,本文提供了具体的配置建议,帮助用户在不同场景下实现最佳性能。 ... [详细]
author-avatar
王碧婷568473
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有