如何在高维中做tensorflow segment_max

How to do tensorflow segment_max in high dimension

我希望能够调用tensorflow的tf.math.unsorted_segment_max 在大小为 [N, s, K] 的数据张量上。 N 是通道数,K 是 filters/feature 贴图的数量。 s 是单通道数据样本的大小。我有 segment_ids 的大小。例如,假设我的样本大小是 s=6,并且我想对两个元素做一个最大值(就像做通常的最大池化一样,所以在整个数据张量的第二个 s 维上)。那么我的segment_ids等于[0,0,1,1,2,2].

我试过运行宁

tf.math.unsorted_segment_max(data, segment_ids, num_segments)

为 segment_ids 扩展了 0 和 2 维,但由于随后重复了段 ID,结果当然是大小 [3] 而不是我想要的 [N,3,K]像。

所以我的问题是,如何构造一个合适的 segment_ids 张量来实现我想要的? IE。要根据原始 s 大小的 segment_ids 张量完成分段最大值,但在每个维度上分别进行?

基本上,回到示例,给定 1D 段 ID 列表 seg_id=[0,0,1,1,2,2],我想构造类似 segment_ids张量,其中:

segment_ids[i,:,j] = seg_id + num_segments*(i*K + j) 

所以当调用 tf.math.(unsorted_)segment_max 使用这个张量作为段 ids 时,我将得到大小为 [N, 3, K] 的结果,效果相同就好像每个数据 [x,:,y] 分别 运行 segment_max 并适当地堆叠结果。

任何方式都可以,只要它适用于 tensorflow。我想 tf.tile、tf.reshape 或 tf.concat 的组合应该可以解决问题,但我不知道如何按什么顺序。 另外,有没有更直接的方法呢?无需在每个 "pooling" 步骤中调整 segment_ids?

我还没有想出任何更优雅的解决方案,但至少我想出了如何结合平铺、重塑和转置来做到这一点。 我首先(使用提到的三个操作,请参见下面的代码)构造一个与数据大小相同的张量,并在张量中重复(但移位)原始 seg_id 向量的条目:

m = tf.reduce_max(seg_id) + 1
a = tf.constant([i*m for i in range(N*K) for j in range(s)])
b = tf.tile(seg_id, N*K)
#now reshape it:
segment_ids = tf.transpose(tf.reshape(a+b, shape=[N,K,s]), perm=[0,2,1])

这样就可以直接调用segment_max函数了:

result = tf.unsorted_segment_max(data=data, segment_ids=segment_ids, num_segments=m*N*K)

它也做了我想要的,只是结果变平了,如果需要的话需要再次整形。 等价地,您可以将原始数据张量重塑为一维张量,并用 a+b 在其上计算 segment_max 作为 segment_ids。如果需要,再次重塑最终结果。

这就是感觉路途遥远的结果...有没有更好的办法?我也不知道所描述的方法是否适合在 NN 内部使用,在反向传播期间......导数或计算图是否存在问题? 有没有人对如何解决这个问题有更好的想法?

我想你可以用tf.nn.pool实现你想要的:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    data = tf.constant([
        [
            [ 1, 12, 13],
            [ 2, 11, 14],
            [ 3, 10, 15],
            [ 4,  9, 16],
            [ 5,  8, 17],
            [ 6,  7, 18],
        ],
        [
            [19, 30, 31],
            [20, 29, 32],
            [21, 28, 33],
            [22, 27, 34],
            [23, 26, 35],
            [24, 25, 36],
        ]], dtype=tf.int32)
    segments = tf.constant([0, 0, 1, 1, 2, 2], dtype=tf.int32)
    pool = tf.nn.pool(data, [2], 'MAX', 'VALID', strides=[2])
    print(sess.run(pool))

输出:

[[[ 2 12 14]
  [ 4 10 16]
  [ 6  8 18]]

 [[20 30 32]
  [22 28 34]
  [24 26 36]]]

如果你真的想要我们tf.unsorted_segment_max, you can do it as you suggest in 。这是避免转置并包括最终重塑的等效公式:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    data = ...
    segments = ...
    shape = tf.shape(data)
    n, k = shape[0], shape[2]
    m = tf.reduce_max(segments) + 1
    grid = tf.meshgrid(tf.range(n) * m * k,
                       segments * k,
                       tf.range(k), indexing='ij')
    segment_nd = tf.add_n(grid)
    segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
    result = tf.reshape(segmented, [n, m, k])
    print(sess.run(result))
    # Same output

就反向传播而言,这两种方法在神经网络中应该都能正常工作。

编辑:就性能而言,池化似乎比分段总和更具可扩展性(正如人们所期望的那样):

import tensorflow as tf
import numpy as np

def method_pool(data, window):
    return tf.nn.pool(data, [window], 'MAX', 'VALID', strides=[window])

def method_segment(data, window):
    shape = tf.shape(data)
    n, s, k = shape[0], shape[1], shape[2]
    segments = tf.range(s) // window
    m = tf.reduce_max(segments) + 1
    grid = tf.meshgrid(tf.range(n) * m * k,
                       segments * k,
                       tf.range(k), indexing='ij')
    segment_nd = tf.add_n(grid)
    segmented = tf.unsorted_segment_max(data, segment_nd, n * m * k)
    return tf.reshape(segmented, [n, m, k])

np.random.seed(100)
rand_data = np.random.rand(300, 500, 100)
window = 10
with tf.Graph().as_default(), tf.Session() as sess:
    data = tf.constant(rand_data, dtype=tf.float32)
    res_pool = method_pool(data, n)
    res_segment = method_segment(data, n)
    print(np.allclose(*sess.run([res_pool, res_segment])))
    # True
    %timeit sess.run(res_pool)
    # 2.56 ms ± 80.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    %timeit sess.run(res_segment)
    # 514 ms ± 6.29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)