using rnn_cell inside of tf.while getting ValueError: The two structures don't have the same number of elements
using rnn_cell inside of tf.while getting ValueError: The two structures don't have the same number of elements
给定 data = tf.placeholder(tf.float32, [2, None, 3])
(batch_size * time_step * feature_size),理想情况下我想 tf.unstack(data, axis = 1)
得到许多张量,每个张量都有[2,3]
的形状,所以稍后将它们送入带有 for 循环的 rnn,例如
for rnn_input in rnn_inputs:
state = rnn_cell(rnn_input, state)
使用像 tf.nn.dynamic_rnn 这样的高级 API 是不合适的 table 所以我创建了一个像
这样的变通方法
import tensorflow as tf
data = tf.placeholder(tf.float32, [2, None, 3])
step_number = tf.placeholder(tf.int32, None)
loop_counter_inital = tf.constant(0)
initi_state = tf.zeros([2,3], tf.float32)
def while_condition(loop_counter, rnn_states):
return loop_counter < step_number
def while_body(loop_counter, rnn_states):
loop_counter_current = loop_counter
current_states = tf.gather_nd(data, tf.stack([tf.range(0, 2), tf.zeros([2], tf.int32)+loop_counter_current], axis=1))
cell = tf.nn.rnn_cell.BasicRNNCell(3)
rnn_states = cell(current_states, rnn_states)
return [loop_counter_current, rnn_states]
_, _states = tf.while_loop(while_condition, while_body,
loop_vars=[loop_counter_inital, initi_state],
shape_invariants=[loop_counter_inital.shape, tf.TensorShape([2, 3])])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(_states, feed_dict={data:[[[3,1,6],[4,1,2]],[[5,8,1],[0,5,2]]], step_number:2 }))
想法是遍历 data
的每个二维张量中的每一行,以获得每个时间步长的特征。我收到一个错误
First structure (2 elements): [<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, <tf.Tensor 'while/Identity_1:0' shape=(2, 3) dtype=float32>]
Second structure (3 elements): [<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, (<tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>, <tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>)]
似乎有一些相关的帖子。 None 确实有效。有人可以帮忙吗?
您需要知道每个 BasicRNNCell
都将使用签名 (output, next_state) = call(input, state)
实现 call()
。这意味着您的结果是形状列表 ((?,unit),(?,unit))
。所以你需要做如下操作。
rnn_states = cell(current_states, rnn_states)[1]
你这里也犯了一个错误。您忘记在 loop_counter_current
.
中加 1
return [loop_counter_current+1, rnn_states]
添加
第一个结构体表示你传入的参数loop_vars
的初始值,其中包含loop_counter_inital
和initi_state
的初始值。所以它的结构对应如下
[
<tf.Tensor 'while/Identity:0' shape=() dtype=int32> #---> loop_counter_inital
, <tf.Tensor 'while/Identity_1:0' shape=(2, 3) dtype=float32> #---> initi_state
]
第二个结构体表示循环后的参数loop_vars
。根据之前的错误,其结果对应如下。
[
<tf.Tensor 'while/Identity:0' shape=() dtype=int32> #---> loop_counter_inital
, (<tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32> #---> output
, <tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>) #---> initi_state
]
给定 data = tf.placeholder(tf.float32, [2, None, 3])
(batch_size * time_step * feature_size),理想情况下我想 tf.unstack(data, axis = 1)
得到许多张量,每个张量都有[2,3]
的形状,所以稍后将它们送入带有 for 循环的 rnn,例如
for rnn_input in rnn_inputs:
state = rnn_cell(rnn_input, state)
使用像 tf.nn.dynamic_rnn 这样的高级 API 是不合适的 table 所以我创建了一个像
这样的变通方法import tensorflow as tf
data = tf.placeholder(tf.float32, [2, None, 3])
step_number = tf.placeholder(tf.int32, None)
loop_counter_inital = tf.constant(0)
initi_state = tf.zeros([2,3], tf.float32)
def while_condition(loop_counter, rnn_states):
return loop_counter < step_number
def while_body(loop_counter, rnn_states):
loop_counter_current = loop_counter
current_states = tf.gather_nd(data, tf.stack([tf.range(0, 2), tf.zeros([2], tf.int32)+loop_counter_current], axis=1))
cell = tf.nn.rnn_cell.BasicRNNCell(3)
rnn_states = cell(current_states, rnn_states)
return [loop_counter_current, rnn_states]
_, _states = tf.while_loop(while_condition, while_body,
loop_vars=[loop_counter_inital, initi_state],
shape_invariants=[loop_counter_inital.shape, tf.TensorShape([2, 3])])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(_states, feed_dict={data:[[[3,1,6],[4,1,2]],[[5,8,1],[0,5,2]]], step_number:2 }))
想法是遍历 data
的每个二维张量中的每一行,以获得每个时间步长的特征。我收到一个错误
First structure (2 elements): [<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, <tf.Tensor 'while/Identity_1:0' shape=(2, 3) dtype=float32>]
Second structure (3 elements): [<tf.Tensor 'while/Identity:0' shape=() dtype=int32>, (<tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>, <tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>)]
似乎有一些相关的帖子。 None 确实有效。有人可以帮忙吗?
您需要知道每个 BasicRNNCell
都将使用签名 (output, next_state) = call(input, state)
实现 call()
。这意味着您的结果是形状列表 ((?,unit),(?,unit))
。所以你需要做如下操作。
rnn_states = cell(current_states, rnn_states)[1]
你这里也犯了一个错误。您忘记在 loop_counter_current
.
return [loop_counter_current+1, rnn_states]
添加
第一个结构体表示你传入的参数loop_vars
的初始值,其中包含loop_counter_inital
和initi_state
的初始值。所以它的结构对应如下
[
<tf.Tensor 'while/Identity:0' shape=() dtype=int32> #---> loop_counter_inital
, <tf.Tensor 'while/Identity_1:0' shape=(2, 3) dtype=float32> #---> initi_state
]
第二个结构体表示循环后的参数loop_vars
。根据之前的错误,其结果对应如下。
[
<tf.Tensor 'while/Identity:0' shape=() dtype=int32> #---> loop_counter_inital
, (<tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32> #---> output
, <tf.Tensor 'while/basic_rnn_cell/Tanh:0' shape=(2, 3) dtype=float32>) #---> initi_state
]