热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

对抗神经网络(一)——GAN

   对抗神经网络其实是两个网络的组合,可以理解为一个网络生成模拟数据,另一个网络判断生成的数据是真实的还是模拟的。生成模拟数据的网络要不断优化自己让判别的网络判断不出来,判别的网

      对抗神经网络其实是两个网络的组合,可以理解为一个网络生成模拟数据,另一个网络判断生成的数据是真实的还是模拟的。生成模拟数据的网络要不断优化自己让判别的网络判断不出来,判别的网络也要不断优化自己让判断的更加精确。两者的关系形成对抗,因此叫对抗神经网络。

GAN由generator(生成模型)和discriminator(判别式模型)两部分构成。

generator:主要是从训练数据中产生相同分布的samples,对于输入x,类别标签y,在生成模型中估计其联合概率分布。

discriminator:判断输入的是真实数据还是generator生成的数据,即估计样本属于某类的条件概率分布。它采用传统的监督学习的方法。

      二者结合之后,经过大量次数的迭代训练会使generator尽可能模拟出以假乱真的样本,而discrimator会有更精确的鉴别真伪数据的能力,最终整个GAN会达到所谓的纳什均衡,即discriminator对于generator的数据鉴别结果为正确率和错误率各占50%。

GAN的实现,本例以mnist数据为例,直接代码

      进行训练

# 定义load_data()函数以读取数据
def load_data(data_path):
    '''
    函数功能:导出MNIST数据
    输入: data_path   传入数据所在路径(解压后的数据)
    输出: train_data  输出data,形状为(60000, 28, 28, 1)
         train_label  输出label,形状为(60000, 1)
    '''

    f_data = open(os.path.join(data_path, 'train-images.idx3-ubyte'))
    loaded_data = np.fromfile(file=f_data, dtype=np.uint8)
    # 前16个字符为说明符,需要跳过
    train_data = loaded_data[16:].reshape((-1, 784)).astype(np.float)

    f_label = open(os.path.join(data_path, 'train-labels.idx1-ubyte'))
    loaded_label = np.fromfile(file=f_label, dtype=np.uint8)
    # 前8个字符为说明符,需要跳过
    train_label = loaded_label[8:].reshape((-1)).astype(np.float)

    return train_data, train_label


# 导入需要的包
import os  # 读取路径下文件
import shutil  # 递归删除文件
import tensorflow as tf  # 编写神经网络
import numpy as np  # 矩阵运算操作
from skimage.io import imsave  # 保存影像
from tensorflow.examples.tutorials.mnist import input_data  # 第一次下载数据时用

# 图像的size为(28, 28, 1)
image_height = 28
image_width = 28
image_size = image_height * image_width

# 是否训练和存储设置
train = True
restore = False  # 是否存储训练结果
output_path = "./output/"  # 存储文件的路径

# 实验所需的超参数
max_epoch = 500
batch_size = 256
h1_size = 256  # 第一隐藏层的size,即特征数
h2_size = 512  # 第二隐藏层的size,即特征数
z_size = 128  # 生成器的传入参数

# 导入tensorflow
import tensorflow as tf


