Pytorch:从网格传递到图像坐标约定以供使用 grid_sample
Pytorch : Pass from grid to image coordinate convention for use of grid_sample
我需要在 PyTorch 中插入一些变形网格,因此决定使用函数 grid_sample
(see doc)。我需要使用这些约定来回重塑网格作为图像:
N
批量大小
D
网格深度(对于 3D 图像)
H
网格高度
W
网格宽度
d
网格维度(=2 或 3)
2D 图像格式为 (N,2,H,W)
(3D 图像格式为 (N,3,D,H,W)
)
当网格格式为 (N,H,W,2)
(对应于 3D 中的 (N,D,H,W,3)
)
我不能使用reshape
或view
,因为他们没有按照我的意愿排列数据。我需要有(例如)
grid_in_grid_convention[0,:,:,0] == grid_in_image_convention[0,0,:,:]
我想出了这些函数来使重塑工作正常,但我确信还有更多 compact/fast 方法可以做到这一点。你教的是什么?
def grid2im(grid):
"""Reshape a grid tensor into an image tensor
2D [T,H,W,2] -> [T,2,H,W]
3D [T,D,H,W,2] -> [T,D,H,W,3]
"""
if grid.shape[0] == 1 and grid.shape[-1] == 2: # 2D case, batch =1
return torch.stack((grid[0,:,:,0],grid[0,:,:,1]),dim = 0).unsqueeze(0)
elif grid.shape[0] == 1 and grid.shape[-1] == 3: # 3D case, batch =1
return torch.stack((grid[0,:,:,:,0],grid[0,:,:,:,1],grid[0,:,:,:,2]),
dim = 0).unsqueeze(0)
elif grid.shape[-1] == 2:
N,H,W,d = grid.shape
temp = torch.zeros((N,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((grid[n,:,:,0],grid[n,:,:,1]),dim = 0).unsqueeze(0)
return temp
elif grid.shape[-1] == 3:
N,D,H,W,d =grid.shape
temp = torch.zeros((N,D,H,W,d))
for n in range(N):
temp[n,:,:,:,:] = torch.stack((grid[n,:,:,:,0],
grid[n,:,:,:,1],
grid[n,:,:,:,2]),
dim = 0).unsqueeze(0)
else:
raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
"got "+str(grid.shape)+" instead.")
def im2grid(image):
"""Reshape an image tensor into a grid tensor
2D case [T,2,H,W] -> [T,H,W,2]
3D case [T,3,D,H,W] -> [T,D,H,W,3]
"""
# No batch
if image.shape[0:2] == (1,2):
return torch.stack((image[0,0,:,:],image[0,1,:,:]),dim= 2).unsqueeze(0)
elif image.shape[0:2] == (1,3):
return torch.stack((image[0,0,:,:],image[0,1,:,:],image[0,2,:,:]),
dim = 2).unsqueeze(0)
# Batch size > 1
elif image.shape[0] > 0 and image.shape[1] == 2 :
N,d,H,W = image.shape
temp = torch.zeros((N,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((image[n,0,:,:],image[n,1,:,:]),dim= 2).unsqueeze(0)
return temp
elif image.shape[0] > 0 and image.shape[1] == 3 :
N,d,D,H,W = image.shape
temp = torch.zeros((N,D,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((image[n,0,:,:],
image[n,1,:,:],
image[n,2,:,:]),
dim = 2).unsqueeze(0)
return temp
else:
raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
"got "+str(image.shape)+" instead.")
您可以使用 transpose()
。例如,如果你想要 [T,H,W,2] -> [T,2,H,W]
作为张量,比如 grid
,你可以做
grid = grid.transpose(2,3).transpose(1,2)
ihdv 答案正确。谢谢!
我在这里用更高效和优雅的方式更正了上面的两个函数来完成答案。
def grid2im(grid):
"""Reshape a grid tensor into an image tensor
2D [T,H,W,2] -> [T,2,H,W]
3D [T,D,H,W,2] -> [T,D,H,W,3]
"""
if grid.shape[-1] == 2: # 2D case
return grid.transpose(2,3).transpose(1,2)
elif grid.shape[-1] == 3: # 3D case
return grid.transpose(3,4).transpose(2,3).transpose(1,2)
else:
raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
"got "+str(grid.shape)+" instead.")
def im2grid(image):
"""Reshape an image tensor into a grid tensor
2D case [T,2,H,W] -> [T,H,W,2]
3D case [T,3,D,H,W] -> [T,D,H,W,3]
"""
# No batch
if image.shape[1] == 2:
return image.transpose(1,2).transpose(2,3)
elif image.shape[1] == 3:
return image.transpose(1,2).transpose(2,3).transpose(3,4)
else:
raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
"got "+str(image.shape)+" instead.")
我需要在 PyTorch 中插入一些变形网格,因此决定使用函数 grid_sample
(see doc)。我需要使用这些约定来回重塑网格作为图像:
N
批量大小D
网格深度(对于 3D 图像)H
网格高度W
网格宽度d
网格维度(=2 或 3)
2D 图像格式为 (N,2,H,W)
(3D 图像格式为 (N,3,D,H,W)
)
当网格格式为 (N,H,W,2)
(对应于 3D 中的 (N,D,H,W,3)
)
我不能使用reshape
或view
,因为他们没有按照我的意愿排列数据。我需要有(例如)
grid_in_grid_convention[0,:,:,0] == grid_in_image_convention[0,0,:,:]
我想出了这些函数来使重塑工作正常,但我确信还有更多 compact/fast 方法可以做到这一点。你教的是什么?
def grid2im(grid):
"""Reshape a grid tensor into an image tensor
2D [T,H,W,2] -> [T,2,H,W]
3D [T,D,H,W,2] -> [T,D,H,W,3]
"""
if grid.shape[0] == 1 and grid.shape[-1] == 2: # 2D case, batch =1
return torch.stack((grid[0,:,:,0],grid[0,:,:,1]),dim = 0).unsqueeze(0)
elif grid.shape[0] == 1 and grid.shape[-1] == 3: # 3D case, batch =1
return torch.stack((grid[0,:,:,:,0],grid[0,:,:,:,1],grid[0,:,:,:,2]),
dim = 0).unsqueeze(0)
elif grid.shape[-1] == 2:
N,H,W,d = grid.shape
temp = torch.zeros((N,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((grid[n,:,:,0],grid[n,:,:,1]),dim = 0).unsqueeze(0)
return temp
elif grid.shape[-1] == 3:
N,D,H,W,d =grid.shape
temp = torch.zeros((N,D,H,W,d))
for n in range(N):
temp[n,:,:,:,:] = torch.stack((grid[n,:,:,:,0],
grid[n,:,:,:,1],
grid[n,:,:,:,2]),
dim = 0).unsqueeze(0)
else:
raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
"got "+str(grid.shape)+" instead.")
def im2grid(image):
"""Reshape an image tensor into a grid tensor
2D case [T,2,H,W] -> [T,H,W,2]
3D case [T,3,D,H,W] -> [T,D,H,W,3]
"""
# No batch
if image.shape[0:2] == (1,2):
return torch.stack((image[0,0,:,:],image[0,1,:,:]),dim= 2).unsqueeze(0)
elif image.shape[0:2] == (1,3):
return torch.stack((image[0,0,:,:],image[0,1,:,:],image[0,2,:,:]),
dim = 2).unsqueeze(0)
# Batch size > 1
elif image.shape[0] > 0 and image.shape[1] == 2 :
N,d,H,W = image.shape
temp = torch.zeros((N,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((image[n,0,:,:],image[n,1,:,:]),dim= 2).unsqueeze(0)
return temp
elif image.shape[0] > 0 and image.shape[1] == 3 :
N,d,D,H,W = image.shape
temp = torch.zeros((N,D,H,W,d))
for n in range(N):
temp[n,:,:,:] = torch.stack((image[n,0,:,:],
image[n,1,:,:],
image[n,2,:,:]),
dim = 2).unsqueeze(0)
return temp
else:
raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
"got "+str(image.shape)+" instead.")
您可以使用 transpose()
。例如,如果你想要 [T,H,W,2] -> [T,2,H,W]
作为张量,比如 grid
,你可以做
grid = grid.transpose(2,3).transpose(1,2)
ihdv 答案正确。谢谢!
我在这里用更高效和优雅的方式更正了上面的两个函数来完成答案。
def grid2im(grid):
"""Reshape a grid tensor into an image tensor
2D [T,H,W,2] -> [T,2,H,W]
3D [T,D,H,W,2] -> [T,D,H,W,3]
"""
if grid.shape[-1] == 2: # 2D case
return grid.transpose(2,3).transpose(1,2)
elif grid.shape[-1] == 3: # 3D case
return grid.transpose(3,4).transpose(2,3).transpose(1,2)
else:
raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
"got "+str(grid.shape)+" instead.")
def im2grid(image):
"""Reshape an image tensor into a grid tensor
2D case [T,2,H,W] -> [T,H,W,2]
3D case [T,3,D,H,W] -> [T,D,H,W,3]
"""
# No batch
if image.shape[1] == 2:
return image.transpose(1,2).transpose(2,3)
elif image.shape[1] == 3:
return image.transpose(1,2).transpose(2,3).transpose(3,4)
else:
raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
"got "+str(image.shape)+" instead.")