热门标签 | HotTags
当前位置:  开发笔记 > 开发工具 > 正文

pytorch:实现简单的GAN示例(MNIST数据集)

今天小编就为大家分享一篇pytorch:实现简单的GAN示例(MNIST数据集),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧

我就废话不多说了,直接上代码吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 设置画图的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定义画图工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定义一个取样的函数
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可视化图片效果
show_images(imgs)
 
#判别网络
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成网络
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判别器的 loss 就是将真实数据的得分判断为 1,假的数据的得分判断为 0,而生成器的 loss 就是将假的数据判断为 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵损失函数
 
def discriminator_loss(logits_real, logits_fake): # 判别器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 来进行训练,学习率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判别网络
      real_data = Variable(x).view(bs, -1) # 真实数据
      logits_real = D_net(real_data) # 判别网络得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均匀分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
      logits_fake = D_net(fake_images) # 判别网络得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判别器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 优化判别网络
      
      # 生成网络
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的数据
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成网络的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 优化生成网络
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)      

以上这篇pytorch:实现简单的GAN示例(MNIST数据集)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。


推荐阅读
  • 尤洋:夸父AI系统——大规模并行训练的深度学习解决方案
    自从AlexNet等模型在计算机视觉领域取得突破以来,深度学习技术迅速发展。近年来,随着BERT等大型模型的广泛应用,AI模型的规模持续扩大,对硬件提出了更高的要求。本文介绍了新加坡国立大学尤洋教授团队开发的夸父AI系统,旨在解决大规模模型训练中的并行计算挑战。 ... [详细]
  • 目录预备知识导包构建数据集神经网络结构训练测试精度可视化计算模型精度损失可视化输出网络结构信息训练神经网络定义参数载入数据载入神经网络结构、损失及优化训练及测试损失、精度可视化qu ... [详细]
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • 不用蘑菇,不拾金币,我通过强化学习成功通关29关马里奥,创造全新纪录
    《超级马里奥兄弟》由任天堂于1985年首次发布,是一款经典的横版过关游戏,至今已在多个平台上售出超过5亿套。该游戏不仅勾起了许多玩家的童年回忆,也成为强化学习领域的热门研究对象。近日,通过先进的强化学习技术,研究人员成功让AI通关了29关,创造了新的纪录。这一成就不仅展示了强化学习在游戏领域的潜力,也为未来的人工智能应用提供了宝贵的经验。 ... [详细]
  • 本文详细介绍了如何利用 Python 语言从文本文件中读取数据,并将其存储为字典格式,涵盖多种实用技巧和示例代码。 ... [详细]
  • 深入解析轻量级数据库 SQL Server Express LocalDB
    本文详细介绍了 SQL Server Express LocalDB,这是一种轻量级的本地 T-SQL 数据库解决方案,特别适合开发环境使用。文章还探讨了 LocalDB 与其他轻量级数据库的对比,并提供了安装和连接 LocalDB 的步骤。 ... [详细]
  • 本文详细介绍了如何在本地环境中安装配置Frida及其服务器组件,以及如何通过Frida进行基本的应用程序动态分析,包括获取应用版本和加载的类信息。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 在 PyTorch 的 `CrossEntropyLoss` 函数中,当目标标签 `target` 为类别 ID 时,实际上会进行 one-hot 编码处理。例如,假设总共有三个类别,其中一个类别的 ID 为 2,则该标签会被转换为 `[0, 0, 1]`。这一过程简化了多分类任务中的损失计算,使得模型能够更高效地进行训练和评估。此外,`CrossEntropyLoss` 还结合了 softmax 激活函数和负对数似然损失,进一步提高了模型的性能和稳定性。 ... [详细]
  • 本文探讨了BERT模型在自然语言处理领域的应用与实践。详细介绍了Transformers库(曾用名pytorch-transformers和pytorch-pretrained-bert)的使用方法,涵盖了从模型加载到微调的各个环节。此外,还分析了BERT在文本分类、情感分析和命名实体识别等任务中的性能表现,并讨论了其在实际项目中的优势和局限性。 ... [详细]
  • 在Windows环境下离线安装PyTorch GPU版时,首先需确认系统配置,例如本文作者使用的是Win8、CUDA 8.0和Python 3.6.5。用户应根据自身Python和CUDA版本,在PyTorch官网查找并下载相应的.whl文件。此外,建议检查系统环境变量设置,确保CUDA路径正确配置,以避免安装过程中可能出现的兼容性问题。 ... [详细]
  • 本文将深入探讨生成对抗网络(GAN)在计算机视觉领域的应用。作为该领域的经典模型,GAN通过生成器和判别器的对抗训练,能够高效地生成高质量的图像。本文不仅回顾了GAN的基本原理,还将介绍一些最新的进展和技术优化方法,帮助读者全面掌握这一重要工具。 ... [详细]
  • PyTorch 使用问题:解决导入 torch 后 torch.cuda.is_available() 返回 False 的方法
    在配置 PyTorch 时,遇到 `torch.cuda.is_available()` 返回 `False` 的问题。本文总结了多种解决方案,并分享了个人在 PyCharm、Python 和 Anaconda3 环境下成功配置 CUDA 的经验,以帮助读者避免常见错误并顺利使用 GPU 加速。 ... [详细]
  • 本文深入解析了PyTorch框架中的`Parameter()`类和`register_parameter()`方法。首先,通过官方文档介绍了`Parameter()`类的基本功能及其在模型参数管理中的作用。接着,详细探讨了`register_parameter()`方法如何将自定义参数添加到模型中,并确保这些参数能够被优化器识别和更新。最后,对比分析了两者的主要差异,帮助读者理解在不同场景下选择合适的方法来管理和优化模型参数。 ... [详细]
  • 2019年斯坦福大学CS224n课程笔记:深度学习在自然语言处理中的应用——Word2Vec与GloVe模型解析
    本文详细解析了2019年斯坦福大学CS224n课程中关于深度学习在自然语言处理(NLP)领域的应用,重点探讨了Word2Vec和GloVe两种词嵌入模型的原理与实现方法。通过具体案例分析,深入阐述了这两种模型在提升NLP任务性能方面的优势与应用场景。 ... [详细]
author-avatar
书友59082326
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有