将 2D 函数应用于 4D 张量 PyTorch
Applying a 2D function to a 4D tensor PyTorch
我得到一个采用矩阵的二维函数 - 形状为 (28, 28) 的二维张量
我得到了一个张量,比方说 (64, 10, 28, 28) - 这是一个包含一批 64 张图像的张量,这些图像通过(10 个内核)conv2d 层。
现在,我想激活张量的最后两个维度,即 (28,28) 位,一个二维函数。
现在我以一种非常低效的方式做到了:
def activation_func(input):
for batch_idx in range(input.shape[0]):
for channel_inx in range(input.shape[1]):
input[batch_idx][channel_inx] = 2D_function(input[batch_idx][channel_inx])
return input
正如我所注意到的,这是非常低效的。
有什么方法可以有效地做到这一点?
如有必要,我可以编写完整的代码
编辑:
def 2D_function(input):
global indices # yes I know, I will remove this global stuff later
# indices = [(i, j) for i in range(1, 28, 4) for j in range(1, 28, 4)]
for x, y in indices:
relu_decision = relu(input[x, y]) # standard relu - relu(x)=(x>1)*x
if not relu_decision:
# zero out the patch
input[x - 1: x + 3, y - 1: y + 3] = 0
return input
在这种情况下,我使用克罗内克乘积技巧:
import torch
torch.set_printoptions(linewidth=200) # you can better see how the mask is shaped
# simulating an input
input = torch.rand(1, 1, 28, 28) - 0.5
ids = torch.meshgrid((torch.arange(1, 28, 4), torch.arange(1, 28, 4)))
# note that relu(x) = (x > 0.) * x, so adjust it to your needs
relus = torch.nn.functional.relu(input[(slice(None), slice(None), *ids)]).to(bool)
A = torch.ones(4, 4)
# generate a block matrix with ones in positions where blocks are set to 0 in correspondence of relus = 0
mask = torch.kron(relus, A)
print(mask.shape)
output = input * mask
print(mask[0, 0])
print(output[0, 0])
我得到一个采用矩阵的二维函数 - 形状为 (28, 28) 的二维张量 我得到了一个张量,比方说 (64, 10, 28, 28) - 这是一个包含一批 64 张图像的张量,这些图像通过(10 个内核)conv2d 层。
现在,我想激活张量的最后两个维度,即 (28,28) 位,一个二维函数。
现在我以一种非常低效的方式做到了:
def activation_func(input):
for batch_idx in range(input.shape[0]):
for channel_inx in range(input.shape[1]):
input[batch_idx][channel_inx] = 2D_function(input[batch_idx][channel_inx])
return input
正如我所注意到的,这是非常低效的。 有什么方法可以有效地做到这一点?
如有必要,我可以编写完整的代码
编辑:
def 2D_function(input):
global indices # yes I know, I will remove this global stuff later
# indices = [(i, j) for i in range(1, 28, 4) for j in range(1, 28, 4)]
for x, y in indices:
relu_decision = relu(input[x, y]) # standard relu - relu(x)=(x>1)*x
if not relu_decision:
# zero out the patch
input[x - 1: x + 3, y - 1: y + 3] = 0
return input
在这种情况下,我使用克罗内克乘积技巧:
import torch
torch.set_printoptions(linewidth=200) # you can better see how the mask is shaped
# simulating an input
input = torch.rand(1, 1, 28, 28) - 0.5
ids = torch.meshgrid((torch.arange(1, 28, 4), torch.arange(1, 28, 4)))
# note that relu(x) = (x > 0.) * x, so adjust it to your needs
relus = torch.nn.functional.relu(input[(slice(None), slice(None), *ids)]).to(bool)
A = torch.ones(4, 4)
# generate a block matrix with ones in positions where blocks are set to 0 in correspondence of relus = 0
mask = torch.kron(relus, A)
print(mask.shape)
output = input * mask
print(mask[0, 0])
print(output[0, 0])