热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

开发笔记:InfoGAN的简易实现

本文由编程笔记#小编为大家整理,主要介绍了InfoGAN的简易实现相关的知识,希望对你有一定的参考价值。这里求最大化互信息没有共享D网络,直接使用了一个简单的mlp神经网络Q
本文由编程笔记#小编为大家整理,主要介绍了InfoGAN的简易实现相关的知识,希望对你有一定的参考价值。


这里求最大化互信息没有共享D网络,直接使用了一个简单的mlp神经网络Q


import os, sys
sys.path.append(
"/home/hxj/anaconda3/lib/python3.6/site-packages")
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data
mnist
= input_data.read_data_sets(./MNIST_data, one_hot=True)
mb_size
= 32
Z_dim
= 16
X_dim
= mnist.train.images.shape[1] #784
y_dim = mnist.train.labels.shape[1] #10
h_dim = 128
cnt
= 0
lr
= 1e-3
def xavier_init(size):
in_dim
= size[0]
xavier_stddev
= 1. / np.sqrt(in_dim / 2.)
return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)
""" ==================== GENERATOR ======================== """
Wzh
= xavier_init(size=[Z_dim + 10, h_dim]) #shape 26 * 128
bzh = Variable(torch.zeros(h_dim), requires_grad=True)
Whx
= xavier_init(size=[h_dim, X_dim]) #shape 128 * 784
bhx = Variable(torch.zeros(X_dim), requires_grad=True)
def G(z, c):
inputs
= torch.cat([z, c], 1)
h
= nn.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1))
X
= nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
return X
""" ==================== DISCRIMINATOR ======================== """
Wxh
= xavier_init(size=[X_dim, h_dim])
bxh
= Variable(torch.zeros(h_dim), requires_grad=True)
Why
= xavier_init(size=[h_dim, 1])
bhy
= Variable(torch.zeros(1), requires_grad=True)
def D(X):
h
= nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
y
= nn.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
return y
""" ====================== Q(c|X) ========================== """
Wqxh
= xavier_init(size=[X_dim, h_dim])
bqxh
= Variable(torch.zeros(h_dim), requires_grad=True)
Whc
= xavier_init(size=[h_dim, 10])
bhc
= Variable(torch.zeros(10), requires_grad=True)
def Q(X):
h
= nn.relu(X @ Wqxh + bqxh.repeat(X.size(0), 1))
c
= nn.softmax(h @ Whc + bhc.repeat(h.size(0), 1))
return c
G_params
= [Wzh, bzh, Whx, bhx]
D_params
= [Wxh, bxh, Why, bhy]
Q_params
= [Wqxh, bqxh, Whc, bhc]
params
= G_params + D_params + Q_params
""" ===================== TRAINING ======================== """
def reset_grad():
for p in params:
if p.grad is not None:
data
= p.grad.data
p.grad
= Variable(data.new().resize_as_(data).zero_())
G_solver
= optim.Adam(G_params, lr=1e-3)
D_solver
= optim.Adam(D_params, lr=1e-3)
Q_solver
= optim.Adam(G_params + Q_params, lr=1e-3)
def sample_c(size):
c
= np.random.multinomial(1, 10*[0.1], size=size)
c
= Variable(torch.from_numpy(c.astype(float32)))
return c
for it in range(100000):
# Sample data
X, _ = mnist.train.next_batch(mb_size) # 32
X = Variable(torch.from_numpy(X)) #将数组转换为列向量 32*784
z = Variable(torch.randn(mb_size, Z_dim))# 32 16 随机二维数组
c = sample_c(mb_size) # 32 10的标签 随机标签
print(z.shape)
print(c.shape)
sys.exit()
# Dicriminator forward-loss-backward-update
G_sample = G(z, c)
D_real
= D(X)
D_fake
= D(G_sample)
D_loss
= -torch.mean(torch.log(D_real + 1e-8) + torch.log(1 - D_fake + 1e-8))
D_loss.backward()
D_solver.step()
# Housekeeping - reset gradient
reset_grad()
# Generator forward-loss-backward-update
G_sample = G(z, c)
D_fake
= D(G_sample)
G_loss
= -torch.mean(torch.log(D_fake + 1e-8))
G_loss.backward()
G_solver.step()
# Housekeeping - reset gradient
reset_grad()
# Q forward-loss-backward-update
G_sample = G(z, c) #在c标签下生成的假样本,除了用来训练G和D之外,还要经过神经网络Q
Q_c_given_x = Q(G_sample) # 让标签和经过Q生成的值之间的互信息最大

crossent_loss
= torch.mean(-torch.sum(c * torch.log(Q_c_given_x + 1e-8), dim=1))
mi_loss
= crossent_loss
mi_loss.backward()
Q_solver.step()
# Housekeeping - reset gradient
reset_grad()
# Print and plot every now and then
if it % 1000 == 0:
idx
= np.random.randint(0, 10)
c
= np.zeros([mb_size, 10])
c[range(mb_size), idx]
= 1
c
= Variable(torch.from_numpy(c.astype(float32)))
samples
= G(z, c).data.numpy()[:16]
print(Iter-{}; D_loss: {}; G_loss: {}; Idx: {}
.format(it, D_loss.data.numpy(), G_loss.data.numpy(), idx))
fig
= plt.figure(figsize=(4, 4))
gs
= gridspec.GridSpec(4, 4)
gs.update(wspace
=0.05, hspace=0.05)
for i, sample in enumerate(samples):
ax
= plt.subplot(gs[i])
plt.axis(
off)
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect(
equal)
plt.imshow(sample.reshape(
28, 28), cmap=Greys_r)
if not os.path.exists(out/):
os.makedirs(
out/)
plt.savefig(
out/{}.png
.format(str(cnt).zfill(
3)), bbox_inches=tight)
cnt
+= 1
plt.close(fig)

 


推荐阅读
author-avatar
XsiaoHan
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有