从 100x100 的 pytorch 张量中获取 10x10 的补丁,圆环样式环绕边界

Get a 10x10 patch from a 100x100 pytorch tensor with torus style wrap around the boundries

我怎样才能从 100x100 的 pytorch 张量中得到一个 10x10 的补丁,并增加了一个约束,即如果一个补丁超出了数组的边界,那么它就会环绕边缘(就好像数组是一个环面,顶部连接到底部,左侧连接到右侧)?

我写了这段代码来完成这项工作,我正在寻找更优雅、高效和清晰的代码:

def shift_matrix(a, distances) -> Tensor:
  x, y = distances
  a = torch.cat((a[x:], a[0:x]), dim=0)
  a = torch.cat((a[:, y:], a[:, :y]), dim=1)
  return a

def randomly_shift_matrix(a) -> Tensor:
  return shift_matrix(a, np.random.randint(low = 0, high = a.size()))

def random_patch(a, size) -> Tensor:
  full_shifted_matrix = randomly_shift_matrix(a)
  return full_shifted_matrix[0:size[0], 0:size[1]]

我觉得带有负索引切片的东西应该可以工作。不过我还没找到。

你可以see the code in google colab here.

您正在寻找torch.roll

def random_patch(a, size) -> Tensor:
  shifts = np.random.randint(low = 0, high = a.size())
  return torch.roll(a, shifts=shifts, dims=(0, 1))[:size[0], :size[1]]