Tensorflow:在 tensorflow 中创建等效的 torch.gather()

Tensorflow: Create the torch.gather() equivalent in tensorflow

我想在 TensorFlow 2.X 中复制 torch.gather() 函数。 我有一个张量 A(形状:[2, 4, 3])和一个相应的索引张量 I(形状:[2,2,3]) . 使用 torch.gather() 产生以下结果:

A = torch.tensor([[[10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
I = torch.tensor([[[0,1,0], [1,2,1]],
                  [[2,1,2], [1,0,1]]])
torch.gather(A, 1, I)

>
tensor([[[10,   200,   30], [100, 2000, 300]],
         [5000, 600, 7000], [500,   60, 700]]])

我试过使用 tf.gather(),但这并没有产生类似 pytorch 的结果。我也试过 tf.gather_nd(),但找不到合适的解决方案。

我找到了 this Whosebug post,但这似乎对我不起作用。

编辑: 使用 tf.gather_nd(A, I) 时,我得到以下结果:

tf.gather_nd(A, I)

>
[[100, 6000],
 [  0,   60]]

tf.gather(A, I) 的结果相当冗长。它的形状为 [2, 2, 3, 4, 3]

torch.gathertf.gather_nd 的工作方式不同,因此在使用相同的索引张量时会产生不同的结果(在某些情况下也会返回错误)。这就是索引张量必须看起来像才能获得相同的结果:

import tensorflow as tf

A = tf.constant([[
                   [10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
I = tf.constant([[[
                  [0,0,0],
                  [0,1,1], 
                  [0,0,2],
                ],[
                  [0,1,0],
                  [0,2,1],
                  [0,1,2],  
                ]], 
                 [[
                  [1,2,0],
                  [1,1,1],
                  [1,2,2],  
                ], 
                  [
                  [1,1,0],
                  [1,0,1],
                  [1,1,2],  
                ]]])


print(tf.gather_nd(A, I))
tf.Tensor(
[[[  10  200   30]
  [ 100 2000  300]]

 [[5000  600 7000]
  [ 500   60  700]]], shape=(2, 2, 3), dtype=int32)

所以,问题实际上是您如何计算指数,或者它们总是 hard-coded?另外,请查看此 post 以了解这两种操作的差异。

至于你链接的 post 对你不起作用,你只需要转换索引,一切都应该没问题:

def torch_gather(x, indices, gather_axis):

    all_indices = tf.where(tf.fill(indices.shape, True))
    gather_locations = tf.reshape(indices, [indices.shape.num_elements()])

    gather_indices = []
    for axis in range(len(indices.shape)):
        if axis == gather_axis:
            gather_indices.append(tf.cast(gather_locations, dtype=tf.int64))
        else:
            gather_indices.append(tf.cast(all_indices[:, axis], dtype=tf.int64))

    gather_indices = tf.stack(gather_indices, axis=-1)
    gathered = tf.gather_nd(x, gather_indices)
    reshaped = tf.reshape(gathered, indices.shape)
    return reshaped

I = tf.constant([[[0,1,0], [1,2,1]],
                  [[2,1,2], [1,0,1]]])
A = tf.constant([[
                   [10,20,30], [100,200,300], [1000,2000,3000]],
                  [[50,60,70], [500,600,700], [5000,6000,7000]]])
print(torch_gather(A, I, 1))
tf.Tensor(
[[[  10  200   30]
  [ 100 2000  300]]

 [[5000  600 7000]
  [ 500   60  700]]], shape=(2, 2, 3), dtype=int32)

您也可以将此作为等同于 torch.gather:

import random
import numpy as np
import tensorflow as tf
import torch

# torch.gather equivalent
def tf_gather(x: tf.Tensor, indices: tf.Tensor, axis: int) -> tf.Tensor:
    complete_indices = np.array(np.where(indices > -1))
    complete_indices[axis] = tf.reshape(indices, [-1])
    flat_ind = np.ravel_multi_index(tuple(complete_indices), x.shape)
    return tf.reshape(tf.gather(tf.reshape(x, [-1]), flat_ind), indices.shape)


# ======= test program ========
if __name__ == '__main__':

    a = np.random.rand(2, 5, 3, 4)
    dim = 2  # 0 <= dim < len(a.shape))

    ind = np.expand_dims(np.argmax(a, axis=dim), axis=dim)

    # ========== np: groundtruth ==========
    np_max = np.expand_dims(np.max(a, axis=dim), axis=dim)

    # ========= torch: gather =========
    torch_max = torch.gather(torch.tensor(a), dim=dim, index=torch.tensor(ind))

    # ========= tensorflow: torch-like gather =========
    tf_max = tf_gather(tf.convert_to_tensor(a), axis=dim, indices=tf.convert_to_tensor(ind))

    keepdim = False
    if not keepdim:
        np_max = np.squeeze(np_max, axis=dim)
        torch_max = torch.squeeze(torch_max, dim=dim)
        tf_max = tf.squeeze(tf_max, axis=dim)

    # print('np_max:\n', np_max)
    # print('torch_max:\n', torch_max)
    # print('tf_max:\n', tf_max)

    assert np.allclose(np_max, torch_max.numpy()), '[1m[31mError with torch[0m'
    assert np.allclose(np_max, tf_max.numpy()), '[1m[31mError with tensorflow[0m'

    print('[1m[32mSuccess![0m')