Pytorch for 循环效率低下

Pytorch for loop inefficience

我有一个关于循环张量的高效问题。

我正在通过图像数据加载器从 CNN 的最后一层提取特征(我使用的批量大小为 8)。我正在获取批量张量的欧氏距离和具有先前特征的 table。

每当 table 中的所有张量都高于阈值时,我想向 table 添加一个张量。我已经实现了一个成功的 运行 代码,但是我使用的循环效率不高,我想知道如何使用更有效的方法而不是这种连续的方式来做类似的事情。

for i, data in enumerate(dataloader, 0):
  input, label = data
  input, label = input.to(device), label.to(device)
  n,c h,w = input.size()
  outputs = model(input)
  if (i == 0):
    features_list = torch.cat( (features_list, outputs[0].view(1,-1)), 0)
  dist_tensores = torch.cdist(outputs, features_list, p=2.0)
  activation = torch.gt(dist_tensores, AVG, out=torch.cuda.FloatTensor(len(outputs), len(features_list)))
  counter = len(features_list)
  activation_list = torch.sum(activation, dim=0)
  for x in range(len(activation)):
    if (torch.sum(activation[x], dim=0) == counter):
      features_list = torch.cat( (features_list, outputs[x].view(1,-1)), 0)

最后一个循环是我想更改的部分,但我真的不知道如何分配和添加我想要的张量,如果不是通过创建一个我可以控制要添加的张量的循环。

最后的循环效率低下,因为它重复连接到同一个张量上。每个连接都必须复制整个现有张量,以便在末尾添加更多元素。运行时间将是串联数量的二次方。

只进行一次连接会更有效率:

outputs_to_concat = []
for x in range(len(activation)):
    if (torch.sum(activation[x], dim=0) == counter):
      outputs_to_concat.append(outputs[x].view(1,-1))
features_list = torch.cat(outputs_to_concat, dim=0)

下面是相同的代码并进行了一些其他清理:

outputs_to_concat = []
for act, output in zip(activation, outputs):
    if torch.sum(act, dim=0) == counter:
        outputs_to_concat.append(output.flatten())
features_list = torch.stack(outputs_to_concat, dim=0)
idx = activation.sum(1) == counter
features_list = torch.cat((features_list, outputs[idx]), 0)

这将取代循环并避免计算和低效问题。