从 PyTorch 中的批量张量中索引元素
Indexing elements from a batch tensor in PyTorch
说,我在 PyTorch 中有一批图像。对于每张图片,我还有一个像素位置,比如 (x, y)
。可以使用 img[x, y]
读取一幅图像的像素值。我正在尝试读取批处理中每个图像的像素值。请看下面的代码片段:
import torch
# create tensors to represent random images in torch format
img_1 = torch.rand(1, 200, 300)
img_2 = torch.rand(1, 200, 300)
img_3 = torch.rand(1, 200, 300)
img_4 = torch.rand(1, 200, 300)
# for each image, x-y value are know, so creating a tuple
img1_xy = (0, 10, 70)
img2_xy = (0, 40, 20)
img3_xy = (0, 30, 50)
img4_xy = (0, 80, 60)
# this is what I am doing right now
imgs = [img_1, img_2, img_3, img_4]
imgs_xy = [img1_xy, img2_xy, img3_xy, img4_xy]
x = [img[xy] for img, xy in zip(imgs, imgs_xy)]
x = torch.as_tensor(x)
我的疑虑和问题
- 在每个图像中,像素位置即
(x, y)
是已知的。但是,我必须创建一个包含另一个元素的元组,即 0 以确保该元组与图像的形状相匹配。有什么优雅的方法吗?
- 不使用
tuple
,我们不能使用张量然后获取像素值吗?
- 所有图像都可以连接成一个批次
img_batch = torch.cat((img_1, img_2, img_3, img_4))
。但是元组呢?
您可以连接图像以形成 (4, 200, 300)
形状的堆叠张量。然后,我们可以使用每个图像的已知 (x, y)
对对其进行索引,如下所示:第一张图像需要 [0, x1, y1]
,第二张图像需要 [1, x2, y2]
,第三张图像需要 [2, x3, y3]
等等。这些可以通过“花式索引”来实现:
# stacking as you did
>>> stacked_imgs = torch.cat(imgs)
>>> stacked_imgs.shape
(4, 200, 300)
# no need for 0s in front
>>> imgs_xy = [(10, 70), (40, 20), (30, 50), (80, 60)]
# need xs together and ys together: take transpose of `imgs_xy`
>>> inds_x, inds_y = torch.tensor(imgs_xy).T
>>> inds_x
tensor([10, 40, 30, 80])
>>> inds_y
tensor([70, 20, 50, 60])
# now we index into the batch
>>> num_imgs = len(imgs)
>>> result = stacked_imgs[range(num_imgs), inds_x, inds_y]
>>> result
tensor([0.5359, 0.4863, 0.6942, 0.6071])
我们可以查看结果:
>>> torch.tensor([img[0, x, y] for img, (x, y) in zip(imgs, imgs_xy)])
tensor([0.5359, 0.4863, 0.6942, 0.6071])
回答您的问题:
1: 因为我们堆叠了图像,所以这个问题得到了缓解,我们使用 range(4)
来索引每个单独的图像。
2: 是的,我们确实将 x, y
个位置转换为张量。
3: 分离成张量后直接索引
说,我在 PyTorch 中有一批图像。对于每张图片,我还有一个像素位置,比如 (x, y)
。可以使用 img[x, y]
读取一幅图像的像素值。我正在尝试读取批处理中每个图像的像素值。请看下面的代码片段:
import torch
# create tensors to represent random images in torch format
img_1 = torch.rand(1, 200, 300)
img_2 = torch.rand(1, 200, 300)
img_3 = torch.rand(1, 200, 300)
img_4 = torch.rand(1, 200, 300)
# for each image, x-y value are know, so creating a tuple
img1_xy = (0, 10, 70)
img2_xy = (0, 40, 20)
img3_xy = (0, 30, 50)
img4_xy = (0, 80, 60)
# this is what I am doing right now
imgs = [img_1, img_2, img_3, img_4]
imgs_xy = [img1_xy, img2_xy, img3_xy, img4_xy]
x = [img[xy] for img, xy in zip(imgs, imgs_xy)]
x = torch.as_tensor(x)
我的疑虑和问题
- 在每个图像中,像素位置即
(x, y)
是已知的。但是,我必须创建一个包含另一个元素的元组,即 0 以确保该元组与图像的形状相匹配。有什么优雅的方法吗? - 不使用
tuple
,我们不能使用张量然后获取像素值吗? - 所有图像都可以连接成一个批次
img_batch = torch.cat((img_1, img_2, img_3, img_4))
。但是元组呢?
您可以连接图像以形成 (4, 200, 300)
形状的堆叠张量。然后,我们可以使用每个图像的已知 (x, y)
对对其进行索引,如下所示:第一张图像需要 [0, x1, y1]
,第二张图像需要 [1, x2, y2]
,第三张图像需要 [2, x3, y3]
等等。这些可以通过“花式索引”来实现:
# stacking as you did
>>> stacked_imgs = torch.cat(imgs)
>>> stacked_imgs.shape
(4, 200, 300)
# no need for 0s in front
>>> imgs_xy = [(10, 70), (40, 20), (30, 50), (80, 60)]
# need xs together and ys together: take transpose of `imgs_xy`
>>> inds_x, inds_y = torch.tensor(imgs_xy).T
>>> inds_x
tensor([10, 40, 30, 80])
>>> inds_y
tensor([70, 20, 50, 60])
# now we index into the batch
>>> num_imgs = len(imgs)
>>> result = stacked_imgs[range(num_imgs), inds_x, inds_y]
>>> result
tensor([0.5359, 0.4863, 0.6942, 0.6071])
我们可以查看结果:
>>> torch.tensor([img[0, x, y] for img, (x, y) in zip(imgs, imgs_xy)])
tensor([0.5359, 0.4863, 0.6942, 0.6071])
回答您的问题:
1: 因为我们堆叠了图像,所以这个问题得到了缓解,我们使用 range(4)
来索引每个单独的图像。
2: 是的,我们确实将 x, y
个位置转换为张量。
3: 分离成张量后直接索引