如何在 tensorflow-js 中获取张量中特定值的索引?

How to get indices of a specific value in a tensor in tensorflow-js?

例如,如果我有一个 [[1,3],[2,1]] 的二维张量,我如何获得值 1 的索引? (应该return[[0,0],[1,1]])。

我查看了 tf.where,但 API 很复杂,我认为这不能解决我的问题

您可以使用 tf.whereAsync 来实现。

只需创建一个掩码,检查输入 Tensor 中的值是否为 1 并将它们转换为布尔值。

掩码:

"Tensor
    [[true , false],
     [false, true ]]"

tf.whereAsync() returns 条件的真实元素的坐标,在这种情况下来自掩码。

(async function getData() {
  const x = tf.tensor2d([[1, 3], [2, 1]])

  const mask = x.equal([1]).asType('bool');
  const coords = await tf.whereAsync(mask);
  coords.print();
}());

输入:

"Tensor
    [[1, 3],
     [2, 1]]"

输出:

"Tensor
    [[0, 0],
     [1, 1]]"