热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

NLP训练个model出来写诗

2018年新年,腾讯整出来个ai春联很吸引眼球,刚好有个需求让我看下能不能训出来个model来写出诗经一样的文风,求助了下小伙伴ÿ

2018年新年,腾讯整出来个ai春联很吸引眼球,刚好有个需求让我看下能不能训出来个model来写出诗经一样的文风,求助了下小伙伴,直接丢过来2个github,原话是:

查了一下诗经一共38000个字,应该是可以训练出一个语言模型的。只是怕机器写出来的诗一般都没灵魂。https://github.com/hjptriplebee/Chinese_poem_generator; https://github.com/xue2han/AncientChinesePoemRNN.

我测试了,第一个没跑通,没有时间去check,所以直接第二个效果赞赞的。只跑了4000个epoch就给出了我惊喜,结果如下:

 

冲着这个效果这么厉害,一定要趁热扒一下背后的NLP技术。

首先参考AI对联背后的技术。

智能春联的核心技术从大的范畴上属于NLP,自然语言处理技术。创作春联又可以归类为其中的语言生成方向的技术,国内的语言生成研究可以追溯到20世纪90年代,至今已经探索了各种方法,主要有基于模版、随机生成并测试、基于遗传算法、基于实例推理、基于统计机器翻译等各种类型的方法。

本文举两个典型的技术途径作为案例:

1.第一种是没文化生成:即不去了解任何信息的内容,程序根本不知道文字内容是啥,只是从信息熵的角度进行随机生成与测试,在计算机眼里这只是“熵为**的一个随机数据序列”。专业的说法叫做不加领域知识的LSTM生成。LSTM是一种RNN网络(循环神经网络),适用于时序性较强的语言类样本,这里主要使用信息熵作为收敛的损失函数。这种方法生成的语言往往经不住推敲,缺乏意境、主题性等,主要是因为损失函数的定义缺少“文化”,过于强调“信息熵”。

RNN示意图

2.第二种是有文化生成:即在算法中增加了格律诗的领域知识,例如格律押韵、主题意境等。专业的说法叫基于主题模型的统计机器翻译生成。统计机器翻译主要是一类映射源语言与目标语言的模型。主要使用生成对联与参考标准集之间的相似度作为收敛的损失函数。这种方法的缺陷在于春联的质量与参考标准集强相关,容易陷入单一风格化,难以创造真正“属于机器自己的风格”。

神经机器翻译架构

一、不加领域知识的LSTM生成

1)从网上搜集了各式春联共6900对

2)将汉字编码为数字,或者叫做Encoder,并将数据分割为训练集和测试集

3)定义LSTM模型

4)用加权交叉熵损失函数训练模型,LOSS控制在1.5左右,训练结束

5)自动生成新的春联,需要再将数字转为汉字

感受一下代码

将汉字编码为数字,或者叫做Encoder,并将数据分割为训练集和测试集

couplet_file &#61;"couplet.txt"#对联couplets &#61; []with open(couplet_file,&#39;r&#39;) as f: for line in f: try: content &#61; line.replace(&#39; &#39;,&#39;&#39;) if &#39;_&#39; in content or &#39;(&#39; in content or &#39;&#xff08;&#39; in content or &#39;&#39; in content or &#39;[&#39; in content: continue if len(content) <5*3 or len(content) > 79*3: continue content &#61; &#39;[&#39; &#43; content &#43; &#39;]&#39; # print chardet.detect(content) content &#61; content.decode(&#39;utf-8&#39;) couplets.append(content) except Exception as e: pass# 按字数排序couplets &#61; sorted(couplets,key&#61;lambda line: len(line))print(&#39;对联总数: %d&#39;%(len(couplets)))# 统计每个字出现次数all_words &#61; []for couplet in couplets: all_words &#43;&#61; [word for word in couplet]counter &#61; collections.Counter(all_words)count_pairs &#61; sorted(counter.items(), key&#61;lambda x: -x[1])words, _ &#61; zip(*count_pairs)words &#61; words[:len(words)] &#43; (&#39; &#39;,)# 每个字映射为一个数字IDword_num_map &#61; dict(zip(words, range(len(words))))to_num &#61; lambda word: word_num_map.get(word, len(words))couplets_vector &#61; [ list(map(to_num, couplet)) for couplet in couplets]# 每次取64首对联进行训练, 此参数可以调整batch_size &#61; 64n_chunk &#61; len(couplets_vector) // batch_sizex_batches &#61; []y_batches &#61; []for i in range(n_chunk): start_index &#61; i * batch_size#起始位置 end_index &#61; start_index &#43; batch_size#结束位置 batches &#61; couplets_vector[start_index:end_index] length &#61; max(map(len,batches))#每个batches中句子的最大长度 xdata &#61; np.full((batch_size,length), word_num_map[&#39; &#39;], np.int32) for row in range(batch_size): xdata[row,:len(batches[row])] &#61; batches[row] ydata &#61; np.copy(xdata) ydata[:,:-1] &#61; xdata[:,1:] x_batches.append(xdata) y_batches.append(ydata)

