热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

将tf1中的代码转换为tf2时出错

值在哪里rnn_size:512batch_size:128rnn_inputs:Tensor(embedding_lookup/Iden

值在哪里

rnn_size: 512
batch_size: 128
rnn_inputs: Tensor("embedding_lookup/Identity_1:0", shape=(?, ?, 128), dtype=float32)
sequence_length: Tensor("inputs_length:0", shape=(?,), dtype=int32)
cell_fw:
cell_bw:

获取 enc_state 值

enc_output, enc_state = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length,
dtype=tf.float32)

enc_state 值在哪里

enc_state: LSTMStateTuple(c=, h=)

TF1代码:

initial_state = tf.contrib.seq2seq.DynamicAttentionWrapperState(enc_state,
_zero_state_tensors(rnn_size,
batch_size,
tf.float32))

转换为 TF2

initial_state = tfa.seq2seq.AttentionWrapper(enc_state,_zero_state_tensors(rnn_size, batch_size, tf.float32))

获取错误:


TypeError Traceback (most recent call last)
in ()
8 threshold)
9 model = build_graph(keep_probability, rnn_size, num_layers, batch_size,
---> 10 learning_rate, embedding_size, direction)
11 train(model, epochs, log_string)
6 frames
/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py in check_type(argname, value, expected_type, memo)
596 raise TypeError(
597 'type of {} must be {}; got {} instead'.
--> 598 format(argname, qualified_name(expected_type), qualified_name(value)))
599 elif isinstance(expected_type, TypeVar):
600 # Only happens on <3.6
TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead

您还可以解释错误的最后一行,即

TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead


推荐阅读
author-avatar
鱼儿什么都知道丶
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有