热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

NLP训练个model出来写诗

2018年新年,腾讯整出来个ai春联很吸引眼球,刚好有个需求让我看下能不能训出来个model来写出诗经一样的文风,求助了下小伙伴ÿ

2018年新年,腾讯整出来个ai春联很吸引眼球,刚好有个需求让我看下能不能训出来个model来写出诗经一样的文风,求助了下小伙伴,直接丢过来2个github,原话是:

查了一下诗经一共38000个字,应该是可以训练出一个语言模型的。只是怕机器写出来的诗一般都没灵魂。https://github.com/hjptriplebee/Chinese_poem_generator; https://github.com/xue2han/AncientChinesePoemRNN.

我测试了,第一个没跑通,没有时间去check,所以直接第二个效果赞赞的。只跑了4000个epoch就给出了我惊喜,结果如下:

 

冲着这个效果这么厉害,一定要趁热扒一下背后的NLP技术。

首先参考AI对联背后的技术。

智能春联的核心技术从大的范畴上属于NLP,自然语言处理技术。创作春联又可以归类为其中的语言生成方向的技术,国内的语言生成研究可以追溯到20世纪90年代,至今已经探索了各种方法,主要有基于模版、随机生成并测试、基于遗传算法、基于实例推理、基于统计机器翻译等各种类型的方法。

本文举两个典型的技术途径作为案例:

1.第一种是没文化生成:即不去了解任何信息的内容,程序根本不知道文字内容是啥,只是从信息熵的角度进行随机生成与测试,在计算机眼里这只是“熵为**的一个随机数据序列”。专业的说法叫做不加领域知识的LSTM生成。LSTM是一种RNN网络(循环神经网络),适用于时序性较强的语言类样本,这里主要使用信息熵作为收敛的损失函数。这种方法生成的语言往往经不住推敲,缺乏意境、主题性等,主要是因为损失函数的定义缺少“文化”,过于强调“信息熵”。

RNN示意图

2.第二种是有文化生成:即在算法中增加了格律诗的领域知识,例如格律押韵、主题意境等。专业的说法叫基于主题模型的统计机器翻译生成。统计机器翻译主要是一类映射源语言与目标语言的模型。主要使用生成对联与参考标准集之间的相似度作为收敛的损失函数。这种方法的缺陷在于春联的质量与参考标准集强相关,容易陷入单一风格化,难以创造真正“属于机器自己的风格”。

神经机器翻译架构

一、不加领域知识的LSTM生成

1)从网上搜集了各式春联共6900对

2)将汉字编码为数字,或者叫做Encoder,并将数据分割为训练集和测试集

3)定义LSTM模型

4)用加权交叉熵损失函数训练模型,LOSS控制在1.5左右,训练结束

5)自动生成新的春联,需要再将数字转为汉字

感受一下代码

将汉字编码为数字,或者叫做Encoder,并将数据分割为训练集和测试集

couplet_file &#61;"couplet.txt"#对联couplets &#61; []with open(couplet_file,&#39;r&#39;) as f: for line in f: try: content &#61; line.replace(&#39; &#39;,&#39;&#39;) if &#39;_&#39; in content or &#39;(&#39; in content or &#39;&#xff08;&#39; in content or &#39;&#39; in content or &#39;[&#39; in content: continue if len(content) <5*3 or len(content) > 79*3: continue content &#61; &#39;[&#39; &#43; content &#43; &#39;]&#39; # print chardet.detect(content) content &#61; content.decode(&#39;utf-8&#39;) couplets.append(content) except Exception as e: pass# 按字数排序couplets &#61; sorted(couplets,key&#61;lambda line: len(line))print(&#39;对联总数: %d&#39;%(len(couplets)))# 统计每个字出现次数all_words &#61; []for couplet in couplets: all_words &#43;&#61; [word for word in couplet]counter &#61; collections.Counter(all_words)count_pairs &#61; sorted(counter.items(), key&#61;lambda x: -x[1])words, _ &#61; zip(*count_pairs)words &#61; words[:len(words)] &#43; (&#39; &#39;,)# 每个字映射为一个数字IDword_num_map &#61; dict(zip(words, range(len(words))))to_num &#61; lambda word: word_num_map.get(word, len(words))couplets_vector &#61; [ list(map(to_num, couplet)) for couplet in couplets]# 每次取64首对联进行训练, 此参数可以调整batch_size &#61; 64n_chunk &#61; len(couplets_vector) // batch_sizex_batches &#61; []y_batches &#61; []for i in range(n_chunk): start_index &#61; i * batch_size#起始位置 end_index &#61; start_index &#43; batch_size#结束位置 batches &#61; couplets_vector[start_index:end_index] length &#61; max(map(len,batches))#每个batches中句子的最大长度 xdata &#61; np.full((batch_size,length), word_num_map[&#39; &#39;], np.int32) for row in range(batch_size): xdata[row,:len(batches[row])] &#61; batches[row] ydata &#61; np.copy(xdata) ydata[:,:-1] &#61; xdata[:,1:] x_batches.append(xdata) y_batches.append(ydata)