定义LSTM模型&#xff08;定义cell为一个128维的ht的cell。并使用MultiRNNCell 定义为两层的LSTM&#xff09;def neural_network(rnn_size
&#61;128, num_layers&#61;2): cell &#61; tf.nn.rnn_cell.BasicLSTMCell(rnn_size, state_is_tuple&#61;True) cell &#61; tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple&#61;True) initial_state &#61; cell.zero_state(batch_size, tf.float32) with tf.variable_scope(&#39;rnnlm&#39;): softmax_w &#61; tf.get_variable("softmax_w", [rnn_size, len(words)&#43;1]) softmax_b &#61; tf.get_variable("softmax_b", [len(words)&#43;1]) with tf.device("/cpu:0"): embedding &#61; tf.get_variable("embedding", [len(words)&#43;1, rnn_size]) inputs &#61; tf.nn.embedding_lookup(embedding, input_data) outputs, last_state &#61; tf.nn.dynamic_rnn(cell, inputs, initial_state&#61;initial_state, scope&#61;&#39;rnnlm&#39;) output &#61; tf.reshape(outputs,[-1, rnn_size]) logits &#61; tf.matmul(output, softmax_w) &#43; softmax_b probs &#61; tf.nn.softmax(logits) return logits, last_state, probs, cell, initial_state

 

用加权交叉熵损失函数训练模型&#xff0c;LOSS控制在1.5左右&#xff0c;训练结束

def train_neural_network(): logits, last_state, _, _, _ &#61; neural_network() targets &#61; tf.reshape(output_targets, [-1]) loss &#61; tf.contrib.legacy_seq2seq.sequence_loss_by_example([logits], [targets], [tf.ones_like(targets, dtype&#61;tf.float32)], len(words)) cost &#61; tf.reduce_mean(loss) learning_rate &#61; tf.Variable(0.0, trainable&#61;False) tvars &#61; tf.trainable_variables() grads, _ &#61; tf.clip_by_global_norm(tf.gradients(cost, tvars), 5) optimizer &#61; tf.train.AdamOptimizer(learning_rate) train_op &#61; optimizer.apply_gradients(zip(grads, tvars)) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) saver &#61; tf.train.Saver(tf.all_variables()) for epoch in range(100): sess.run(tf.assign(learning_rate, 0.01 * (0.97 ** epoch))) n &#61; 0 for batche in range(n_chunk): train_loss, _ , _ &#61; sess.run([cost, last_state, train_op], feed_dict&#61;{input_data: x_batches[n], output_targets: y_batches[n]}) n &#43;&#61; 1 print(epoch, batche, train_loss) if epoch % 7 &#61;&#61; 0: saver.save(sess, &#39;./couplet.module&#39;, global_step&#61;epoch)

 

自动生成新的春联

