2018年新年,腾讯整出来个ai春联很吸引眼球,刚好有个需求让我看下能不能训出来个model来写出诗经一样的文风,求助了下小伙伴,直接丢过来2个github,原话是:
查了一下诗经一共38000个字,应该是可以训练出一个语言模型的。只是怕机器写出来的诗一般都没灵魂。https://github.com/hjptriplebee/Chinese_poem_generator; https://github.com/xue2han/AncientChinesePoemRNN.
我测试了,第一个没跑通,没有时间去check,所以直接第二个效果赞赞的。只跑了4000个epoch就给出了我惊喜,结果如下:
冲着这个效果这么厉害,一定要趁热扒一下背后的NLP技术。
首先参考AI对联背后的技术。
智能春联的核心技术从大的范畴上属于NLP,自然语言处理技术。创作春联又可以归类为其中的语言生成方向的技术,国内的语言生成研究可以追溯到20世纪90年代,至今已经探索了各种方法,主要有基于模版、随机生成并测试、基于遗传算法、基于实例推理、基于统计机器翻译等各种类型的方法。
本文举两个典型的技术途径作为案例:
1.第一种是没文化生成:即不去了解任何信息的内容,程序根本不知道文字内容是啥,只是从信息熵的角度进行随机生成与测试,在计算机眼里这只是“熵为**的一个随机数据序列”。专业的说法叫做不加领域知识的LSTM生成。LSTM是一种RNN网络(循环神经网络),适用于时序性较强的语言类样本,这里主要使用信息熵作为收敛的损失函数。这种方法生成的语言往往经不住推敲,缺乏意境、主题性等,主要是因为损失函数的定义缺少“文化”,过于强调“信息熵”。
RNN示意图
2.第二种是有文化生成:即在算法中增加了格律诗的领域知识,例如格律押韵、主题意境等。专业的说法叫基于主题模型的统计机器翻译生成。统计机器翻译主要是一类映射源语言与目标语言的模型。主要使用生成对联与参考标准集之间的相似度作为收敛的损失函数。这种方法的缺陷在于春联的质量与参考标准集强相关,容易陷入单一风格化,难以创造真正“属于机器自己的风格”。
神经机器翻译架构
一、不加领域知识的LSTM生成
1)从网上搜集了各式春联共6900对
2)将汉字编码为数字,或者叫做Encoder,并将数据分割为训练集和测试集
3)定义LSTM模型
4)用加权交叉熵损失函数训练模型,LOSS控制在1.5左右,训练结束
5)自动生成新的春联,需要再将数字转为汉字
感受一下代码
将汉字编码为数字,或者叫做Encoder,并将数据分割为训练集和测试集
couplet_file &#61;"couplet.txt"#对联couplets &#61; []with open(couplet_file,&#39;r&#39;) as f: for line in f: try: content &#61; line.replace(&#39; &#39;,&#39;&#39;) if &#39;_&#39; in content or &#39;(&#39; in content or &#39;&#xff08;&#39; in content or &#39;《&#39; in content or &#39;[&#39; in content: continue if len(content) <5*3 or len(content) > 79*3: continue content &#61; &#39;[&#39; &#43; content &#43; &#39;]&#39; # print chardet.detect(content) content &#61; content.decode(&#39;utf-8&#39;) couplets.append(content) except Exception as e: pass# 按字数排序couplets &#61; sorted(couplets,key&#61;lambda line: len(line))print(&#39;对联总数: %d&#39;%(len(couplets)))# 统计每个字出现次数all_words &#61; []for couplet in couplets: all_words &#43;&#61; [word for word in couplet]counter &#61; collections.Counter(all_words)count_pairs &#61; sorted(counter.items(), key&#61;lambda x: -x[1])words, _ &#61; zip(*count_pairs)words &#61; words[:len(words)] &#43; (&#39; &#39;,)# 每个字映射为一个数字IDword_num_map &#61; dict(zip(words, range(len(words))))to_num &#61; lambda word: word_num_map.get(word, len(words))couplets_vector &#61; [ list(map(to_num, couplet)) for couplet in couplets]# 每次取64首对联进行训练, 此参数可以调整batch_size &#61; 64n_chunk &#61; len(couplets_vector) // batch_sizex_batches &#61; []y_batches &#61; []for i in range(n_chunk): start_index &#61; i * batch_size#起始位置 end_index &#61; start_index &#43; batch_size#结束位置 batches &#61; couplets_vector[start_index:end_index] length &#61; max(map(len,batches))#每个batches中句子的最大长度 xdata &#61; np.full((batch_size,length), word_num_map[&#39; &#39;], np.int32) for row in range(batch_size): xdata[row,:len(batches[row])] &#61; batches[row] ydata &#61; np.copy(xdata) ydata[:,:-1] &#61; xdata[:,1:] x_batches.append(xdata) y_batches.append(ydata)
定义LSTM模型&#xff08;定义cell为一个128维的ht的cell。并使用MultiRNNCell 定义为两层的LSTM&#xff09;def neural_network(rnn_size&#61;128, num_layers&#61;2): cell &#61; tf.nn.rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple&#61;True) cell &#61; tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple&#61;True) initial_state &#61; cell.zero_state(batch_size, tf.float32) with tf.variable_scope(&#39;rnnlm&#39;): softmax_w &#61; tf.get_variable("softmax_w", [rnn_size, len(words)&#43;1]) softmax_b &#61; tf.get_variable("softmax_b", [len(words)&#43;1]) with tf.device("/cpu:0"): embedding &#61; tf.get_variable("embedding", [len(words)&#43;1, rnn_size]) inputs &#61; tf.nn.embedding_lookup(embedding, input_data) outputs, last_state &#61; tf.nn.dynamic_rnn(cell, inputs, initial_state&#61;initial_state, scope&#61;&#39;rnnlm&#39;) output &#61; tf.reshape(outputs,[-1, rnn_size]) logits &#61; tf.matmul(output, softmax_w) &#43; softmax_b probs &#61; tf.nn.softmax(logits) return logits, last_state, probs, cell, initial_state
用加权交叉熵损失函数训练模型&#xff0c;LOSS控制在1.5左右&#xff0c;训练结束
def train_neural_network(): logits, last_state, _, _, _ &#61; neural_network() targets &#61; tf.reshape(output_targets, [-1]) loss &#61; tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype&#61;tf.float32)], len(words)) cost &#61; tf.reduce_mean(loss) learning_rate &#61; tf.Variable(0.0, trainable&#61;False) tvars &#61; tf.trainable_variables() grads, _ &#61; tf.clip_by_global_norm(tf.gradients(cost, tvars), 5) optimizer &#61; tf.train.AdamOptimizer(learning_rate) train_op &#61; optimizer.apply_gradients(zip(grads, tvars)) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) saver &#61; tf.train.Saver(tf.all_variables()) for epoch in range(100): sess.run(tf.assign(learning_rate, 0.01 * (0.97 ** epoch))) n &#61; 0 for batche in range(n_chunk): train_loss, _ , _ &#61; sess.run([cost, last_state, train_op], feed_dict&#61;{input_data: x_batches[n], output_targets: y_batches[n]}) n &#43;&#61; 1 print(epoch, batche, train_loss) if epoch % 7 &#61;&#61; 0: saver.save(sess, &#39;./couplet.module&#39;, global_step&#61;epoch)
自动生成新的春联
saver.restore(sess, &#39;couplet.module-98&#39;)
二、基于主题模型的统计机器翻译生成
1&#xff09;准备统计模型的训练数据
格律诗训练语料来自互联网&#xff0c;其中包括《 唐诗》、《 全唐诗》、《 全台词》等文献&#xff0c;以及从各大诗词论坛&#xff08;例如诗词在线、天涯论坛诗词比兴等&#xff09;抓取并筛选后的格律诗&#xff0c;总计287000多首。
2&#xff09;设定主题模型&#xff0c;这里使用概率潜在语义分析&#xff0c;PLSA
例如&#xff1a;给定主题词“春日”&#xff0c;根据它在潜在主题空间中的分布向量&#xff0c;可以找出 “玉魄”、“红泥”和 “燕”等空间距离比较近的语义相关词。
3&#xff09;基于主题模型的词汇扩展
4&#xff09;定义算法&#xff1a;依照主题词生成首句的算法
5&#xff09;定义基于统计机器翻译的二、三、四句生成模型
我们采用基于短语的统计机器翻译技术 &#xff0c;PBSMT是目前一种主流的机器翻译技术&#xff0c;它的优势在于短语翻译结果的选词准确&#xff0e; 由于诗词的生成讲求对仗&#xff0c;不涉及远距离语序调整问题&#xff0c;因此&#xff0c;诗词的生成非常适合采用基于短语的机器翻译算法来解决。
6&#xff09;基于BLEU的评测方法&#xff0c;结果收敛后保存模型
BLEU 的直观思想是翻译结果越接近参考答案则翻译质量越好&#xff0e; 相应的&#xff0c;我们认为如果根据给定上句生成的下句能够更贴近已有的参考下句则系统的生成质量越好&#xff0c;但由于诗词在内容表现上丰富多样&#xff0c;所以需要搜集拥有多个参考下句的数据样本加入答案集。BLEU通过对生成候选句与源语句的参考句进行&#xff11;元词到N元词的重合度统计&#xff0c;结合下式衡量生成结果的好坏。
7&#xff09;给定主题词&#xff0c;生成新的格律诗
论文全文下载&#xff0c;请在公众号回复&#xff1a;20180216
参考资料
关于RNN和LSTM原理的说明&#xff1a; http://www.jianshu.com/p/9dc9f41f0b29
LSTM深度学习写春联&#xff1a;http://blog.csdn.net/leadai/article/details/79015862
基于主题模型和统计机器翻译方法的中文格律诗自动生成&#xff1a;蒋锐滢&#xff0c;崔 磊&#xff0c;何 晶&#xff0c;周 明&#xff0c;潘志庚
接下来看代码
参照[char-rnn-tensorflow](https://github.com/sherjilozair/char-rnn-tensorflow)&#xff0c;使用RNN的字符模型&#xff0c;学习并生成古诗。
数据来自于http://www16.zzu.edu.cn/qts/ ,总共4万多首唐诗。
- tensorflow 1.0
- python2
先看训练数据&#xff0c;poems.txt.截取片段
煌煌道宫&#xff0c;肃肃太清。礼光尊祖&#xff0c;乐备充庭。罄竭诚至&#xff0c;希夷降灵。云凝翠盖&#xff0c;风焰红旌。众真以从&#xff0c;九奏初迎。永惟休v&#xff0c;是锡和平。
种瓜黄台下&#xff0c;瓜熟子离离。一摘使瓜好&#xff0c;再摘使瓜稀。三摘犹自可&#xff0c;摘绝抱蔓归。
看出来去掉了标题和作者的干扰。这点很重要&#xff0c;我诗经训练出来的结果很奇葩&#xff0c;估计就是我标题没有去。
核心训练代码。train.py
from __future__ import print_function
import numpy as np
import tensorflow as tfimport argparse
import time
import os,sys
from six.moves import cPicklefrom utils import TextLoader
from model import Modeldef main():parser &#61; argparse.ArgumentParser()parser.add_argument(&#39;--save_dir&#39;, type&#61;str, default&#61;&#39;save&#39;,help&#61;&#39;directory to store checkpointed models&#39;)parser.add_argument(&#39;--rnn_size&#39;, type&#61;int, default&#61;128,help&#61;&#39;size of RNN hidden state&#39;)parser.add_argument(&#39;--num_layers&#39;, type&#61;int, default&#61;2,help&#61;&#39;number of layers in the RNN&#39;)parser.add_argument(&#39;--model&#39;, type&#61;str, default&#61;&#39;lstm&#39;,help&#61;&#39;rnn, gru, or lstm&#39;)parser.add_argument(&#39;--batch_size&#39;, type&#61;int, default&#61;64,help&#61;&#39;minibatch size&#39;)parser.add_argument(&#39;--num_epochs&#39;, type&#61;int, default&#61;50,help&#61;&#39;number of epochs&#39;)parser.add_argument(&#39;--save_every&#39;, type&#61;int, default&#61;1000,help&#61;&#39;save frequency&#39;)parser.add_argument(&#39;--grad_clip&#39;, type&#61;float, default&#61;5.,help&#61;&#39;clip gradients at this value&#39;)parser.add_argument(&#39;--learning_rate&#39;, type&#61;float, default&#61;0.002,help&#61;&#39;learning rate&#39;)parser.add_argument(&#39;--decay_rate&#39;, type&#61;float, default&#61;0.97,help&#61;&#39;decay rate for rmsprop&#39;)parser.add_argument(&#39;--init_from&#39;, type&#61;str, default&#61;None,help&#61;"""continue training from saved model at this path. Path must contain files saved by previous training process:&#39;config.pkl&#39; : configuration;&#39;chars_vocab.pkl&#39; : vocabulary definitions;&#39;iterations&#39; : number of trained iterations;&#39;losses-*&#39; : train loss;&#39;checkpoint&#39; : paths to model file(s) (created by tf).Note: this file contains absolute paths, be careful when moving files around;&#39;model.ckpt-*&#39; : file(s) with model definition (created by tf)""")args &#61; parser.parse_args()train(args)def train(args):data_loader &#61; TextLoader(args.batch_size)args.vocab_size &#61; data_loader.vocab_size# check compatibility if training is continued from previously saved modelif args.init_from is not None:# check if all necessary files existassert os.path.isdir(args.init_from)," %s must be a a path" % args.init_fromassert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_fromassert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_fromckpt &#61; tf.train.get_checkpoint_state(args.init_from)assert ckpt,"No checkpoint found"assert ckpt.model_checkpoint_path,"No model path found in checkpoint"assert os.path.isfile(os.path.join(args.init_from,"iterations")),"iterations file does not exist in path %s " % args.init_from# open old config and check if models are compatiblewith open(os.path.join(args.init_from, &#39;config.pkl&#39;),&#39;rb&#39;) as f:saved_model_args &#61; cPickle.load(f)need_be_same&#61;["model","rnn_size","num_layers"]for checkme in need_be_same:assert vars(saved_model_args)[checkme]&#61;&#61;vars(args)[checkme],"Command line argument and saved model disagree on &#39;%s&#39; "%checkme# open saved vocab/dict and check if vocabs/dicts are compatiblewith open(os.path.join(args.init_from, &#39;chars_vocab.pkl&#39;),&#39;rb&#39;) as f:saved_chars, saved_vocab &#61; cPickle.load(f)assert saved_chars&#61;&#61;data_loader.chars, "Data and loaded model disagree on character set!"assert saved_vocab&#61;&#61;data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"with open(os.path.join(args.save_dir, &#39;config.pkl&#39;), &#39;wb&#39;) as f:cPickle.dump(args, f)with open(os.path.join(args.save_dir, &#39;chars_vocab.pkl&#39;), &#39;wb&#39;) as f:cPickle.dump((data_loader.chars, data_loader.vocab), f)model &#61; Model(args)with tf.Session() as sess:tf.global_variables_initializer().run()saver &#61; tf.train.Saver(tf.global_variables())iterations &#61; 0# restore model and number of iterationsif args.init_from is not None:saver.restore(sess, ckpt.model_checkpoint_path)with open(os.path.join(args.save_dir, &#39;iterations&#39;),&#39;rb&#39;) as f:iterations &#61; cPickle.load(f)losses &#61; []for e in range(args.num_epochs):sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))data_loader.reset_batch_pointer()for b in range(data_loader.num_batches):iterations &#43;&#61; 1start &#61; time.time() x, y &#61; data_loader.next_batch() feed &#61; {model.input_data: x, model.targets: y}train_loss, _ , _ &#61; sess.run([model.cost, model.final_state, model.train_op], feed)end &#61; time.time()sys.stdout.write(&#39;\r&#39;)info &#61; "{}/{} (epoch {}), train_loss &#61; {:.3f}, time/batch &#61; {:.3f}" \.format(e * data_loader.num_batches &#43; b,args.num_epochs * data_loader.num_batches,e, train_loss, end - start)sys.stdout.write(info)sys.stdout.flush()losses.append(train_loss)if (e * data_loader.num_batches &#43; b) % args.save_every &#61;&#61; 0\or (e&#61;&#61;args.num_epochs-1 and b &#61;&#61; data_loader.num_batches-1): # save for the last resultcheckpoint_path &#61; os.path.join(args.save_dir, &#39;model.ckpt&#39;)saver.save(sess, checkpoint_path, global_step &#61; iterations)with open(os.path.join(args.save_dir,"iterations"),&#39;wb&#39;) as f:cPickle.dump(iterations,f)with open(os.path.join(args.save_dir,"losses-"&#43;str(iterations)),&#39;wb&#39;) as f:cPickle.dump(losses,f)losses &#61; []sys.stdout.write(&#39;\n&#39;)print("model saved to {}".format(checkpoint_path))sys.stdout.write(&#39;\n&#39;)if __name__ &#61;&#61; &#39;__main__&#39;:main()
再看下model.py
#-*- coding:utf-8 -*-import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.contrib import legacy_seq2seq
import numpy as npclass Model():def __init__(self, args,infer&#61;False):self.args &#61; argsif infer:args.batch_size &#61; 1if args.model &#61;&#61; &#39;rnn&#39;:cell_fn &#61; rnn.BasicRNNCellelif args.model &#61;&#61; &#39;gru&#39;:cell_fn &#61; rnn.GRUCellelif args.model &#61;&#61; &#39;lstm&#39;:cell_fn &#61; rnn.BasicLSTMCellelse:raise Exception("model type not supported: {}".format(args.model))cell &#61; cell_fn(args.rnn_size,state_is_tuple&#61;False)self.cell &#61; cell &#61; rnn.MultiRNNCell([cell] * args.num_layers,state_is_tuple&#61;False)self.input_data &#61; tf.placeholder(tf.int32, [args.batch_size, None])# the length of input sequence is variable.self.targets &#61; tf.placeholder(tf.int32, [args.batch_size, None])self.initial_state &#61; cell.zero_state(args.batch_size, tf.float32)with tf.variable_scope(&#39;rnnlm&#39;):softmax_w &#61; tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])softmax_b &#61; tf.get_variable("softmax_b", [args.vocab_size])with tf.device("/cpu:0"):embedding &#61; tf.get_variable("embedding", [args.vocab_size, args.rnn_size])inputs &#61; tf.nn.embedding_lookup(embedding, self.input_data)outputs, last_state &#61; tf.nn.dynamic_rnn(cell,inputs,initial_state&#61;self.initial_state,scope&#61;&#39;rnnlm&#39;)output &#61; tf.reshape(outputs,[-1, args.rnn_size])self.logits &#61; tf.matmul(output, softmax_w) &#43; softmax_bself.probs &#61; tf.nn.softmax(self.logits)targets &#61; tf.reshape(self.targets, [-1])loss &#61; legacy_seq2seq.sequence_loss_by_example([self.logits],[targets],[tf.ones_like(targets,dtype&#61;tf.float32)],args.vocab_size)self.cost &#61; tf.reduce_mean(loss)self.final_state &#61; last_stateself.lr &#61; tf.Variable(0.0, trainable&#61;False)tvars &#61; tf.trainable_variables()grads, _ &#61; tf.clip_by_global_norm(tf.gradients(self.cost, tvars),args.grad_clip)optimizer &#61; tf.train.AdamOptimizer(self.lr)self.train_op &#61; optimizer.apply_gradients(zip(grads, tvars))def sample(self, sess, chars, vocab, prime&#61;u&#39;&#39;, sampling_type&#61;1):def pick_char(weights):if sampling_type &#61;&#61; 0:sample &#61; np.argmax(weights)else:t &#61; np.cumsum(weights)s &#61; np.sum(weights)sample &#61; int(np.searchsorted(t, np.random.rand(1)*s))return chars[sample]for char in prime:if char not in vocab:return u"{} is not in charset!".format(char)if not prime:state &#61; self.cell.zero_state(1, tf.float32).eval()prime &#61; u&#39;^&#39;result &#61; u&#39;&#39;x &#61; np.array([list(map(vocab.get,prime))])[probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])while char !&#61; u&#39;$&#39;:result &#43;&#61; charx &#61; np.zeros((1,1))x[0,0] &#61; vocab[char][probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])return resultelse:result &#61; u&#39;^&#39;for prime_char in prime:result &#43;&#61; prime_charx &#61; np.array([list(map(vocab.get,result))])state &#61; self.cell.zero_state(1, tf.float32).eval()[probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])while char !&#61; u&#39;&#xff0c;&#39; and char !&#61; u&#39;。&#39;:result &#43;&#61; charx &#61; np.zeros((1,1))x[0,0] &#61; vocab[char][probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])result &#43;&#61; charreturn result[1:]
数据预处理utils.py
#-*- coding:utf-8 -*-import codecs
import os
import collections
from six.moves import cPickle,reduce,map
import numpy as npBEGIN_CHAR &#61; &#39;^&#39;
END_CHAR &#61; &#39;$&#39;
UNKNOWN_CHAR &#61; &#39;*&#39;
MAX_LENGTH &#61; 100class TextLoader():def __init__(self, batch_size, max_vocabsize&#61;3000, encoding&#61;&#39;utf-8&#39;):self.batch_size &#61; batch_sizeself.max_vocabsize &#61; max_vocabsizeself.encoding &#61; encodingdata_dir &#61; &#39;./data&#39;input_file &#61; os.path.join(data_dir, "shijing.txt")vocab_file &#61; os.path.join(data_dir, "vocab.pkl")tensor_file &#61; os.path.join(data_dir, "data.npy")if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):print("reading text file")self.preprocess(input_file, vocab_file, tensor_file)else:print("loading preprocessed files")self.load_preprocessed(vocab_file, tensor_file)self.create_batches()self.reset_batch_pointer()def preprocess(self, input_file, vocab_file, tensor_file):def handle_poem(line):line &#61; line.replace(&#39; &#39;,&#39;&#39;)if len(line) >&#61; MAX_LENGTH:index_end &#61; line.rfind(u&#39;。&#39;,0,MAX_LENGTH)index_end &#61; index_end if index_end > 0 else MAX_LENGTHline &#61; line[:index_end&#43;1]return BEGIN_CHAR&#43;line&#43;END_CHARwith codecs.open(input_file, "r", encoding&#61;self.encoding) as f:lines &#61; list(map(handle_poem,f.read().strip().split(&#39;\n&#39;)))counter &#61; collections.Counter(reduce(lambda data,line: line&#43;data,lines,&#39;&#39;))count_pairs &#61; sorted(counter.items(), key&#61;lambda x: -x[1])chars, _ &#61; zip(*count_pairs)self.vocab_size &#61; min(len(chars),self.max_vocabsize - 1) &#43; 1self.chars &#61; chars[:self.vocab_size-1] &#43; (UNKNOWN_CHAR,)self.vocab &#61; dict(zip(self.chars, range(len(self.chars))))unknown_char_int &#61; self.vocab.get(UNKNOWN_CHAR)with open(vocab_file, &#39;wb&#39;) as f:cPickle.dump(self.chars, f)get_int &#61; lambda char: self.vocab.get(char,unknown_char_int)lines &#61; sorted(lines,key&#61;lambda line: len(line))self.tensor &#61; [ list(map(get_int,line)) for line in lines ]with open(tensor_file,&#39;wb&#39;) as f:cPickle.dump(self.tensor,f)def load_preprocessed(self, vocab_file, tensor_file):with open(vocab_file, &#39;rb&#39;) as f:self.chars &#61; cPickle.load(f)with open(tensor_file,&#39;rb&#39;) as f:self.tensor &#61; cPickle.load(f)self.vocab_size &#61; len(self.chars)self.vocab &#61; dict(zip(self.chars, range(len(self.chars))))def create_batches(self):self.num_batches &#61; int(len(self.tensor) / self.batch_size)self.tensor &#61; self.tensor[:self.num_batches * self.batch_size]unknown_char_int &#61; self.vocab.get(UNKNOWN_CHAR)self.x_batches &#61; []self.y_batches &#61; []for i in range(self.num_batches):from_index &#61; i * self.batch_sizeto_index &#61; from_index &#43; self.batch_sizebatches &#61; self.tensor[from_index:to_index]seq_length &#61; max(map(len,batches))xdata &#61; np.full((self.batch_size,seq_length),unknown_char_int,np.int32)for row in range(self.batch_size):xdata[row,:len(batches[row])] &#61; batches[row]ydata &#61; np.copy(xdata)ydata[:,:-1] &#61; xdata[:,1:]self.x_batches.append(xdata)self.y_batches.append(ydata)def next_batch(self):x, y &#61; self.x_batches[self.pointer], self.y_batches[self.pointer]self.pointer &#43;&#61; 1return x, ydef reset_batch_pointer(self):self.pointer &#61; 0
测试案例sample.py
#-*- coding:utf-8 -*-from __future__ import print_function
import numpy as np
import tensorflow as tf
import argparse
import time
import os
from six.moves import cPicklefrom utils import TextLoader
from model import Modelfrom six import text_typedef main():parser &#61; argparse.ArgumentParser()parser.add_argument(&#39;--save_dir&#39;, type&#61;str, default&#61;&#39;save&#39;,help&#61;&#39;model directory to store checkpointed models&#39;)parser.add_argument(&#39;--prime&#39;, type&#61;str, default&#61;&#39;&#39;,help&#61;u&#39;输入指定文字生成藏头诗&#39;)parser.add_argument(&#39;--sample&#39;, type&#61;int, default&#61;1,help&#61;&#39;0 to use max at each timestep, 1 to sample at each timestep&#39;)args &#61; parser.parse_args()sample(args)def sample(args):with open(os.path.join(args.save_dir, &#39;config.pkl&#39;), &#39;rb&#39;) as f:saved_args &#61; cPickle.load(f)with open(os.path.join(args.save_dir, &#39;chars_vocab.pkl&#39;), &#39;rb&#39;) as f:chars, vocab &#61; cPickle.load(f)model &#61; Model(saved_args, True)with tf.Session() as sess:tf.global_variables_initializer().run()saver &#61; tf.train.Saver(tf.global_variables())ckpt &#61; tf.train.get_checkpoint_state(args.save_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)print(model.sample(sess, chars, vocab, args.prime.decode(&#39;utf-8&#39;,errors&#61;&#39;ignore&#39;), args.sample))if __name__ &#61;&#61; &#39;__main__&#39;:main()
python sample.py
rnn神经网络会生成一首全新的古诗。例如&#xff1a; ”帝以诚求备&#xff0c;堪留百勇杯。教官日与失&#xff0c;共恨五毛宣。鸡唇春疏叶&#xff0c;空衣滴舞衣。丑夫归晚里&#xff0c;此地几何人。”python sample.py --prime <这里输入指定汉字>
rnn神经网络会利用输入的汉字生成一首藏头诗。例如&#xff1a;python sample.py --prime 如花似月
会得到 “如尔残回号&#xff0c;花枝误晚声。似君星度上&#xff0c;月满二秋寒。”