# 定义GAN的生成器
def generator(z_prior):
    '''
    函数功能:生成影像,参与训练过程
    输入:z_prior,       #输入tf格式,size为(batch_size, z_size)的数据
    输出:x_generate,    #生成图像
         g_params,      #生成图像的所有参数
    '''
    # 第一个链接层
    # 以2倍标准差stddev的截断的正态分布中生成大小为[z_size, h1_size]的随机值,权值weight初始化。
    w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], stddev=0.1), name="g_w1", dtype=tf.float32)
    # 生成大小为[h1_size]的0值矩阵,偏置bias初始化
    b1 = tf.Variable(tf.zeros([h1_size]), name="g_b1", dtype=tf.float32)
    # 通过矩阵运算,将输入z_prior传入隐含层h1。**函数为relu
    h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)

    # 第二个链接层
    # 以2倍标准差stddev的截断的正态分布中生成大小为[h1_size, h2_size]的随机值,权值weight初始化。
    w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], stddev=0.1), name="g_w2", dtype=tf.float32)
    # 生成大小为[h2_size]的0值矩阵,偏置bias初始化
    b2 = tf.Variable(tf.zeros([h2_size]), name="g_b2", dtype=tf.float32)
    # 通过矩阵运算,将h1传入隐含层h2。**函数为relu
    h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)

    # 第三个链接层
    # 以2倍标准差stddev的截断的正态分布中生成大小为[h2_size, image_size]的随机值,权值weight初始化。
    w3 = tf.Variable(tf.truncated_normal([h2_size, image_size], stddev=0.1), name="g_w3", dtype=tf.float32)
    # 生成大小为[image_size]的0值矩阵,偏置bias初始化
    b3 = tf.Variable(tf.zeros([image_size]), name="g_b3", dtype=tf.float32)
    # 通过矩阵运算,将h2传入隐含层h3。
    h3 = tf.matmul(h2, w3) + b3
    # 利用tanh**函数,将h3传入输出层
    x_generate = tf.nn.tanh(h3)

    # 将所有参数合并到一起
    g_params = [w1, b1, w2, b2, w3, b3]

    return x_generate, g_params


# 定义GAN的判别器
def discriminator(x_data, x_generated, keep_prob):
    '''
    函数功能:对输入数据进行判断,并保存其参数
    输入:x_data,        #输入的真实数据
        x_generated,     #生成器生成的虚假数据
        keep_prob,      #dropout率,防止过拟合
    输出:y_data,        #判别器对batch个数据的处理结果
        y_generated,     #判别器对余下数据的处理结果
        d_params,       #判别器的参数
    '''

    # 合并输入数据,包括真实数据x_data和生成器生成的假数据x_generated
    x_in = tf.concat([x_data, x_generated], 0)

    # 第一个链接层
    # 以2倍标准差stddev的截断的正态分布中生成大小为[image_size, h2_size]的随机值,权值weight初始化。
    w1 = tf.Variable(tf.truncated_normal([image_size, h2_size], stddev=0.1), name="d_w1", dtype=tf.float32)
    # 生成大小为[h2_size]的0值矩阵,偏置bias初始化
    b1 = tf.Variable(tf.zeros([h2_size]), name="d_b1", dtype=tf.float32)
    # 通过矩阵运算,将输入x_in传入隐含层h1.同时以一定的dropout率舍弃节点,防止过拟合
    h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1), keep_prob)

    # 第二个链接层
    # 以2倍标准差stddev的截断的正态分布中生成大小为[h2_size, h1_size]的随机值,权值weight初始化。
    w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], stddev=0.1), name="d_w2", dtype=tf.float32)
    # 生成大小为[h1_size]的0值矩阵,偏置bias初始化
    b2 = tf.Variable(tf.zeros([h1_size]), name="d_b2", dtype=tf.float32)
    # 通过矩阵运算,将h1传入隐含层h2.同时以一定的dropout率舍弃节点,防止过拟合
    h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2), keep_prob)

    # 第三个链接层
    # 以2倍标准差stddev的截断的正态分布中生成大小为[h1_size, 1]的随机值,权值weight初始化。
    w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev=0.1), name="d_w3", dtype=tf.float32)
    # 生成0值,偏置bias初始化
    b3 = tf.Variable(tf.zeros([1]), name="d_b3", dtype=tf.float32)
    # 通过矩阵运算,将h2传入隐含层h3
    h3 = tf.matmul(h2, w3) + b3

    # 从h3中切出batch_size张图像
    y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1], name=None))
    # 从h3中切除余下的图像
    y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1], name=None))

    # 判别器的所有参数
    d_params = [w1, b1, w2, b2, w3, b3]

    return y_data, y_generated, d_params


