Pytorch:从网格传递到图像坐标约定以供使用 grid_sample

Pytorch : Pass from grid to image coordinate convention for use of grid_sample

我需要在 PyTorch 中插入一些变形网格,因此决定使用函数 grid_sample (see doc)。我需要使用这些约定来回重塑网格作为图像:

2D 图像格式为 (N,2,H,W)(3D 图像格式为 (N,3,D,H,W)) 当网格格式为 (N,H,W,2)(对应于 3D 中的 (N,D,H,W,3)

我不能使用reshapeview,因为他们没有按照我的意愿排列数据。我需要有(例如)

grid_in_grid_convention[0,:,:,0] == grid_in_image_convention[0,0,:,:]

我想出了这些函数来使重塑工作正常,但我确信还有更多 compact/fast 方法可以做到这一点。你教的是什么?

def grid2im(grid):
    """Reshape a grid tensor into an image tensor
        2D  [T,H,W,2] -> [T,2,H,W]
        3D  [T,D,H,W,2] -> [T,D,H,W,3]
    """
    if grid.shape[0] == 1 and grid.shape[-1] == 2: # 2D case, batch =1
        return torch.stack((grid[0,:,:,0],grid[0,:,:,1]),dim = 0).unsqueeze(0)

    elif grid.shape[0] == 1 and grid.shape[-1] == 3: # 3D case, batch =1 
        return torch.stack((grid[0,:,:,:,0],grid[0,:,:,:,1],grid[0,:,:,:,2]),
                            dim = 0).unsqueeze(0)
    
    elif grid.shape[-1] == 2:
        N,H,W,d = grid.shape
        temp = torch.zeros((N,H,W,d))
        for n in range(N):
            temp[n,:,:,:] = torch.stack((grid[n,:,:,0],grid[n,:,:,1]),dim = 0).unsqueeze(0)
        return temp
    
    elif grid.shape[-1] == 3:
        N,D,H,W,d =grid.shape
        temp = torch.zeros((N,D,H,W,d))
        for n in range(N):
            temp[n,:,:,:,:] = torch.stack((grid[n,:,:,:,0],
                                           grid[n,:,:,:,1],
                                           grid[n,:,:,:,2]),
                            dim = 0).unsqueeze(0)
    else:
        raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
                         "got "+str(grid.shape)+" instead.")
def im2grid(image):
    """Reshape an image tensor into a grid tensor
        2D case [T,2,H,W]   ->  [T,H,W,2]
        3D case [T,3,D,H,W] ->  [T,D,H,W,3]
    """
    # No batch 
    if image.shape[0:2] == (1,2):
        return torch.stack((image[0,0,:,:],image[0,1,:,:]),dim= 2).unsqueeze(0)
    elif image.shape[0:2] == (1,3):
        return torch.stack((image[0,0,:,:],image[0,1,:,:],image[0,2,:,:]),
                           dim = 2).unsqueeze(0)
    # Batch size > 1
    elif image.shape[0] > 0 and image.shape[1] == 2 :
        N,d,H,W = image.shape
        temp = torch.zeros((N,H,W,d))
        for n in range(N):
            temp[n,:,:,:] = torch.stack((image[n,0,:,:],image[n,1,:,:]),dim= 2).unsqueeze(0)
        return temp
    elif image.shape[0] > 0 and image.shape[1] == 3 :
        N,d,D,H,W = image.shape
        temp = torch.zeros((N,D,H,W,d))
        for n in range(N):
            temp[n,:,:,:] = torch.stack((image[n,0,:,:],
                                         image[n,1,:,:],
                                         image[n,2,:,:]),
                           dim = 2).unsqueeze(0)
        return temp
    else:
        raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
                         "got "+str(image.shape)+" instead.")

您可以使用 transpose()。例如,如果你想要 [T,H,W,2] -> [T,2,H,W] 作为张量,比如 grid,你可以做

grid = grid.transpose(2,3).transpose(1,2)

ihdv 答案正确。谢谢!

我在这里用更高效和优雅的方式更正了上面的两个函数来完成答案。

def grid2im(grid):
    """Reshape a grid tensor into an image tensor
        2D  [T,H,W,2] -> [T,2,H,W]
        3D  [T,D,H,W,2] -> [T,D,H,W,3]
    """
    if grid.shape[-1] == 2: # 2D case
        return grid.transpose(2,3).transpose(1,2)

    elif grid.shape[-1] == 3: # 3D case
        return grid.transpose(3,4).transpose(2,3).transpose(1,2)
    else:
        raise ValueError("input argument expected is [N,H,W,2] or [N,D,H,W,3]",
                         "got "+str(grid.shape)+" instead.")



def im2grid(image):
    """Reshape an image tensor into a grid tensor
        2D case [T,2,H,W]   ->  [T,H,W,2]
        3D case [T,3,D,H,W] ->  [T,D,H,W,3]
    """
    # No batch
    if image.shape[1] == 2:
        return image.transpose(1,2).transpose(2,3)
    elif image.shape[1] == 3:
        return image.transpose(1,2).transpose(2,3).transpose(3,4)
    else:
        raise ValueError("input argument expected is [1,2,H,W] or [1,3,D,H,W]",
                         "got "+str(image.shape)+" instead.")