定义LSTM模型&#xff08;定义cell为一个128维的ht的cell。并使用MultiRNNCell 定义为两层的LSTM&#xff09;def neural_network(rnn_size
&#61;128, num_layers&#61;2): cell &#61; tf.nn.rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple&#61;True) cell &#61; tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple&#61;True) initial_state &#61; cell.zero_state(batch_size, tf.float32) with tf.variable_scope(&#39;rnnlm&#39;): softmax_w &#61; tf.get_variable("softmax_w", [rnn_size, len(words)&#43;1]) softmax_b &#61; tf.get_variable("softmax_b", [len(words)&#43;1]) with tf.device("/cpu:0"): embedding &#61; tf.get_variable("embedding", [len(words)&#43;1, rnn_size]) inputs &#61; tf.nn.embedding_lookup(embedding, input_data) outputs, last_state &#61; tf.nn.dynamic_rnn(cell, inputs, initial_state&#61;initial_state, scope&#61;&#39;rnnlm&#39;) output &#61; tf.reshape(outputs,[-1, rnn_size]) logits &#61; tf.matmul(output, softmax_w) &#43; softmax_b probs &#61; tf.nn.softmax(logits) return logits, last_state, probs, cell, initial_state

 

用加权交叉熵损失函数训练模型&#xff0c;LOSS控制在1.5左右&#xff0c;训练结束

def train_neural_network(): logits, last_state, _, _, _ &#61; neural_network() targets &#61; tf.reshape(output_targets, [-1]) loss &#61; tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype&#61;tf.float32)], len(words)) cost &#61; tf.reduce_mean(loss) learning_rate &#61; tf.Variable(0.0, trainable&#61;False) tvars &#61; tf.trainable_variables() grads, _ &#61; tf.clip_by_global_norm(tf.gradients(cost, tvars), 5) optimizer &#61; tf.train.AdamOptimizer(learning_rate) train_op &#61; optimizer.apply_gradients(zip(grads, tvars)) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) saver &#61; tf.train.Saver(tf.all_variables()) for epoch in range(100): sess.run(tf.assign(learning_rate, 0.01 * (0.97 ** epoch))) n &#61; 0 for batche in range(n_chunk): train_loss, _ , _ &#61; sess.run([cost, last_state, train_op], feed_dict&#61;{input_data: x_batches[n], output_targets: y_batches[n]}) n &#43;&#61; 1 print(epoch, batche, train_loss) if epoch % 7 &#61;&#61; 0: saver.save(sess, &#39;./couplet.module&#39;, global_step&#61;epoch)

 

自动生成新的春联

