Pytorch 加权张量

Pytorch weighted Tensor

我正在将一些复杂的 TF2 代码移植到 Pytorch。由于TF2不区分Tensor和numpy array,直接就可以了。但是,遇到好几个'you cannot mix Tensor and numpy array here in Pytorch!'的错误,感觉又回到了TF1时代。这是原始的 TF2 代码:

def get_weighted_imgs(points, centers, imgs):
  weights = np.array([[tf.norm(p - c) for c in centers] for p in points], dtype=np.float32)
  weighted_imgs = np.array([[w * img for w, img in zip(weight, imgs)] for weight in weights])

  weights = tf.expand_dims(1 / tf.reduce_sum(weights, axis=1), axis=-1)
  weighted_imgs = tf.reshape(tf.reduce_sum(weighted_imgs, axis=1), [len(weights), 64*64*3])

  return weights * weighted_imgs

还有我有问题的 Pytorch 代码:

def get_weighted_imgs(points, centers, imgs):
  weights = torch.Tensor([[torch.norm(p - c) for c in centers] for p in points])
  weighted_imgs = torch.Tensor([[w * img for w, img in zip(weight, imgs)] for weight in weights])

  weights = torch.unsqueeze(1 / torch.sum(weights, dim=1), dim=-1)
  weighted_imgs = torch.sum(weighted_imgs, dim=1).view([len(weights), 64*64*3])

  return weights * weighted_imgs

def reproducible():
  points = torch.Tensor(np.random.random((128, 5)))
  centers = torch.Tensor(np.random.random((10, 5)))
  imgs = torch.Tensor(np.random.random((10, 64, 64, 3)))

  weighted_imgs = get_weighted_imgs(points, centers, imgs)

我可以保证 tensors/arrays 的尺寸顺序或形状没有问题。我得到的错误信息是

ValueError: only one element tensors can be converted to Python scalars

来自

weighted_imgs = torch.Tensor([[w * img for w, img in zip(weight, imgs)] for weight in weights])

有人可以帮我解决这个问题吗?将不胜感激。

也许这会对您有所帮助,但我不确定权重和 weighted_imgs 之间的最终乘法,因为它们不具有相同的形状,即使在按照您可能想要的方式重塑之后也是如此。我不确定我是否正确理解了您的逻辑:

import torch
def get_weighted_imgs(points, centers, imgs):
  weights = torch.Tensor([[torch.norm(p - c) for c in centers] for p in points])
  
  imgs = imgs.unsqueeze(0).repeat(weights.shape[0],1,1,1,1)
  dims_to_rep = list(imgs.shape[-3:])
  weights = weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1,1,*dims_to_rep)
  weights /= torch.sum(weights[...,0:1,0:1,0:1],dim=1, keepdim=True)
  weighted_imgs =  torch.sum(imgs * weights, dim=1).view(weights.shape[0], -1)
  
  return weighted_imgs #weights.view(weighted_imgs.shape[0],-1) *\
         #weighted_imgs # Shapes are torch.Size([128, 122880]) and torch.Size([128, 12288])

def reproducible():
  points = torch.Tensor(np.random.random((128, 5)))
  centers = torch.Tensor(np.random.random((10, 5)))
  imgs = torch.Tensor(np.random.random((10, 64, 64, 3)))

  weighted_imgs = get_weighted_imgs(points, centers, imgs)
#Test:
reproducible()