saver.restore(sess, &#39;couplet.module-98&#39;)

二、基于主题模型的统计机器翻译生成

1&#xff09;准备统计模型的训练数据

格律诗训练语料来自互联网&#xff0c;其中包括《 唐诗》、《 全唐诗》、《 全台词》等文献&#xff0c;以及从各大诗词论坛&#xff08;例如诗词在线、天涯论坛诗词比兴等&#xff09;抓取并筛选后的格律诗&#xff0c;总计287000多首。

2&#xff09;设定主题模型&#xff0c;这里使用概率潜在语义分析&#xff0c;PLSA

例如&#xff1a;给定主题词“春日”&#xff0c;根据它在潜在主题空间中的分布向量&#xff0c;可以找出 “玉魄”、“红泥”和 “燕”等空间距离比较近的语义相关词。

3&#xff09;基于主题模型的词汇扩展

4&#xff09;定义算法&#xff1a;依照主题词生成首句的算法

5&#xff09;定义基于统计机器翻译的二、三、四句生成模型

我们采用基于短语的统计机器翻译技术 &#xff0c;PBSMT是目前一种主流的机器翻译技术&#xff0c;它的优势在于短语翻译结果的选词准确&#xff0e; 由于诗词的生成讲求对仗&#xff0c;不涉及远距离语序调整问题&#xff0c;因此&#xff0c;诗词的生成非常适合采用基于短语的机器翻译算法来解决。

6&#xff09;基于BLEU的评测方法&#xff0c;结果收敛后保存模型

BLEU 的直观思想是翻译结果越接近参考答案则翻译质量越好&#xff0e; 相应的&#xff0c;我们认为如果根据给定上句生成的下句能够更贴近已有的参考下句则系统的生成质量越好&#xff0c;但由于诗词在内容表现上丰富多样&#xff0c;所以需要搜集拥有多个参考下句的数据样本加入答案集。BLEU通过对生成候选句与源语句的参考句进行&#xff11;元词到N元词的重合度统计&#xff0c;结合下式衡量生成结果的好坏。

7&#xff09;给定主题词&#xff0c;生成新的格律诗

论文全文下载&#xff0c;请在公众号回复&#xff1a;20180216

参考资料

关于RNN和LSTM原理的说明&#xff1a; http://www.jianshu.com/p/9dc9f41f0b29

LSTM深度学习写春联&#xff1a;http://blog.csdn.net/leadai/article/details/79015862

基于主题模型和统计机器翻译方法的中文格律诗自动生成&#xff1a;蒋锐滢&#xff0c;崔 磊&#xff0c;何 晶&#xff0c;周 明&#xff0c;潘志庚

 

接下来看代码

 

参照[char-rnn-tensorflow](https://github.com/sherjilozair/char-rnn-tensorflow)&#xff0c;使用RNN的字符模型&#xff0c;学习并生成古诗。
数据来自于http://www16.zzu.edu.cn/qts/ ,总共4万多首唐诗。

  • tensorflow 1.0
  • python2

 

先看训练数据&#xff0c;poems.txt.截取片段

煌煌道宫&#xff0c;肃肃太清。礼光尊祖&#xff0c;乐备充庭。罄竭诚至&#xff0c;希夷降灵。云凝翠盖&#xff0c;风焰红旌。众真以从&#xff0c;九奏初迎。永惟休v&#xff0c;是锡和平。
种瓜黄台下&#xff0c;瓜熟子离离。一摘使瓜好&#xff0c;再摘使瓜稀。三摘犹自可&#xff0c;摘绝抱蔓归。

 

看出来去掉了标题和作者的干扰。这点很重要&#xff0c;我诗经训练出来的结果很奇葩&#xff0c;估计就是我标题没有去。

核心训练代码。train.py

from __future__ import print_function
import numpy
as np
import tensorflow
as tfimport argparse
import time
import os,sys
from six.moves import cPicklefrom utils import TextLoader
from model import Modeldef main():parser &#61; argparse.ArgumentParser()parser.add_argument(&#39;--save_dir&#39;, type&#61;str, default&#61;&#39;save&#39;,help&#61;&#39;directory to store checkpointed models&#39;)parser.add_argument(&#39;--rnn_size&#39;, type&#61;int, default&#61;128,help&#61;&#39;size of RNN hidden state&#39;)parser.add_argument(&#39;--num_layers&#39;, type&#61;int, default&#61;2,help&#61;&#39;number of layers in the RNN&#39;)parser.add_argument(&#39;--model&#39;, type&#61;str, default&#61;&#39;lstm&#39;,help&#61;&#39;rnn, gru, or lstm&#39;)parser.add_argument(&#39;--batch_size&#39;, type&#61;int, default&#61;64,help&#61;&#39;minibatch size&#39;)parser.add_argument(&#39;--num_epochs&#39;, type&#61;int, default&#61;50,help&#61;&#39;number of epochs&#39;)parser.add_argument(&#39;--save_every&#39;, type&#61;int, default&#61;1000,help&#61;&#39;save frequency&#39;)parser.add_argument(&#39;--grad_clip&#39;, type&#61;float, default&#61;5.,help&#61;&#39;clip gradients at this value&#39;)parser.add_argument(&#39;--learning_rate&#39;, type&#61;float, default&#61;0.002,help&#61;&#39;learning rate&#39;)parser.add_argument(&#39;--decay_rate&#39;, type&#61;float, default&#61;0.97,help&#61;&#39;decay rate for rmsprop&#39;)parser.add_argument(&#39;--init_from&#39;, type&#61;str, default&#61;None,help&#61;"""continue training from saved model at this path. Path must contain files saved by previous training process:&#39;config.pkl&#39; : configuration;&#39;chars_vocab.pkl&#39; : vocabulary definitions;&#39;iterations&#39; : number of trained iterations;&#39;losses-*&#39; : train loss;&#39;checkpoint&#39; : paths to model file(s) (created by tf).Note: this file contains absolute paths, be careful when moving files around;&#39;model.ckpt-*&#39; : file(s) with model definition (created by tf)""")args &#61; parser.parse_args()train(args)def train(args):data_loader &#61; TextLoader(args.batch_size)args.vocab_size &#61; data_loader.vocab_size# check compatibility if training is continued from previously saved modelif args.init_from is not None:# check if all necessary files existassert os.path.isdir(args.init_from)," %s must be a a path" % args.init_fromassert os.path.isfile(os.path.join(args.init_from,"config.pkl")),"config.pkl file does not exist in path %s"%args.init_fromassert os.path.isfile(os.path.join(args.init_from,"chars_vocab.pkl")),"chars_vocab.pkl.pkl file does not exist in path %s" % args.init_fromckpt &#61; tf.train.get_checkpoint_state(args.init_from)assert ckpt,"No checkpoint found"assert ckpt.model_checkpoint_path,"No model path found in checkpoint"assert os.path.isfile(os.path.join(args.init_from,"iterations")),"iterations file does not exist in path %s " % args.init_from# open old config and check if models are compatiblewith open(os.path.join(args.init_from, &#39;config.pkl&#39;),&#39;rb&#39;) as f:saved_model_args &#61; cPickle.load(f)need_be_same&#61;["model","rnn_size","num_layers"]for checkme in need_be_same:assert vars(saved_model_args)[checkme]&#61;&#61;vars(args)[checkme],"Command line argument and saved model disagree on &#39;%s&#39; "%checkme# open saved vocab/dict and check if vocabs/dicts are compatiblewith open(os.path.join(args.init_from, &#39;chars_vocab.pkl&#39;),&#39;rb&#39;) as f:saved_chars, saved_vocab &#61; cPickle.load(f)assert saved_chars&#61;&#61;data_loader.chars, "Data and loaded model disagree on character set!"assert saved_vocab&#61;&#61;data_loader.vocab, "Data and loaded model disagree on dictionary mappings!"with open(os.path.join(args.save_dir, &#39;config.pkl&#39;), &#39;wb&#39;) as f:cPickle.dump(args, f)with open(os.path.join(args.save_dir, &#39;chars_vocab.pkl&#39;), &#39;wb&#39;) as f:cPickle.dump((data_loader.chars, data_loader.vocab), f)model &#61; Model(args)with tf.Session() as sess:tf.global_variables_initializer().run()saver &#61; tf.train.Saver(tf.global_variables())iterations &#61; 0# restore model and number of iterationsif args.init_from is not None:saver.restore(sess, ckpt.model_checkpoint_path)with open(os.path.join(args.save_dir, &#39;iterations&#39;),&#39;rb&#39;) as f:iterations &#61; cPickle.load(f)losses &#61; []for e in range(args.num_epochs):sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))data_loader.reset_batch_pointer()for b in range(data_loader.num_batches):iterations &#43;&#61; 1start &#61; time.time() x, y &#61; data_loader.next_batch() feed &#61; {model.input_data: x, model.targets: y}train_loss, _ , _ &#61; sess.run([model.cost, model.final_state, model.train_op], feed)end &#61; time.time()sys.stdout.write(&#39;\r&#39;)info &#61; "{}/{} (epoch {}), train_loss &#61; {:.3f}, time/batch &#61; {:.3f}" \.format(e * data_loader.num_batches &#43; b,args.num_epochs * data_loader.num_batches,e, train_loss, end - start)sys.stdout.write(info)sys.stdout.flush()losses.append(train_loss)if (e * data_loader.num_batches &#43; b) % args.save_every &#61;&#61; 0\or (e&#61;&#61;args.num_epochs-1 and b &#61;&#61; data_loader.num_batches-1): # save for the last resultcheckpoint_path &#61; os.path.join(args.save_dir, &#39;model.ckpt&#39;)saver.save(sess, checkpoint_path, global_step &#61; iterations)with open(os.path.join(args.save_dir,"iterations"),&#39;wb&#39;) as f:cPickle.dump(iterations,f)with open(os.path.join(args.save_dir,"losses-"&#43;str(iterations)),&#39;wb&#39;) as f:cPickle.dump(losses,f)losses &#61; []sys.stdout.write(&#39;\n&#39;)print("model saved to {}".format(checkpoint_path))sys.stdout.write(&#39;\n&#39;)if __name__ &#61;&#61; &#39;__main__&#39;:main()

 再看下model.py

#-*- coding:utf-8 -*-import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.contrib import legacy_seq2seq
import numpy
as npclass Model():def __init__(self, args,infer&#61;False):self.args &#61; argsif infer:args.batch_size &#61; 1if args.model &#61;&#61; &#39;rnn&#39;:cell_fn &#61; rnn.BasicRNNCellelif args.model &#61;&#61; &#39;gru&#39;:cell_fn &#61; rnn.GRUCellelif args.model &#61;&#61; &#39;lstm&#39;:cell_fn &#61; rnn.BasicLSTMCellelse:raise Exception("model type not supported: {}".format(args.model))cell &#61; cell_fn(args.rnn_size,state_is_tuple&#61;False)self.cell &#61; cell &#61; rnn.MultiRNNCell([cell] * args.num_layers,state_is_tuple&#61;False)self.input_data &#61; tf.placeholder(tf.int32, [args.batch_size, None])# the length of input sequence is variable.self.targets &#61; tf.placeholder(tf.int32, [args.batch_size, None])self.initial_state &#61; cell.zero_state(args.batch_size, tf.float32)with tf.variable_scope(&#39;rnnlm&#39;):softmax_w &#61; tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size])softmax_b &#61; tf.get_variable("softmax_b", [args.vocab_size])with tf.device("/cpu:0"):embedding &#61; tf.get_variable("embedding", [args.vocab_size, args.rnn_size])inputs &#61; tf.nn.embedding_lookup(embedding, self.input_data)outputs, last_state &#61; tf.nn.dynamic_rnn(cell,inputs,initial_state&#61;self.initial_state,scope&#61;&#39;rnnlm&#39;)output &#61; tf.reshape(outputs,[-1, args.rnn_size])self.logits &#61; tf.matmul(output, softmax_w) &#43; softmax_bself.probs &#61; tf.nn.softmax(self.logits)targets &#61; tf.reshape(self.targets, [-1])loss &#61; legacy_seq2seq.sequence_loss_by_example([self.logits],[targets],[tf.ones_like(targets,dtype&#61;tf.float32)],args.vocab_size)self.cost &#61; tf.reduce_mean(loss)self.final_state &#61; last_stateself.lr &#61; tf.Variable(0.0, trainable&#61;False)tvars &#61; tf.trainable_variables()grads, _ &#61; tf.clip_by_global_norm(tf.gradients(self.cost, tvars),args.grad_clip)optimizer &#61; tf.train.AdamOptimizer(self.lr)self.train_op &#61; optimizer.apply_gradients(zip(grads, tvars))def sample(self, sess, chars, vocab, prime&#61;u&#39;&#39;, sampling_type&#61;1):def pick_char(weights):if sampling_type &#61;&#61; 0:sample &#61; np.argmax(weights)else:t &#61; np.cumsum(weights)s &#61; np.sum(weights)sample &#61; int(np.searchsorted(t, np.random.rand(1)*s))return chars[sample]for char in prime:if char not in vocab:return u"{} is not in charset!".format(char)if not prime:state &#61; self.cell.zero_state(1, tf.float32).eval()prime &#61; u&#39;^&#39;result &#61; u&#39;&#39;x &#61; np.array([list(map(vocab.get,prime))])[probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])while char !&#61; u&#39;$&#39;:result &#43;&#61; charx &#61; np.zeros((1,1))x[0,0] &#61; vocab[char][probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])return resultelse:result &#61; u&#39;^&#39;for prime_char in prime:result &#43;&#61; prime_charx &#61; np.array([list(map(vocab.get,result))])state &#61; self.cell.zero_state(1, tf.float32).eval()[probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])while char !&#61; u&#39;&#xff0c;&#39; and char !&#61; u&#39;&#39;:result &#43;&#61; charx &#61; np.zeros((1,1))x[0,0] &#61; vocab[char][probs,state] &#61; sess.run([self.probs,self.final_state],{self.input_data: x,self.initial_state: state})char &#61; pick_char(probs[-1])result &#43;&#61; charreturn result[1:]

数据预处理utils.py

#-*- coding:utf-8 -*-import codecs
import os
import collections
from six.moves import cPickle,reduce,map
import numpy
as npBEGIN_CHAR &#61; &#39;^&#39;
END_CHAR
&#61; &#39;$&#39;
UNKNOWN_CHAR
&#61; &#39;*&#39;
MAX_LENGTH
&#61; 100class TextLoader():def __init__(self, batch_size, max_vocabsize&#61;3000, encoding&#61;&#39;utf-8&#39;):self.batch_size &#61; batch_sizeself.max_vocabsize &#61; max_vocabsizeself.encoding &#61; encodingdata_dir &#61; &#39;./data&#39;input_file &#61; os.path.join(data_dir, "shijing.txt")vocab_file &#61; os.path.join(data_dir, "vocab.pkl")tensor_file &#61; os.path.join(data_dir, "data.npy")if not (os.path.exists(vocab_file) and os.path.exists(tensor_file)):print("reading text file")self.preprocess(input_file, vocab_file, tensor_file)else:print("loading preprocessed files")self.load_preprocessed(vocab_file, tensor_file)self.create_batches()self.reset_batch_pointer()def preprocess(self, input_file, vocab_file, tensor_file):def handle_poem(line):line &#61; line.replace(&#39; &#39;,&#39;&#39;)if len(line) >&#61; MAX_LENGTH:index_end &#61; line.rfind(u&#39;&#39;,0,MAX_LENGTH)index_end &#61; index_end if index_end > 0 else MAX_LENGTHline &#61; line[:index_end&#43;1]return BEGIN_CHAR&#43;line&#43;END_CHARwith codecs.open(input_file, "r", encoding&#61;self.encoding) as f:lines &#61; list(map(handle_poem,f.read().strip().split(&#39;\n&#39;)))counter &#61; collections.Counter(reduce(lambda data,line: line&#43;data,lines,&#39;&#39;))count_pairs &#61; sorted(counter.items(), key&#61;lambda x: -x[1])chars, _ &#61; zip(*count_pairs)self.vocab_size &#61; min(len(chars),self.max_vocabsize - 1) &#43; 1self.chars &#61; chars[:self.vocab_size-1] &#43; (UNKNOWN_CHAR,)self.vocab &#61; dict(zip(self.chars, range(len(self.chars))))unknown_char_int &#61; self.vocab.get(UNKNOWN_CHAR)with open(vocab_file, &#39;wb&#39;) as f:cPickle.dump(self.chars, f)get_int &#61; lambda char: self.vocab.get(char,unknown_char_int)lines &#61; sorted(lines,key&#61;lambda line: len(line))self.tensor &#61; [ list(map(get_int,line)) for line in lines ]with open(tensor_file,&#39;wb&#39;) as f:cPickle.dump(self.tensor,f)def load_preprocessed(self, vocab_file, tensor_file):with open(vocab_file, &#39;rb&#39;) as f:self.chars &#61; cPickle.load(f)with open(tensor_file,&#39;rb&#39;) as f:self.tensor &#61; cPickle.load(f)self.vocab_size &#61; len(self.chars)self.vocab &#61; dict(zip(self.chars, range(len(self.chars))))def create_batches(self):self.num_batches &#61; int(len(self.tensor) / self.batch_size)self.tensor &#61; self.tensor[:self.num_batches * self.batch_size]unknown_char_int &#61; self.vocab.get(UNKNOWN_CHAR)self.x_batches &#61; []self.y_batches &#61; []for i in range(self.num_batches):from_index &#61; i * self.batch_sizeto_index &#61; from_index &#43; self.batch_sizebatches &#61; self.tensor[from_index:to_index]seq_length &#61; max(map(len,batches))xdata &#61; np.full((self.batch_size,seq_length),unknown_char_int,np.int32)for row in range(self.batch_size):xdata[row,:len(batches[row])] &#61; batches[row]ydata &#61; np.copy(xdata)ydata[:,:-1] &#61; xdata[:,1:]self.x_batches.append(xdata)self.y_batches.append(ydata)def next_batch(self):x, y &#61; self.x_batches[self.pointer], self.y_batches[self.pointer]self.pointer &#43;&#61; 1return x, ydef reset_batch_pointer(self):self.pointer &#61; 0

 

测试案例sample.py

#-*- coding:utf-8 -*-from __future__ import print_function
import numpy
as np
import tensorflow
as tf
import argparse
import time
import os
from six.moves import cPicklefrom utils import TextLoader
from model import Modelfrom six import text_typedef main():parser &#61; argparse.ArgumentParser()parser.add_argument(&#39;--save_dir&#39;, type&#61;str, default&#61;&#39;save&#39;,help&#61;&#39;model directory to store checkpointed models&#39;)parser.add_argument(&#39;--prime&#39;, type&#61;str, default&#61;&#39;&#39;,help&#61;u&#39;输入指定文字生成藏头诗&#39;)parser.add_argument(&#39;--sample&#39;, type&#61;int, default&#61;1,help&#61;&#39;0 to use max at each timestep, 1 to sample at each timestep&#39;)args &#61; parser.parse_args()sample(args)def sample(args):with open(os.path.join(args.save_dir, &#39;config.pkl&#39;), &#39;rb&#39;) as f:saved_args &#61; cPickle.load(f)with open(os.path.join(args.save_dir, &#39;chars_vocab.pkl&#39;), &#39;rb&#39;) as f:chars, vocab &#61; cPickle.load(f)model &#61; Model(saved_args, True)with tf.Session() as sess:tf.global_variables_initializer().run()saver &#61; tf.train.Saver(tf.global_variables())ckpt &#61; tf.train.get_checkpoint_state(args.save_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path)print(model.sample(sess, chars, vocab, args.prime.decode(&#39;utf-8&#39;,errors&#61;&#39;ignore&#39;), args.sample))if __name__ &#61;&#61; &#39;__main__&#39;:main()

  • python sample.py rnn神经网络会生成一首全新的古诗。例如&#xff1a; ”帝以诚求备&#xff0c;堪留百勇杯。教官日与失&#xff0c;共恨五毛宣。鸡唇春疏叶&#xff0c;空衣滴舞衣。丑夫归晚里&#xff0c;此地几何人。”
  • python sample.py --prime <这里输入指定汉字> rnn神经网络会利用输入的汉字生成一首藏头诗。例如&#xff1a; python sample.py --prime 如花似月 会得到 “如尔残回号&#xff0c;花枝误晚声。似君星度上&#xff0c;月满二秋寒。”

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 


转:https://www.cnblogs.com/Anita9002/p/9106063.html



推荐阅读
  • 独家解析:深度学习泛化理论的破解之道与应用前景
    本文深入探讨了深度学习泛化理论的关键问题,通过分析现有研究和实践经验,揭示了泛化性能背后的核心机制。文章详细解析了泛化能力的影响因素,并提出了改进模型泛化性能的有效策略。此外,还展望了这些理论在实际应用中的广阔前景,为未来的研究和开发提供了宝贵的参考。 ... [详细]
  • 2019年斯坦福大学CS224n课程笔记:深度学习在自然语言处理中的应用——Word2Vec与GloVe模型解析
    本文详细解析了2019年斯坦福大学CS224n课程中关于深度学习在自然语言处理(NLP)领域的应用,重点探讨了Word2Vec和GloVe两种词嵌入模型的原理与实现方法。通过具体案例分析,深入阐述了这两种模型在提升NLP任务性能方面的优势与应用场景。 ... [详细]
  • 计算机视觉领域介绍 | 自然语言驱动的跨模态行人重识别前沿技术综述(上篇)
    本文介绍了计算机视觉领域的最新进展,特别是自然语言驱动的跨模态行人重识别技术。上篇内容详细探讨了该领域的基础理论、关键技术及当前的研究热点,为读者提供了全面的概述。 ... [详细]
  • 能够感知你情绪状态的智能机器人即将问世 | 科技前沿观察
    本周科技前沿报道了多项重要进展,包括美国多所高校在机器人技术和自动驾驶领域的最新研究成果,以及硅谷大型企业在智能硬件和深度学习技术上的突破性进展。特别值得一提的是,一款能够感知用户情绪状态的智能机器人即将问世,为未来的人机交互带来了全新的可能性。 ... [详细]
  • 从2019年AI顶级会议最佳论文,探索深度学习的理论根基与前沿进展 ... [详细]
  • TensorFlow基础知识深化讲解
    批标准化批标准化(batchnormalization,BN)是为了克服神经网络层数加深导致难以训练而诞生的。深度神经网络随着深度加深,收 ... [详细]
  • python绘制拟合回归散点图_机器学习之利用Python进行简单线性回归分析
    前言:在利用机器学习方法进行数据分析时经常要了解变量的相关性,有时还需要对变量进行回归分析。本文首先对人工智能机器学习深度学习、相关分析因果分析回归分析 ... [详细]
  • 中文分词_中文分词技术小结几大分词引擎的介绍与比较
    篇首语:本文由编程笔记#小编为大家整理,主要介绍了中文分词技术小结几大分词引擎的介绍与比较相关的知识,希望对你有一定的参考价值。笔者想说:觉得英文与中文分词有很大的区别, ... [详细]
  • 关于爬虫内容的分享,我会分成两篇,六个部分来分享,分别是: 我们的目的是什么内容从何而 ... [详细]
  • 本文探讨了 TypeScript 中泛型的重要性和应用场景,通过多个实例详细解析了泛型如何提升代码的复用性和类型安全性。 ... [详细]
  • 在机器学习领域,深入探讨了概率论与数理统计的基础知识,特别是这些理论在数据挖掘中的应用。文章重点分析了偏差(Bias)与方差(Variance)之间的平衡问题,强调了方差反映了不同训练模型之间的差异,例如在K折交叉验证中,不同模型之间的性能差异显著。此外,还讨论了如何通过优化模型选择和参数调整来有效控制这一平衡,以提高模型的泛化能力。 ... [详细]
  • 数字图书馆近期展出了一批精选的Linux经典著作,这些书籍虽然部分较为陈旧,但依然具有重要的参考价值。如需转载相关内容,请务必注明来源:小文论坛(http://www.xiaowenbbs.com)。 ... [详细]
  • Android中将独立SO库封装进JAR包并实现SO库的加载与调用
    在Android开发中,将独立的SO库封装进JAR包并实现其加载与调用是一个常见的需求。本文详细介绍了如何将SO库嵌入到JAR包中,并确保在外部应用调用该JAR包时能够正确加载和使用这些SO库。通过这种方式,开发者可以更方便地管理和分发包含原生代码的库文件,提高开发效率和代码复用性。文章还探讨了常见的问题及其解决方案,帮助开发者避免在实际应用中遇到的坑。 ... [详细]
  • 本文探讨了利用Python实现高效语音识别技术的方法。通过使用先进的语音处理库和算法,本文详细介绍了如何构建一个准确且高效的语音识别系统。提供的代码示例和实验结果展示了该方法在实际应用中的优越性能。相关文件可从以下链接下载:链接:https://pan.baidu.com/s/1RWNVHuXMQleOrEi5vig_bQ,提取码:p57s。 ... [详细]
  • 在《Python编程基础》课程中,我们将深入探讨Python中的循环结构。通过详细解析for循环和while循环的语法与应用场景,帮助初学者掌握循环控制语句的核心概念和实际应用技巧。此外,还将介绍如何利用循环结构解决复杂问题,提高编程效率和代码可读性。 ... [详细]
author-avatar
陈杰倩平贵奕白
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有