从 100x100 的 pytorch 张量中获取 10x10 的补丁,圆环样式环绕边界
Get a 10x10 patch from a 100x100 pytorch tensor with torus style wrap around the boundries
我怎样才能从 100x100 的 pytorch 张量中得到一个 10x10 的补丁,并增加了一个约束,即如果一个补丁超出了数组的边界,那么它就会环绕边缘(就好像数组是一个环面,顶部连接到底部,左侧连接到右侧)?
我写了这段代码来完成这项工作,我正在寻找更优雅、高效和清晰的代码:
def shift_matrix(a, distances) -> Tensor:
x, y = distances
a = torch.cat((a[x:], a[0:x]), dim=0)
a = torch.cat((a[:, y:], a[:, :y]), dim=1)
return a
def randomly_shift_matrix(a) -> Tensor:
return shift_matrix(a, np.random.randint(low = 0, high = a.size()))
def random_patch(a, size) -> Tensor:
full_shifted_matrix = randomly_shift_matrix(a)
return full_shifted_matrix[0:size[0], 0:size[1]]
我觉得带有负索引切片的东西应该可以工作。不过我还没找到。
您正在寻找torch.roll
def random_patch(a, size) -> Tensor:
shifts = np.random.randint(low = 0, high = a.size())
return torch.roll(a, shifts=shifts, dims=(0, 1))[:size[0], :size[1]]
我怎样才能从 100x100 的 pytorch 张量中得到一个 10x10 的补丁,并增加了一个约束,即如果一个补丁超出了数组的边界,那么它就会环绕边缘(就好像数组是一个环面,顶部连接到底部,左侧连接到右侧)?
我写了这段代码来完成这项工作,我正在寻找更优雅、高效和清晰的代码:
def shift_matrix(a, distances) -> Tensor:
x, y = distances
a = torch.cat((a[x:], a[0:x]), dim=0)
a = torch.cat((a[:, y:], a[:, :y]), dim=1)
return a
def randomly_shift_matrix(a) -> Tensor:
return shift_matrix(a, np.random.randint(low = 0, high = a.size()))
def random_patch(a, size) -> Tensor:
full_shifted_matrix = randomly_shift_matrix(a)
return full_shifted_matrix[0:size[0], 0:size[1]]
我觉得带有负索引切片的东西应该可以工作。不过我还没找到。
您正在寻找torch.roll
def random_patch(a, size) -> Tensor:
shifts = np.random.randint(low = 0, high = a.size())
return torch.roll(a, shifts=shifts, dims=(0, 1))[:size[0], :size[1]]