作者:鱼儿什么都知道丶 | 来源:互联网 | 2023-09-04 03:25
值在哪里
rnn_size: 512
batch_size: 128
rnn_inputs: Tensor("embedding_lookup/Identity_1:0", shape=(?, ?, 128), dtype=float32)
sequence_length: Tensor("inputs_length:0", shape=(?,), dtype=int32)
cell_fw:
cell_bw:
获取 enc_state 值
enc_output, enc_state = tf.compat.v1.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length,
dtype=tf.float32)
enc_state 值在哪里
enc_state: LSTMStateTuple(c=, h=)
TF1代码:
initial_state = tf.contrib.seq2seq.DynamicAttentionWrapperState(enc_state,
_zero_state_tensors(rnn_size,
batch_size,
tf.float32))
转换为 TF2
initial_state = tfa.seq2seq.AttentionWrapper(enc_state,_zero_state_tensors(rnn_size, batch_size, tf.float32))
获取错误:
TypeError Traceback (most recent call last)
in ()
8 threshold)
9 model = build_graph(keep_probability, rnn_size, num_layers, batch_size,
---> 10 learning_rate, embedding_size, direction)
11 train(model, epochs, log_string)
6 frames
/usr/local/lib/python3.7/dist-packages/typeguard/__init__.py in check_type(argname, value, expected_type, memo)
596 raise TypeError(
597 'type of {} must be {}; got {} instead'.
--> 598 format(argname, qualified_name(expected_type), qualified_name(value)))
599 elif isinstance(expected_type, TypeVar):
600 # Only happens on <3.6
TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead
您还可以解释错误的最后一行,即
TypeError: type of argument "cell" must be tensorflow.python.keras.engine.base_layer.Layer; got tensorflow.python.keras.layers.legacy_rnn.rnn_cell_impl.LSTMStateTuple instead