saver.restore(sess, &#39;couplet.module-98&#39;)

二、基于主题模型的统计机器翻译生成

1&#xff09;准备统计模型的训练数据

格律诗训练语料来自互联网&#xff0c;其中包括《 唐诗》、《 全唐诗》、《 全台词》等文献&#xff0c;以及从各大诗词论坛&#xff08;例如诗词在线、天涯论坛诗词比兴等&#xff09;抓取并筛选后的格律诗&#xff0c;总计287000多首。

2&#xff09;设定主题模型&#xff0c;这里使用概率潜在语义分析&#xff0c;PLSA

例如&#xff1a;给定主题词“春日”&#xff0c;根据它在潜在主题空间中的分布向量&#xff0c;可以找出 “玉魄”、“红泥”和 “燕”等空间距离比较近的语义相关词。

3&#xff09;基于主题模型的词汇扩展

4&#xff09;定义算法&#xff1a;依照主题词生成首句的算法

5&#xff09;定义基于统计机器翻译的二、三、四句生成模型

我们采用基于短语的统计机器翻译技术 &#xff0c;PBSMT是目前一种主流的机器翻译技术&#xff0c;它的优势在于短语翻译结果的选词准确&#xff0e; 由于诗词的生成讲求对仗&#xff0c;不涉及远距离语序调整问题&#xff0c;因此&#xff0c;诗词的生成非常适合采用基于短语的机器翻译算法来解决。

6&#xff09;基于BLEU的评测方法&#xff0c;结果收敛后保存模型

BLEU 的直观思想是翻译结果越接近参考答案则翻译质量越好&#xff0e; 相应的&#xff0c;我们认为如果根据给定上句生成的下句能够更贴近已有的参考下句则系统的生成质量越好&#xff0c;但由于诗词在内容表现上丰富多样&#xff0c;所以需要搜集拥有多个参考下句的数据样本加入答案集。BLEU通过对生成候选句与源语句的参考句进行&#xff11;元词到N元词的重合度统计&#xff0c;结合下式衡量生成结果的好坏。

7&#xff09;给定主题词&#xff0c;生成新的格律诗

论文全文下载&#xff0c;请在公众号回复&#xff1a;20180216

参考资料

关于RNN和LSTM原理的说明&#xff1a; http://www.jianshu.com/p/9dc9f41f0b29

LSTM深度学习写春联&#xff1a;http://blog.csdn.net/leadai/article/details/79015862

基于主题模型和统计机器翻译方法的中文格律诗自动生成&#xff1a;蒋锐滢&#xff0c;崔 磊&#xff0c;何 晶&#xff0c;周 明&#xff0c;潘志庚

 

接下来看代码

 

参照[char-rnn-tensorflow](https://github.com/sherjilozair/char-rnn-tensorflow)&#xff0c;使用RNN的字符模型&#xff0c;学习并生成古诗。
数据来自于http://www16.zzu.edu.cn/qts/ ,总共4万多首唐诗。

  • tensorflow 1.0
  • python2

 

先看训练数据&#xff0c;poems.txt.截取片段

煌煌道宫&#xff0c;肃肃太清。礼光尊祖&#xff0c;乐备充庭。罄竭诚至&#xff0c;希夷降灵。云凝翠盖&#xff0c;风焰红旌。众真以从&#xff0c;九奏初迎。永惟休v&#xff0c;是锡和平。
种瓜黄台下&#xff0c;瓜熟子离离。一摘使瓜好&#xff0c;再摘使瓜稀。三摘犹自可&#xff0c;摘绝抱蔓归。

 

看出来去掉了标题和作者的干扰。这点很重要&#xff0c;我诗经训练出来的结果很奇葩&#xff0c;估计就是我标题没有去。

核心训练代码。train.py

from __future__ import print_function
import numpy
as np
import tensorflow
as tfimport argparse
import time
import os,sys
from six.moves import cPicklefrom utils import TextLoader
from model import Modeldef main():parser &#61; argparse.ArgumentParser()parser.add_argument(&#39;--save_dir&#39;, type&#61;str, default&#61;&#39;save&#39;,help&#61;&#39;directory to store checkpointed models&#39;)parser.add_argument(&#39;--rnn_size&#39;, type&#61;int, default&#61;128,help&#61;&#39;size of RNN hidden state&#39;)parser.add_argument(&#39;--num_layers&#39;, type&#61;int, default&#61;2,help&#61;&#39;number of layers in the RNN&#39;)parser.add_argument(&#39;--model&#39;, type&#61;str, default&#61;&#39;lstm&#39;,help&#61;&#39;rnn, gru, or lstm&#39;)parser.add_argument(&#39;--batch_size&#39;, type&#61;int, default&#61;64,help&#61;&#39;minibatch size&#39;)parser.add_argument(&#39;--num_epochs&#39;, type&#61;int, default&#61;50,help&#61;&#39;number of epochs&#39;)parser.add_argument(&#39;--save_every&#39;, type&#61;int, default&#61;1000,help&#61;&#39;save frequency&#39;)parser.add_argument(&#39;--grad_clip&#39;, type&#61;float, default&#61;5.,help&#61;&#39;clip gradients at this value&#39;)parser.add_argument(&#39;--learning_rate&#39;, type&#61;float, default&#61;0.002,help&#61;&#39;learning rate&#39;)parser.add_argument(&#39;--decay_rate&#39;, type&#61;float, default&#61;0.97,help&#61;&#39;decay rate for rmsprop&#39;)parser.add_argument(&#39;--init_from&#39;, type&#61;str, default&#61;None,help&#61;"""continue training from saved model at this path. Path must contain files saved by previous training process:&#39;config.pkl&#39; : configuration;&#39;chars_vocab.pkl&#39; : vocabulary definitions;&#39;iterations&#39; : number of trained iterations;&#39;losses-*&#39; : train loss;&#39;checkpoint&#39; : paths to model file(s) (created by tf).Note: this file contains absolute paths, be careful when moving files around;&#39;model.ckpt-*&#39; : file(s) with model definition (created by tf)""")args &#61; parser.parse_args()train(args)def train(args):data_loader &#61; TextLoader(args.batch_size)args.vocab_size &#61; data_loader.vocab_size# check compatibility if training is continued from previously saved modelif args.init_from is not None:# check if all necessary files existassert os.path.isdir(args.init_from)," %s must be a a path" % args.init_fromassert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_fromassert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_fromckpt &#61; tf.train.get_checkpoint_state(args.init_from)assert ckpt,"No checkpoint found"assert ckpt.model_checkpoint_path,"No model path found in checkpoint"assert os.path.isfile(os.path.join(args.init_from,"iterations")),"iterations file does not exist in path %s " % args.init_from# open old config and check if models are compatiblewith open(os.path.join(args.init_from, &#39;config.pkl&#39;),&#39;rb&#39;) as f:saved_model_args &#61; cPickle.load(f)need_be_same&#61;["model","rnn_size","num_layers"]for checkme in need_be_same:assert vars(saved_model_args)[checkme]&#61;&#61;vars(args)[checkme],"Command line argument and saved model disagree on &#39;%s&#39; "%checkme# open saved vocab/dict and check if vocabs/dicts are compatiblewith open(os.path.join(args.init_from, &#39;chars_vocab.pkl&#39;),&#39;rb&#39;) as f:saved_chars, saved_vocab &#61; cPickle.load(f)assert saved_chars&#61;&#61;data_loader.chars, "Data and loaded model disagree on character set!"assert saved_vocab&#61;&#61;data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"with open(os.path.join(args.save_dir, &#39;config.pkl&#39;), &#39;wb&#39;) as f:cPickle.dump(args, f)with open(os.path.join(args.save_dir, &#39;chars_vocab.pkl&#39;), &#39;wb&#39;) as f:cPickle.dump((data_loader.chars, data_loader.vocab), f)model &#61; Model(args)with tf.Session() as sess:tf.global_variables_initializer().run()saver &#61; tf.train.Saver(tf.global_variables())iterations &#61; 0# restore model and number of iterationsif args.init_from is not None:saver.restore(sess, ckpt.model_checkpoint_path)with open(os.path.join(args.save_dir, &#39;iterations&#39;),&#39;rb&#39;) as f:iterations &#61; cPickle.load(f)losses &#61; []for e in range(args.num_epochs):sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))data_loader.reset_batch_pointer()for b in range(data_loader.num_batches):iterations &#43;&#61; 1start &#61; time.time() x, y &#61; data_loader.next_batch() feed &#61; {model.input_data: x, model.targets: y}train_loss, _ , _ &#61; sess.run([model.cost, model.final_state, model.train_op], feed)end &#61; time.time()sys.stdout.write(&#39;\r&#39;)info &#61; "{}/{} (epoch {}), train_loss &#61; {:.3f}, time/batch &#61; {:.3f}" \.format(e * data_loader.num_batches &#43; b,args.num_epochs * data_loader.num_batches,e, train_loss, end - start)sys.stdout.write(info)sys.stdout.flush()losses.append(train_loss)if (e * data_loader.num_batches &#43; b) % args.save_every &#61;&#61; 0\or (e&#61;&#61;args.num_epochs-1 and b &#61;&#61; data_loader.num_batches-1): # save for the last resultcheckpoint_path &#61; os.path.join(args.save_dir, &#39;model.ckpt&#39;)saver.save(sess, checkpoint_path, global_step &#61; iterations)with open(os.path.join(args.save_dir,"iterations"),&#39;wb&#39;) as f:cPickle.dump(iterations,f)with open(os.path.join(args.save_dir,"losses-"&#43;str(iterations)),&#39;wb&#39;) as f:cPickle.dump(losses,f)losses &#61; []sys.stdout.write(&#39;\n&#39;)print("model saved to {}".format(checkpoint_path))sys.stdout.write(&#39;\n&#39;)if __name__ &#61;&#61; &#39;__main__&#39;:main()

 再看下model.py

#-*- coding:utf-8 -*-import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.contrib import legacy_seq2seq
import numpy
as npclass Model():def __init__(self, args,infer&#61;False):self.args &#61; argsif infer:args.batch_size &#61; 1if args.model &#61;&#61; &#39;rnn&#39;:cell_fn &#61; rnn.BasicRNNCellelif args.model &#61;&#61; &#39;gru&#39;:cell_fn &#61; rnn.GRUCellelif args.model &#61;&#61; &#39;lstm&#39;:cell_fn &#61; rnn.BasicLSTMCellelse:raise Exception("model type not supported: {}".format(args.model))cell &#61; cell_fn(args.rnn_size,state_is_tuple&#61;False)self.cell &#61; cell &#61; rnn.MultiRNNCell([cell] * args.num_layers,state_is_tuple&#61;False)self.input_data &#61; tf.placeholder(tf.int32, [args.batch_size, None])# the length of input sequence is variable.self.targets &#61; tf.placeholder(tf.int32, [args.batch_size, None])self.initial_state &#61; cell.zero_state(args.batch_size, tf.float32)with tf.variable_scope(&#39;rnnlm&#39;):softmax_w &#61; tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])softmax_b &#61; tf.get_variable("softmax_b", [args.vocab_size])with tf.device("/cpu:0"):embedding &#61; tf.get_variable("embedding", [args.vocab_size, args.rnn_size])inputs &#61; tf.nn.embedding_lookup(embedding, self.input_data)outputs, last_state &#61; tf.nn.dynamic_rnn(cell,inputs,initial_state&#61;self.initial_state,scope&#61;&#39;rnnlm&#39;)output &#61; tf.reshape(outputs,[-1, args.rnn_size])self.logits &#61; tf.matmul(output, softmax_w) &#43; softmax_bself.probs &#61; tf.nn.softmax(self.logits)targets &#61; tf.reshape(self.targets, [-1])loss &#61; legacy_seq2seq.sequence_loss_by_example([self.logits],[targets],[tf.ones_like(targets,dtype&#61;tf.float32)],args.vocab_size)self.cost &#61; tf.reduce_mean(loss)self.final_state &#61; last_stateself.lr &#61; tf.Variable(0.0, trainable&#61;False)tvars &#61; tf.trainable_variables()grads, _ &#61; tf.clip_by_global_norm(tf.gradients(self.cost, tvars),args.grad_clip)optimizer &#61; tf.train.AdamOptimizer(self.lr)self.train_op &#61; optimizer.apply_gradients(zip(grads, tvars))def sample(self, sess, chars, vocab, prime&#61;u&#39;&#39;, sampling_type&#61;1):def pick_char(weights):if sampling_type &#61;&#61; 0:sample &#61; np.argmax(weights)else:t &#61; np.cumsum(weights)s &#61; np.sum(weights)sample &#61; int(np.searchsorted(t, np.random.rand(1)*s))return chars[sample]for char in prime:if char not in vocab:return u"{} is not in charset!".format(char)if not prime:state &#61; self.cell.zero_state(1, tf.float32).eval()prime &#61; u&#39;^&#39;result &#61; u&#39;&#39;x &#61; np.array([list(map(vocab.get,prime))])[probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])while char !&#61; u&#39;$&#39;:result &#43;&#61; charx &#61; np.zeros((1,1))x[0,0] &#61; vocab[char][probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])return resultelse:result &#61; u&#39;^&#39;for prime_char in prime:result &#43;&#61; prime_charx &#61; np.array([list(map(vocab.get,result))])state &#61; self.cell.zero_state(1, tf.float32).eval()[probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])while char !&#61; u&#39;&#xff0c;&#39; and char !&#61; u&#39;&#39;:result &#43;&#61; charx &#61; np.zeros((1,1))x[0,0] &#61; vocab[char][probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])result &#43;&#61; charreturn result[1:]

数据预处理utils.py

#-*- coding:utf-8 -*-import codecs
import os
import collections
from six.moves import cPickle,reduce,map
import numpy
as npBEGIN_CHAR &#61; &#39;^&#39;
END_CHAR
&#61; &#39;$&#39;
UNKNOWN_CHAR
&#61; &#39;*&#39;
MAX_LENGTH
&#61; 100class TextLoader():def __init__(self, batch_size, max_vocabsize&#61;3000, encoding&#61;&#39;utf-8&#39;):self.batch_size &#61; batch_sizeself.max_vocabsize &#61; max_vocabsizeself.encoding &#61; encodingdata_dir &#61; &#39;./data&#39;input_file &#61; os.path.join(data_dir, "shijing.txt")vocab_file &#61; os.path.join(data_dir, "vocab.pkl")tensor_file &#61; os.path.join(data_dir, "data.npy")if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):print("reading text file")self.preprocess(input_file, vocab_file, tensor_file)else:print("loading preprocessed files")self.load_preprocessed(vocab_file, tensor_file)self.create_batches()self.reset_batch_pointer()def preprocess(self, input_file, vocab_file, tensor_file):def handle_poem(line):line &#61; line.replace(&#39; &#39;,&#39;&#39;)if len(line) >&#61; MAX_LENGTH:index_end &#61; line.rfind(u&#39;&#39;,0,MAX_LENGTH)index_end &#61; index_end if index_end > 0 else MAX_LENGTHline &#61; line[:index_end&#43;1]return BEGIN_CHAR&#43;line&#43;END_CHARwith codecs.open(input_file, "r", encoding&#61;self.encoding) as f:lines &#61; list(map(handle_poem,f.read().strip().split(&#39;\n&#39;)))counter &#61; collections.Counter(reduce(lambda data,line: line&#43;data,lines,&#39;&#39;))count_pairs &#61; sorted(counter.items(), key&#61;lambda x: -x[1])chars, _ &#61; zip(*count_pairs)self.vocab_size &#61; min(len(chars),self.max_vocabsize - 1) &#43; 1self.chars &#61; chars[:self.vocab_size-1] &#43; (UNKNOWN_CHAR,)self.vocab &#61; dict(zip(self.chars, range(len(self.chars))))unknown_char_int &#61; self.vocab.get(UNKNOWN_CHAR)with open(vocab_file, &#39;wb&#39;) as f:cPickle.dump(self.chars, f)get_int &#61; lambda char: self.vocab.get(char,unknown_char_int)lines &#61; sorted(lines,key&#61;lambda line: len(line))self.tensor &#61; [ list(map(get_int,line)) for line in lines ]with open(tensor_file,&#39;wb&#39;) as f:cPickle.dump(self.tensor,f)def load_preprocessed(self, vocab_file, tensor_file):with open(vocab_file, &#39;rb&#39;) as f:self.chars &#61; cPickle.load(f)with open(tensor_file,&#39;rb&#39;) as f:self.tensor &#61; cPickle.load(f)self.vocab_size &#61; len(self.chars)self.vocab &#61; dict(zip(self.chars, range(len(self.chars))))def create_batches(self):self.num_batches &#61; int(len(self.tensor) / self.batch_size)self.tensor &#61; self.tensor[:self.num_batches * self.batch_size]unknown_char_int &#61; self.vocab.get(UNKNOWN_CHAR)self.x_batches &#61; []self.y_batches &#61; []for i in range(self.num_batches):from_index &#61; i * self.batch_sizeto_index &#61; from_index &#43; self.batch_sizebatches &#61; self.tensor[from_index:to_index]seq_length &#61; max(map(len,batches))xdata &#61; np.full((self.batch_size,seq_length),unknown_char_int,np.int32)for row in range(self.batch_size):xdata[row,:len(batches[row])] &#61; batches[row]ydata &#61; np.copy(xdata)ydata[:,:-1] &#61; xdata[:,1:]self.x_batches.append(xdata)self.y_batches.append(ydata)def next_batch(self):x, y &#61; self.x_batches[self.pointer], self.y_batches[self.pointer]self.pointer &#43;&#61; 1return x, ydef reset_batch_pointer(self):self.pointer &#61; 0

 

测试案例sample.py

#-*- coding:utf-8 -*-from __future__ import print_function
import numpy
as np
import tensorflow
as tf
import argparse
import time
import os
from six.moves import cPicklefrom utils import TextLoader
from model import Modelfrom six import text_typedef main():parser &#61; argparse.ArgumentParser()parser.add_argument(&#39;--save_dir&#39;, type&#61;str, default&#61;&#39;save&#39;,help&#61;&#39;model directory to store checkpointed models&#39;)parser.add_argument(&#39;--prime&#39;, type&#61;str, default&#61;&#39;&#39;,help&#61;u&#39;输入指定文字生成藏头诗&#39;)parser.add_argument(&#39;--sample&#39;, type&#61;int, default&#61;1,help&#61;&#39;0 to use max at each timestep, 1 to sample at each timestep&#39;)args &#61; parser.parse_args()sample(args)def sample(args):with open(os.path.join(args.save_dir, &#39;config.pkl&#39;), &#39;rb&#39;) as f:saved_args &#61; cPickle.load(f)with open(os.path.join(args.save_dir, &#39;chars_vocab.pkl&#39;), &#39;rb&#39;) as f:chars, vocab &#61; cPickle.load(f)model &#61; Model(saved_args, True)with tf.Session() as sess:tf.global_variables_initializer().run()saver &#61; tf.train.Saver(tf.global_variables())ckpt &#61; tf.train.get_checkpoint_state(args.save_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)print(model.sample(sess, chars, vocab, args.prime.decode(&#39;utf-8&#39;,errors&#61;&#39;ignore&#39;), args.sample))if __name__ &#61;&#61; &#39;__main__&#39;:main()

  • python sample.py rnn神经网络会生成一首全新的古诗。例如&#xff1a; ”帝以诚求备&#xff0c;堪留百勇杯。教官日与失&#xff0c;共恨五毛宣。鸡唇春疏叶&#xff0c;空衣滴舞衣。丑夫归晚里&#xff0c;此地几何人。”
  • python sample.py --prime <这里输入指定汉字> rnn神经网络会利用输入的汉字生成一首藏头诗。例如&#xff1a; python sample.py --prime 如花似月 会得到 “如尔残回号&#xff0c;花枝误晚声。似君星度上&#xff0c;月满二秋寒。”

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


转:https://www.cnblogs.com/Anita9002/p/9106063.html



推荐阅读
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 在Docker中,将主机目录挂载到容器中作为volume使用时,常常会遇到文件权限问题。这是因为容器内外的UID不同所导致的。本文介绍了解决这个问题的方法,包括使用gosu和suexec工具以及在Dockerfile中配置volume的权限。通过这些方法,可以避免在使用Docker时出现无写权限的情况。 ... [详细]
  • 本文介绍了在Python3中如何使用选择文件对话框的格式打开和保存图片的方法。通过使用tkinter库中的filedialog模块的asksaveasfilename和askopenfilename函数,可以方便地选择要打开或保存的图片文件,并进行相关操作。具体的代码示例和操作步骤也被提供。 ... [详细]
  • 目录实现效果:实现环境实现方法一:基本思路主要代码JavaScript代码总结方法二主要代码总结方法三基本思路主要代码JavaScriptHTML总结实 ... [详细]
  • 解决Cydia数据库错误:could not open file /var/lib/dpkg/status 的方法
    本文介绍了解决iOS系统中Cydia数据库错误的方法。通过使用苹果电脑上的Impactor工具和NewTerm软件,以及ifunbox工具和终端命令,可以解决该问题。具体步骤包括下载所需工具、连接手机到电脑、安装NewTerm、下载ifunbox并注册Dropbox账号、下载并解压lib.zip文件、将lib文件夹拖入Books文件夹中,并将lib文件夹拷贝到/var/目录下。以上方法适用于已经越狱且出现Cydia数据库错误的iPhone手机。 ... [详细]
  • XML介绍与使用的概述及标签规则
    本文介绍了XML的基本概念和用途,包括XML的可扩展性和标签的自定义特性。同时还详细解释了XML标签的规则,包括标签的尖括号和合法标识符的组成,标签必须成对出现的原则以及特殊标签的使用方法。通过本文的阅读,读者可以对XML的基本知识有一个全面的了解。 ... [详细]
  • 本文介绍了在Python张量流中使用make_merged_spec()方法合并设备规格对象的方法和语法,以及参数和返回值的说明,并提供了一个示例代码。 ... [详细]
  • 本文由编程笔记#小编为大家整理,主要介绍了logistic回归(线性和非线性)相关的知识,包括线性logistic回归的代码和数据集的分布情况。希望对你有一定的参考价值。 ... [详细]
  • 本文介绍了在开发Android新闻App时,搭建本地服务器的步骤。通过使用XAMPP软件,可以一键式搭建起开发环境,包括Apache、MySQL、PHP、PERL。在本地服务器上新建数据库和表,并设置相应的属性。最后,给出了创建new表的SQL语句。这个教程适合初学者参考。 ... [详细]
  • 向QTextEdit拖放文件的方法及实现步骤
    本文介绍了在使用QTextEdit时如何实现拖放文件的功能,包括相关的方法和实现步骤。通过重写dragEnterEvent和dropEvent函数,并结合QMimeData和QUrl等类,可以轻松实现向QTextEdit拖放文件的功能。详细的代码实现和说明可以参考本文提供的示例代码。 ... [详细]
  • android listview OnItemClickListener失效原因
    最近在做listview时发现OnItemClickListener失效的问题,经过查找发现是因为button的原因。不仅listitem中存在button会影响OnItemClickListener事件的失效,还会导致单击后listview每个item的背景改变,使得item中的所有有关焦点的事件都失效。本文给出了一个范例来说明这种情况,并提供了解决方法。 ... [详细]
  • 本文介绍了Redis的基础数据结构string的应用场景,并以面试的形式进行问答讲解,帮助读者更好地理解和应用Redis。同时,描述了一位面试者的心理状态和面试官的行为。 ... [详细]
  • 第四章高阶函数(参数传递、高阶函数、lambda表达式)(python进阶)的讲解和应用
    本文主要讲解了第四章高阶函数(参数传递、高阶函数、lambda表达式)的相关知识,包括函数参数传递机制和赋值机制、引用传递的概念和应用、默认参数的定义和使用等内容。同时介绍了高阶函数和lambda表达式的概念,并给出了一些实例代码进行演示。对于想要进一步提升python编程能力的读者来说,本文将是一个不错的学习资料。 ... [详细]
  • Java程序设计第4周学习总结及注释应用的开发笔记
    本文由编程笔记#小编为大家整理,主要介绍了201521123087《Java程序设计》第4周学习总结相关的知识,包括注释的应用和使用类的注释与方法的注释进行注释的方法,并在Eclipse中查看。摘要内容大约为150字,提供了一定的参考价值。 ... [详细]
  • Week04面向对象设计与继承学习总结及作业要求
    本文总结了Week04面向对象设计与继承的重要知识点,包括对象、类、封装性、静态属性、静态方法、重载、继承和多态等。同时,还介绍了私有构造函数在类外部无法被调用、static不能访问非静态属性以及该类实例可以共享类里的static属性等内容。此外,还提到了作业要求,包括讲述一个在网上商城购物或在班级博客进行学习的故事,并使用Markdown的加粗标记和语句块标记标注关键名词和动词。最后,还提到了参考资料中关于UML类图如何绘制的范例。 ... [详细]
author-avatar
陈杰倩平贵奕白
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有