在for循环以外的批处理数据中查找数字首次出现索引的有效方法
Efficient method for finding index of first occurrence of a number in batched data other than for loop
我正在执行一项任务,其中我有批量存储的帧形式的数据。批次的维度就像 (batch_size,400),我想在每个 400 长度的帧中找到第一次出现数字 1 的索引。
目前我正在对批量大小使用 for 循环,但由于数据非常大,因此非常耗时
在 tensorflow 或 numpy 中使用某些矩阵运算的任何其他有效方法都会
在 TensorFlow 中:
import tensorflow as tf
def index_of_first_tf(batch, value):
eq = tf.equal(batch, value)
has_value = tf.reduce_any(eq, axis=-1)
_, idx = tf.math.top_k(tf.cast(eq, tf.int8))
idx = tf.squeeze(idx, -1)
return tf.where(has_value, idx, -tf.ones_like(idx))
在 NumPy 中:
import numpy as np
def index_of_first_np(batch, value):
eq = np.equal(batch, value)
has_value = np.any(eq, axis=-1)
idx = np.argmax(eq, axis=-1)
idx[~has_value] = -1
return idx
测试:
import tensorflow as tf
batch = [[0, 1, 2, 3],
[1, 2, 1, 0],
[0, 2, 3, 4]]
value = 1
print(index_of_first_np(batch, value))
# [ 1 0 -1]
with tf.Graph().as_default(), tf.Session() as sess:
print(sess.run(index_of_first_tf(batch, value)))
# [ 1 0 -1]
我正在执行一项任务,其中我有批量存储的帧形式的数据。批次的维度就像 (batch_size,400),我想在每个 400 长度的帧中找到第一次出现数字 1 的索引。
目前我正在对批量大小使用 for 循环,但由于数据非常大,因此非常耗时
在 tensorflow 或 numpy 中使用某些矩阵运算的任何其他有效方法都会
在 TensorFlow 中:
import tensorflow as tf
def index_of_first_tf(batch, value):
eq = tf.equal(batch, value)
has_value = tf.reduce_any(eq, axis=-1)
_, idx = tf.math.top_k(tf.cast(eq, tf.int8))
idx = tf.squeeze(idx, -1)
return tf.where(has_value, idx, -tf.ones_like(idx))
在 NumPy 中:
import numpy as np
def index_of_first_np(batch, value):
eq = np.equal(batch, value)
has_value = np.any(eq, axis=-1)
idx = np.argmax(eq, axis=-1)
idx[~has_value] = -1
return idx
测试:
import tensorflow as tf
batch = [[0, 1, 2, 3],
[1, 2, 1, 0],
[0, 2, 3, 4]]
value = 1
print(index_of_first_np(batch, value))
# [ 1 0 -1]
with tf.Graph().as_default(), tf.Session() as sess:
print(sess.run(index_of_first_tf(batch, value)))
# [ 1 0 -1]