参考资料:
https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/autoencoder.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as pltfrom tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)"""
自编码器将输入经过多层网络计算得到一个中间状态,
然后再将这个中间状态经过多层网络计算得到还原状态,
计算还原的状态和原始状态之间的误差,从而更新网络层的权重
"""learning_rate = 0.01
num_steps = 20000
batch_size = 256display_step = 1000
examples_to_show = 10num_hidden_1 = 256
num_hidden_2 = 128
num_input = 784X = tf.placeholder('float', [None, num_input])"""
input -> encoder_layer1 -> encoder_layer2 -> decoder_layer1 -> decoder_layer2 -> output
"""weights = {'encoder_h1': tf.Variable(tf.random_normal([num_input, num_hidden_1])),'encoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_hidden_2])),'decoder_h1': tf.Variable(tf.random_normal([num_hidden_2, num_hidden_1])),'decoder_h2': tf.Variable(tf.random_normal([num_hidden_1, num_input])),
}biases = {'encoder_b1': tf.Variable(tf.random_normal([num_hidden_1])),'encoder_b2': tf.Variable(tf.random_normal([num_hidden_2])),'decoder_b1': tf.Variable(tf.random_normal([num_hidden_1])),'decoder_b2': tf.Variable(tf.random_normal([num_input])),
}def encoder(x):layer1 = tf.nn.sigmoid(tf.add(tf.matmul(x,weights["encoder_h1"]), biases["encoder_b1"]))layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1, weights["encoder_h2"]), biases["encoder_b2"]))return layer2
def decoder(x):layer1 = tf.nn.sigmoid(tf.add(tf.matmul(x, weights["decoder_h1"]), biases["decoder_b1"]))layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1, weights["decoder_h2"]), biases["decoder_b2"]))return layer2encoder_op = encoder(X)
decoder_op = decoder(encoder_op)y_pred = decoder_op
y_true = X
loss = tf.reduce_mean(tf.pow(y_true-y_pred,2))
optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(loss)init = tf.global_variables_initializer()with tf.Session() as sess:sess.run(init)for i in range(1, num_steps+1):batch_x, _ = mnist.train.next_batch(batch_size)_, l = sess.run([optimizer, loss], feed_dict={X:batch_x})if i % display_step == 0 or i == 1:print("loss", l)n = 1for i in range(10):batch_x, _ = mnist.test.next_batch(1)g = sess.run(decoder_op, feed_dict={X:batch_x})g = np.reshape(g, newshape=[28, 28])print("Original Images")plt.figure(figsize=(n, n))plt.imshow(np.reshape(batch_x, newshape=[28, 28]), origin="upper", cmap="gray")plt.show()print("Reconstructed Images")plt.figure(figsize=(n, n))plt.imshow(g, origin="upper", cmap="gray")plt.show()