如何在急切模式下迭代 tf.tensor
How to iterate over a tf.tensor in eager mode
我正在尝试以急切模式迭代张量,但我做不到。
自然地,你会做这样的事情:
probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1, 2, 3]))
@tf.function
def iterate_tensor(probs, indexs):
return [output[label] for output, label in zip(probs, indexs)]
iterate_tensor(probs, indexs)
但这给出了错误 OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed:
我尝试的另一件事是:
probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1, 2, 3]))
@tf.function
def iterate_tensor(probs, indexs):
return tf.map_fn(lambda i: i[0][i[1]], (probs, indexs), dtype=(tf.int64, tf.int64))
iterate_tensor(probs, indexs)
给出错误ValueError: The two structures don't have the same nested structure.
这似乎有效:
probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1,1,1]))
@tf.function
def iterate_tensor(probs, indexs):
return tf.linalg.diag_part(tf.gather(probs, indexs, axis=1))
iterate_tensor(probs, indexs)
输出:<tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 5, 8])>
我正在尝试以急切模式迭代张量,但我做不到。
自然地,你会做这样的事情:
probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1, 2, 3]))
@tf.function
def iterate_tensor(probs, indexs):
return [output[label] for output, label in zip(probs, indexs)]
iterate_tensor(probs, indexs)
但这给出了错误 OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed:
我尝试的另一件事是:
probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1, 2, 3]))
@tf.function
def iterate_tensor(probs, indexs):
return tf.map_fn(lambda i: i[0][i[1]], (probs, indexs), dtype=(tf.int64, tf.int64))
iterate_tensor(probs, indexs)
给出错误ValueError: The two structures don't have the same nested structure.
这似乎有效:
probs = tf.convert_to_tensor(np.array([[1,2,3], [4,5,6], [7,8,9]]))
indexs = tf.convert_to_tensor(np.array([1,1,1]))
@tf.function
def iterate_tensor(probs, indexs):
return tf.linalg.diag_part(tf.gather(probs, indexs, axis=1))
iterate_tensor(probs, indexs)
输出:<tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 5, 8])>