通过 2d 条件使用 tf.where() 到 select 3d 张量并用键和值替换 2d 索引中的元素

using tf.where() to select 3d tensor by 2d conditions & replacing elements in a 2d indices with keys and values

题目中有2个问题。我对这两个问题感到困惑,因为 tensorflow 是一种静态编程语言(我真的很想回到 pytorch 或 chainer)。

我举两个例子。请在tensorflow代码中回答我或提供相关功能链接。

1) tf.where()

data0 = tf.zeros([2, 3, 4], dtype = tf.float32)
data1 = tf.ones([2, 3, 4], dtype = tf.float32)
cond = tf.constant([[0, 1, 1], [1, 0, 0]])
# cond.shape == (2, 3)
# tf.where() works for 1d condition with 2d data, 
# but not for 2d indices with 3d tensor
# currently, what I am doing is:
#    cond = tf.stack([cond] * 4, 2)
data = tf.where(cond > 0, data1, data0)
# data should be [[0., 1., 1.], [1., 0., 0.]]

(我不知道如何将cond广播到3d张量)

2) 改变二维张量中的元素

# all dtype == tf.int64
t2d = tf.Variable([[0, 1, 2], [3, 4, 5]])
k, v = tf.constant([[0, 2], [1, 0]]), tf.constant([-2, -3])
# TODO: change values at positions k to v
# I cannot do [t2d.copy()[i] = j for i, j in k, v]
t3d == [[[0, 1, -2], [3, 4, 5]],
        [[0, 1, 2], [-3, 4, 5]]]

在此先感谢您。 XD

这是两个完全不同的问题,它们可能应该这样发布,但无论如何。

1)

是的,您需要手动将所有输入广播到 [tf.where](https://www.tensorflow.org/api_docs/python/tf/where] if they are different. For what is worth, there is an (old) open issue about it, but so far implicit broadcasting it has not been implemented. You can use tf.stack like you suggest, although tf.tile 可能会更明显(并且可能会节省内存,尽管我不确定它是如何实现的真的):

cond = tf.tile(tf.expand_dims(cond, -1), (1, 1, 4))

或者简单地使用 tf.broadcast_to:

cond = tf.broadcast_to(tf.expand_dims(cond, -1), tf.shape(data1))

2)

这是一种方法:

import tensorflow as tf

t2d = tf.constant([[0, 1, 2], [3, 4, 5]])
k, v = tf.constant([[0, 2], [1, 0]]), tf.constant([-2, -3])
# Tile t2d
n = tf.shape(k)[0]
t2d_tile = tf.tile(tf.expand_dims(t2d, 0), (n, 1, 1))
# Add aditional coordinate to index
idx = tf.concat([tf.expand_dims(tf.range(n), 1), k], axis=1)
# Make updates tensor
s = tf.shape(t2d_tile)
t2d_upd = tf.scatter_nd(idx, v, s)
# Make updates mask
upd_mask = tf.scatter_nd(idx, tf.ones_like(v, dtype=tf.bool), s)
# Make final tensor
t3d = tf.where(upd_mask, t2d_upd, t2d_tile)
# Test
with tf.Session() as sess:
    print(sess.run(t3d))

输出:

[[[ 0  1 -2]
  [ 3  4  5]]

 [[ 0  1  2]
  [-3  4  5]]]