Pytorch:如何沿轴获取切片的平均值,其中切片索引值在不同的张量上定义并且梯度仅流入切片

Pytorch: How to get mean of slices along an axis where the slices indices value are defined on a different tensor and gradients only flow into slices

我想取沿张量轴的平均值,张量由包含多个切片的张量定义。

所以这将是我的样本张量,我想从中获取切片的平均值,沿着第一维

import torch

sample = torch.arange(0,40).reshape(10,-1)
sample
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31],
        [32, 33, 34, 35],
        [36, 37, 38, 39]])

这将是包含开始和结束索引的张量,我想得到

的平均值
mean_slices = torch.tensor([
    [3, 5],
    [1, 8],
    [4, 8],
    [6, 9],
])

由于 Pytorch 没有参差不齐的张量,我使用此处描述的技巧

通过我想从中获取平均值的整个轴计算 cumsum,然后检索每个结束切片索引的行,并从起始切片索引之前的 cumsum 行中减去。最后将结果除以切片的长度。

padded = torch.nn.functional.pad(
    sample.cumsum(dim=0), (0, 0, 1, 0)
)
padded
tensor([[  0,   0,   0,   0],
        [  0,   1,   2,   3],
        [  4,   6,   8,  10],
        [ 12,  15,  18,  21],
        [ 24,  28,  32,  36],
        [ 40,  45,  50,  55],
        [ 60,  66,  72,  78],
        [ 84,  91,  98, 105],
        [112, 120, 128, 136],
        [144, 153, 162, 171],
        [180, 190, 200, 210]])
pools = torch.diff(
    padded[mean_slices], dim=1
).squeeze()/torch.diff(mean_slices, dim=1)

pools
tensor([[14., 15., 16., 17.],
        [16., 17., 18., 19.],
        [22., 23., 24., 25.],
        [28., 29., 30., 31.]])

这个解决方案的唯一问题是,最初我只是想获得由切片定义的特定行的平均值,虽然我当前的解决方案是这样做的,但计算也涉及切片索引之前的所有行.因此向后传递可能无法按预期工作。

这个猜测是否正确?

是否有更精确且计算效率更高的方法来计算张量中定义的切片的平均值?

为什么你认为梯度计算包括切片外的像素值?

当您使用 torch.cumsum 计算切片上的总和时,您将切片外的所有值相加 两次 :一次估计它们的总和,存储在之前的行中切片和第二次对切片求和,这些值将此值存储在切片的最后一行。 最重要的是你从最后一行中减去row-before-first:也就是说,你消除所有值之外的总和从等式中切片。因此这些值对计算和梯度没有影响。

这是一个简单的例子:
考虑函数 f(x,y,z) = x + y + z - zfw.r.tz的梯度是多少? z一经消去,对f的值和梯度没有影响。

底线:你的向后传递是正确的,并且不受切片外值的影响。


关于更有效的实施:
如果最小切片起始索引很高(即 sample 的很大一部分被所有切片忽略),您可能会完全删除它:

mn, mx = mean_slices.min(), mean_slices.max()  # only the relevant par of sample
padded_ef = torch.nn.functional.pad(
    sample[mn:mx, :].cumsum(dim=0), (0, 0, 1, 0)
)
# sum the slices - need to shift the index
pools_ef = torch.diff(
    padded_ef[mean_slices-mn], dim=1
).squeeze()/torch.diff(mean_slices, dim=1)

具有相同 pools 的结果,但如果切片被“打包”,则可能涉及较少的 sample 元素。
但是,除非 sample 非常大 w.r.t 切片,否则我认为这不会在 运行 时间内给您带来显着的提升。