# 显示结果的函数
def show_result(batch_res, fname, grid_size=(8, 8), grid_pad=5):
    '''
    函数功能:输入相关参数,将运行结果以图片的形式保存到当前路径下
    输入:batch_res,       #输入数据
        fname,             #输入路径
        grid_size=(8, 8),  #默认输出图像为8*8张
        grid_pad=5,       #默认图像的边缘留白为5像素
    输出:无
    '''

    # 将batch_res进行值[0, 1]归一化,同时将其reshape成(batch_size, image_height, image_width)
    batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], image_height, image_width)) + 0.5
    # 重构显示图像格网的参数
    img_h, img_w = batch_res.shape[1], batch_res.shape[2]
    grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
    grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
    img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
    for i, res in enumerate(batch_res):
        if i >= grid_size[0] * grid_size[1]:
            break
        img = (res) * 255.
        img = img.astype(np.uint8)
        row = (i // grid_size[0]) * (img_h + grid_pad)
        col = (i % grid_size[1]) * (img_w + grid_pad)
        img_grid[row:row + img_h, col:col + img_w] = img
    # 保存图像
    imsave(fname, img_grid)


# 定义训练过程
def train():
    '''
    函数功能:训练整个GAN网络,并随机生成手写数字
    输入:无
    输出:sess.saver()
    '''

    # 加载数据
    train_data, train_label = load_data("MNIST_data")
    size = train_data.shape[0]

    # 构建模型---------------------------------------------------------------------
    # 定义GAN网络的输入,其中x_data为[batch_size, image_size], z_prior为[batch_size, z_size]
    x_data = tf.placeholder(tf.float32, [batch_size, image_size], name="x_data")  # (batch_size, image_size)
    z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name="z_prior")  # (batch_size, z_size)
    # 定义dropout率
    keep_prob = tf.placeholder(tf.float32, name="keep_prob")
    global_step = tf.Variable(0, name="global_step", trainable=False)

    # 利用生成器生成数据x_generated和参数g_params
    x_generated, g_params = generator(z_prior)
    # 利用判别器判别生成器的结果
    y_data, y_generated, d_params = discriminator(x_data, x_generated, keep_prob)

    # 定义判别器和生成器的loss函数
    d_loss = - (tf.log(y_data) + tf.log(1 - y_generated))
    g_loss = - tf.log(y_generated)

    # 设置学习率为0.0001,用AdamOptimizer进行优化
    optimizer = tf.train.AdamOptimizer(0.0001)

    # 判别器discriminator 和生成器 generator 对损失函数进行最小化处理
    d_trainer = optimizer.minimize(d_loss, var_list=d_params)
    g_trainer = optimizer.minimize(g_loss, var_list=g_params)
    # 模型构建完毕--------------------------------------------------------------------

    # 全局变量初始化
    init = tf.global_variables_initializer()

    # 启动会话sess
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(init)

    # 判断是否需要存储
    if restore:
        # 若是,将最近一次的checkpoint点存到outpath下
        chkpt_fname = tf.train.latest_checkpoint(output_path)
        saver.restore(sess, chkpt_fname)
    else:
        # 若否,判断目录是存在,如果目录存在,则递归的删除目录下的所有内容,并重新建立目录
        if os.path.exists(output_path):
            shutil.rmtree(output_path)
        os.mkdir(output_path)

    # 利用随机正态分布产生噪声影像,尺寸为(batch_size, z_size)
    z_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)

    # 逐个epoch内训练
    for i in range(sess.run(global_step), max_epoch):
        # 图像每个epoch内可以放(size // batch_size)个size
        for j in range(size // batch_size):
            if j % 20 == 0:
                print("epoch:%s, iter:%s" % (i, j))

            # 训练一个batch的数据
            batch_end = j * batch_size + batch_size
            if batch_end >= size:
                batch_end = size - 1
            x_value = train_data[j * batch_size: batch_end]
            # 将数据归一化到[-1, 1]
            x_value = x_value / 255.
            x_value = 2 * x_value - 1

            # 以正太分布的形式产生随机噪声
            z_value = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
            # 每个batch下,输入数据运行GAN,训练判别器
            sess.run(d_trainer,
                     feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
            # 每个batch下,输入数据运行GAN,训练生成器
            if j % 1 == 0:
                sess.run(g_trainer,
                         feed_dict={x_data: x_value, z_prior: z_value, keep_prob: np.sum(0.7).astype(np.float32)})
        # 每一个epoch中的所有batch训练完后,利用z_sample测试训练后的生成器
        x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_sample_val})
        # 每一个epoch中的所有batch训练完后,显示生成器的结果,并打印生成结果的值
        show_result(x_gen_val, os.path.join(output_path, "sample%s.jpg" % i))
        print(x_gen_val)
        # 每一个epoch中,生成随机分布以重置z_random_sample_val
        z_random_sample_val = np.random.normal(0, 1, size=(batch_size, z_size)).astype(np.float32)
        # 每一个epoch中,利用z_random_sample_val生成手写数字图像,并显示结果
        x_gen_val = sess.run(x_generated, feed_dict={z_prior: z_random_sample_val})
        show_result(x_gen_val, os.path.join(output_path, "random_sample%s.jpg" % i))
        # 保存会话
        sess.run(tf.assign(global_step, i + 1))
        saver.save(sess, os.path.join(output_path, "model"), global_step=global_step)

