Tensorflow 2 - tensor_scatter_nd_update 中的 'index depth' 是什么?
Tensorflow 2 - what is 'index depth' in tensor_scatter_nd_update?
请解释什么是 tf.tensor_scatter_nd_update 的索引深度。
tf.tensor_scatter_nd_update(
tensor, indices, updates, name=None
)
为什么 indices 对于一维张量是二维的?
indices has at least two axes, the last axis is the depth of the index vectors.
For a higher rank input tensor scalar updates can be inserted by using an index_depth that matches tf.rank(tensor):
tensor = [0, 0, 0, 0, 0, 0, 0, 0] # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]] # num_updates == 4, index_depth == 1 # <--- what is depth and why 2D for 1D tensor?
updates = [9, 10, 11, 12] # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2
updates = [5, 10] # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
对于 indices
,index depth
是 索引向量 的大小或 长度 。例如:
indicesA = [[1], [3], [4], [7]] # index vector with 1 element: index_depth = 1
indicesB = [[0, 1], [2, 0]] # index vector with 2 element: index_depth = 2
索引的原因是2D
是为了保存两个信息,一个是更新的长度(num_updates
)和索引向量的长度。需要满足两点:
indices
的 index depth
必须等于 input
张量 的 rank
updates
的长度必须等于indices
的长度
因此,在示例代码中
# tf.rank(tensor) == 1
tensor = [0, 0, 0, 0, 0, 0, 0, 0]
# num_updates == 4, index_depth == 1 | tf.rank(indices).numpy() == 2
indices = [[1], [3], [4], [7]]
# num_updates == 4 | tf.rank(output).numpy() == 1
updates = [9, 10, 11, 12]
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
tf.Tensor([ 0 9 0 10 11 0 0 12], shape=(8,), dtype=int32)
还有
# tf.rank(tensor) == 2
tensor = [[1, 1], [1, 1], [1, 1]]
# num_updates == 2, index_depth == 2 | tf.rank(indices).numpy() == 2
indices = [[0, 1], [2, 0]]
# num_updates == 2 | tf.rank(output).numpy() == 2
updates = [5, 10]
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
tf.Tensor(
[[ 1 5]
[ 1 1]
[10 1]], shape=(3, 2), dtype=int32)
num_updates, index_depth = tf.convert_to_tensor(indices).shape.as_list()
[num_updates, index_depth]
[2, 2]
请解释什么是 tf.tensor_scatter_nd_update 的索引深度。
tf.tensor_scatter_nd_update(
tensor, indices, updates, name=None
)
为什么 indices 对于一维张量是二维的?
indices has at least two axes, the last axis is the depth of the index vectors. For a higher rank input tensor scalar updates can be inserted by using an index_depth that matches tf.rank(tensor):
tensor = [0, 0, 0, 0, 0, 0, 0, 0] # tf.rank(tensor) == 1
indices = [[1], [3], [4], [7]] # num_updates == 4, index_depth == 1 # <--- what is depth and why 2D for 1D tensor?
updates = [9, 10, 11, 12] # num_updates == 4
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
tensor = [[1, 1], [1, 1], [1, 1]] # tf.rank(tensor) == 2
indices = [[0, 1], [2, 0]] # num_updates == 2, index_depth == 2
updates = [5, 10] # num_updates == 2
print(tf.tensor_scatter_nd_update(tensor, indices, updates))
对于 indices
,index depth
是 索引向量 的大小或 长度 。例如:
indicesA = [[1], [3], [4], [7]] # index vector with 1 element: index_depth = 1
indicesB = [[0, 1], [2, 0]] # index vector with 2 element: index_depth = 2
索引的原因是2D
是为了保存两个信息,一个是更新的长度(num_updates
)和索引向量的长度。需要满足两点:
indices
的index depth
必须等于input
张量 的 rank
updates
的长度必须等于indices
的长度
因此,在示例代码中
# tf.rank(tensor) == 1
tensor = [0, 0, 0, 0, 0, 0, 0, 0]
# num_updates == 4, index_depth == 1 | tf.rank(indices).numpy() == 2
indices = [[1], [3], [4], [7]]
# num_updates == 4 | tf.rank(output).numpy() == 1
updates = [9, 10, 11, 12]
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
tf.Tensor([ 0 9 0 10 11 0 0 12], shape=(8,), dtype=int32)
还有
# tf.rank(tensor) == 2
tensor = [[1, 1], [1, 1], [1, 1]]
# num_updates == 2, index_depth == 2 | tf.rank(indices).numpy() == 2
indices = [[0, 1], [2, 0]]
# num_updates == 2 | tf.rank(output).numpy() == 2
updates = [5, 10]
output = tf.tensor_scatter_nd_update(tensor, indices, updates)
tf.Tensor(
[[ 1 5]
[ 1 1]
[10 1]], shape=(3, 2), dtype=int32)
num_updates, index_depth = tf.convert_to_tensor(indices).shape.as_list()
[num_updates, index_depth]
[2, 2]