将 2D 函数应用于 4D 张量 PyTorch

Applying a 2D function to a 4D tensor PyTorch

我得到一个采用矩阵的二维函数 - 形状为 (28, 28) 的二维张量 我得到了一个张量,比方说 (64, 10, 28, 28) - 这是一个包含一批 64 张图像的张量,这些图像通过(10 个内核)conv2d 层。

现在,我想激活张量的最后两个维度,即 (28,28) 位,一个二维函数。

现在我以一种非常低效的方式做到了:

def activation_func(input):

    for batch_idx in range(input.shape[0]):
        for channel_inx in range(input.shape[1]):       
            input[batch_idx][channel_inx] = 2D_function(input[batch_idx][channel_inx])

    return input

正如我所注意到的,这是非常低效的。 有什么方法可以有效地做到这一点?

如有必要,我可以编写完整的代码

编辑:

def 2D_function(input):
    global indices # yes I know, I will remove this global stuff later
    # indices = [(i, j) for i in range(1, 28, 4) for j in range(1, 28, 4)]

    for x, y in indices:
        relu_decision = relu(input[x, y]) # standard relu - relu(x)=(x>1)*x
        if not relu_decision:
            # zero out the patch
            input[x - 1: x + 3, y - 1: y + 3] = 0
    return input

在这种情况下,我使用克罗内克乘积技巧:

import torch

torch.set_printoptions(linewidth=200)  # you can better see how the mask is shaped

# simulating an input
input = torch.rand(1, 1, 28, 28) - 0.5


ids = torch.meshgrid((torch.arange(1, 28, 4), torch.arange(1, 28, 4)))

# note that relu(x) = (x > 0.) * x, so adjust it to your needs
relus = torch.nn.functional.relu(input[(slice(None), slice(None), *ids)]).to(bool)

A = torch.ones(4, 4)
# generate a block matrix with ones in positions where blocks are set to 0 in correspondence of relus = 0
mask = torch.kron(relus, A)
print(mask.shape)
output = input * mask

print(mask[0, 0])
print(output[0, 0])