如何在 tensorflow-js 中获取张量中特定值的索引?
How to get indices of a specific value in a tensor in tensorflow-js?
例如,如果我有一个 [[1,3],[2,1]]
的二维张量,我如何获得值 1
的索引? (应该return[[0,0],[1,1]]
)。
我查看了 tf.where
,但 API 很复杂,我认为这不能解决我的问题
您可以使用 tf.whereAsync 来实现。
只需创建一个掩码,检查输入 Tensor 中的值是否为 1 并将它们转换为布尔值。
掩码:
"Tensor
[[true , false],
[false, true ]]"
tf.whereAsync()
returns 条件的真实元素的坐标,在这种情况下来自掩码。
(async function getData() {
const x = tf.tensor2d([[1, 3], [2, 1]])
const mask = x.equal([1]).asType('bool');
const coords = await tf.whereAsync(mask);
coords.print();
}());
输入:
"Tensor
[[1, 3],
[2, 1]]"
输出:
"Tensor
[[0, 0],
[1, 1]]"
例如,如果我有一个 [[1,3],[2,1]]
的二维张量,我如何获得值 1
的索引? (应该return[[0,0],[1,1]]
)。
我查看了 tf.where
,但 API 很复杂,我认为这不能解决我的问题
您可以使用 tf.whereAsync 来实现。
只需创建一个掩码,检查输入 Tensor 中的值是否为 1 并将它们转换为布尔值。
掩码:
"Tensor
[[true , false],
[false, true ]]"
tf.whereAsync()
returns 条件的真实元素的坐标,在这种情况下来自掩码。
(async function getData() {
const x = tf.tensor2d([[1, 3], [2, 1]])
const mask = x.equal([1]).asType('bool');
const coords = await tf.whereAsync(mask);
coords.print();
}());
输入:
"Tensor
[[1, 3],
[2, 1]]"
输出:
"Tensor
[[0, 0],
[1, 1]]"