if __name__ == '__main__':
    if train:
        train()

      训练完成后,如下

对抗神经网络(一)——GAN

 

      训练epoch为300次的实验结果:

对抗神经网络(一)——GAN

      生成的和原图像基本一样。

 

 

参考:https://blog.csdn.net/z704630835/article/details/82017892


推荐阅读
  • 通过使用CIFAR-10数据集,本文详细介绍了如何快速掌握Mixup数据增强技术,并展示了该方法在图像分类任务中的显著效果。实验结果表明,Mixup能够有效提高模型的泛化能力和分类精度,为图像识别领域的研究提供了有价值的参考。 ... [详细]
  • Python 实战:异步爬虫(协程技术)与分布式爬虫(多进程应用)深入解析
    本文将深入探讨 Python 异步爬虫和分布式爬虫的技术细节,重点介绍协程技术和多进程应用在爬虫开发中的实际应用。通过对比多进程和协程的工作原理,帮助读者理解两者在性能和资源利用上的差异,从而在实际项目中做出更合适的选择。文章还将结合具体案例,展示如何高效地实现异步和分布式爬虫,以提升数据抓取的效率和稳定性。 ... [详细]
  • 探索聚类分析中的K-Means与DBSCAN算法及其应用
    聚类分析是一种用于解决样本或特征分类问题的统计分析方法,也是数据挖掘领域的重要算法之一。本文主要探讨了K-Means和DBSCAN两种聚类算法的原理及其应用场景。K-Means算法通过迭代优化簇中心来实现数据点的划分,适用于球形分布的数据集;而DBSCAN算法则基于密度进行聚类,能够有效识别任意形状的簇,并且对噪声数据具有较好的鲁棒性。通过对这两种算法的对比分析,本文旨在为实际应用中选择合适的聚类方法提供参考。 ... [详细]
  • 本文探讨了基于点集估算图像区域的Alpha形状算法在Python中的应用。通过改进传统的Delaunay三角剖分方法,该算法能够生成更加灵活和精确的形状轮廓,避免了单纯使用Delaunay三角剖分时可能出现的过大三角形问题。这种“模糊Delaunay三角剖分”技术不仅提高了形状的准确性,还增强了对复杂图像区域的适应能力。 ... [详细]
  • 使用多项式拟合分析淘宝双11销售趋势
    根据天猫官方数据,2019年双11成交额达到2684亿元,再次刷新历史记录。本文通过多项式拟合方法,分析并预测未来几年的销售趋势。 ... [详细]
  • Python 序列图分割与可视化编程入门教程
    本文介绍了如何使用 Python 进行序列图的快速分割与可视化。通过一个实际案例,详细展示了从需求分析到代码实现的全过程。具体包括如何读取序列图数据、应用分割算法以及利用可视化库生成直观的图表,帮助非编程背景的用户也能轻松上手。 ... [详细]
  • 在Java项目中,当两个文件进行互相调用时出现了函数错误。具体问题出现在 `MainFrame.java` 文件中,该文件位于 `cn.javass.bookmgr` 包下,并且导入了 `java.awt.BorderLayout` 和 `java.awt.Event` 等相关类。为了确保项目的正常运行,请求提供专业的解决方案,以解决函数调用中的错误。建议从类路径、依赖关系和方法签名等方面入手,进行全面排查和调试。 ... [详细]
  • 【图像分类实战】利用DenseNet在PyTorch中实现秃头识别
    本文详细介绍了如何使用DenseNet模型在PyTorch框架下实现秃头识别。首先,文章概述了项目所需的库和全局参数设置。接着,对图像进行预处理并读取数据集。随后,构建并配置DenseNet模型,设置训练和验证流程。最后,通过测试阶段验证模型性能,并提供了完整的代码实现。本文不仅涵盖了技术细节,还提供了实用的操作指南,适合初学者和有经验的研究人员参考。 ... [详细]
  • 利用 Python 中的 Altair 库实现数据抖动的水平剥离分析 ... [详细]
  • 本题库精选了Java核心知识点的练习题,旨在帮助学习者巩固和检验对Java理论基础的掌握。其中,选择题部分涵盖了访问控制权限等关键概念,例如,Java语言中仅允许子类或同一包内的类访问的访问权限为protected。此外,题库还包括其他重要知识点,如异常处理、多线程、集合框架等,全面覆盖Java编程的核心内容。 ... [详细]
  • 如何使用 `org.opencb.opencga.core.results.VariantQueryResult.getSource()` 方法及其代码示例详解 ... [详细]
  • PTArchiver工作原理详解与应用分析
    PTArchiver工作原理及其应用分析本文详细解析了PTArchiver的工作机制,探讨了其在数据归档和管理中的应用。PTArchiver通过高效的压缩算法和灵活的存储策略,实现了对大规模数据的高效管理和长期保存。文章还介绍了其在企业级数据备份、历史数据迁移等场景中的实际应用案例,为用户提供了实用的操作建议和技术支持。 ... [详细]
  • 优化后的标题:深入探讨网关安全:将微服务升级为OAuth2资源服务器的最佳实践
    本文深入探讨了如何将微服务升级为OAuth2资源服务器,以订单服务为例,详细介绍了在POM文件中添加 `spring-cloud-starter-oauth2` 依赖,并配置Spring Security以实现对微服务的保护。通过这一过程,不仅增强了系统的安全性,还提高了资源访问的可控性和灵活性。文章还讨论了最佳实践,包括如何配置OAuth2客户端和资源服务器,以及如何处理常见的安全问题和错误。 ... [详细]
  • 本文详细介绍了使用 Python 进行 MySQL 和 Redis 数据库操作的实战技巧。首先,针对 MySQL 数据库,通过 `pymysql` 模块展示了如何连接和操作数据库,包括建立连接、执行查询和更新等常见操作。接着,文章深入探讨了 Redis 的基本命令和高级功能,如键值存储、列表操作和事务处理。此外,还提供了多个实际案例,帮助读者更好地理解和应用这些技术。 ... [详细]
  • 本项目在Java Maven框架下,利用POI库实现了Excel数据的高效导入与导出功能。通过优化数据处理流程,提升了数据操作的性能和稳定性。项目已发布至GitHub,当前最新版本为0.0.5。该项目不仅适用于小型应用,也可扩展用于大型企业级系统,提供了灵活的数据管理解决方案。GitHub地址:https://github.com/83945105/holygrail,Maven坐标:`com.github.83945105:holygrail:0.0.5`。 ... [详细]
author-avatar
姿萱俊达俊宏
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有