热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

TF入门之自编码器

参考资料:https:github.comaymericdamienTensorFlow-Examplesblobmasterexamples3_NeuralNetw

参考资料:
https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/autoencoder.py

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as pltfrom tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)"""
自编码器将输入经过多层网络计算得到一个中间状态,
然后再将这个中间状态经过多层网络计算得到还原状态,
计算还原的状态和原始状态之间的误差,从而更新网络层的权重
"""
learning_rate = 0.01
num_steps = 20000
batch_size = 256display_step = 1000
examples_to_show = 10num_hidden_1 = 256
num_hidden_2 = 128
num_input = 784X = tf.placeholder('float', [None, num_input])"""
input -> encoder_layer1 -> encoder_layer2 -> decoder_layer1 -> decoder_layer2 -> output
"""
weights = {'encoder_h1': tf.Variable(tf.random_normal([num_input, num_hidden_1])),'encoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_hidden_2])),'decoder_h1': tf.Variable(tf.random_normal([num_hidden_2, num_hidden_1])),'decoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_input])),
}biases = {'encoder_b1': tf.Variable(tf.random_normal([num_hidden_1])),'encoder_b2': tf.Variable(tf.random_normal([num_hidden_2])),'decoder_b1': tf.Variable(tf.random_normal([num_hidden_1])),'decoder_b2': tf.Variable(tf.random_normal([num_input])),
}def encoder(x):layer1 = tf.nn.sigmoid(tf.add(tf.matmul(x,weights["encoder_h1"]), biases["encoder_b1"]))layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1, weights["encoder_h2"]), biases["encoder_b2"]))return layer2
def decoder(x):layer1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights["decoder_h1"]), biases["decoder_b1"]))layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1, weights["decoder_h2"]), biases["decoder_b2"]))return layer2encoder_op = encoder(X)
decoder_op = decoder(encoder_op)y_pred = decoder_op
y_true = X#计算误差
loss = tf.reduce_mean(tf.pow(y_true-y_pred,2))
#优化器
optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(loss)init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)for i in range(1, num_steps+1):batch_x, _ = mnist.train.next_batch(batch_size)_, l = sess.run([optimizer, loss], feed_dict={X:batch_x})if i % display_step == 0 or i == 1:print("loss", l)n = 1for i in range(10):batch_x, _ = mnist.test.next_batch(1)g = sess.run(decoder_op, feed_dict={X:batch_x})g = np.reshape(g, newshape=[28, 28])print("Original Images")plt.figure(figsize=(n, n))plt.imshow(np.reshape(batch_x, newshape=[28, 28]), origin="upper", cmap="gray")plt.show()print("Reconstructed Images")plt.figure(figsize=(n, n))plt.imshow(g, origin="upper", cmap="gray")plt.show()


推荐阅读
author-avatar
youstar
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有