如何在(至少)批量维度上对这个 pytorch 代码进行矢量化?

How to vectorize this pytorch code over (at least) the batch dimension?

我想实现一个代码来构建一个邻接矩阵,这样(例如):

如果 X[0] : [0, 1, 2, 0, 1, 0], 那么,

A[0, 1] = 1
A[1, 2] = 1
A[2, 0] = 1
A[0, 1] = 1
A[1, 0] = 1

下面的代码工作正常,但是太慢了!所以,请帮助我至少在批量(第一)维度上对这段代码进行矢量化处理:

A = torch.zeros((3, 3, 3), dtype = torch.float)
X = torch.tensor([[0, 1, 2, 0, 1, 0], [1, 0, 0, 2, 1, 1], [0, 0, 2, 2, 1, 1]])
for a, x in zip(A, X):
    for i, j in zip(x, x[1:]):
        a[i, j] = 1

谢谢! :)

使 X 成为元组而不是张量:

A = torch.zeros((3, 3, 3), dtype = torch.float)
X = ([0, 1, 2, 0, 1, 0], [1, 0, 0, 2, 1, 1], [0, 0, 2, 2, 1, 1])
A[X] = 1

例如,通过这样转换:A[tuple(X)]

我很确定有一种更简单的方法可以做到这一点,但我试图保持在 torch 函数调用的范围内,以确保可以正确跟踪任何梯度操作。

如果这不是反向传播所必需的,我强烈建议您研究可能利用一些 numpy 函数的解决方案,因为我认为可以更有力地保证在这里找到合适的东西。但是,事不宜迟,这是我想出的解决方案。

它实质上是将您的 X 向量转换为一系列与 A 中的位置相对应的元组条目。为此,我们需要对齐一些索引(具体来说,第一个维度仅在 X 中隐式给出,因为 X 中的第一个列表对应于 A[0,:,:],第二个列表为A[1,:,:],等等。 这也可能是您可以开始优化代码的地方,因为我没有找到对这种矩阵的合理描述,因此不得不想出自己的创建方式。

# Start by "aligning" your shifted view of X
# Essentially, take the all but the last element, 
# and put it on top of all but the first element.
X_shift = torch.stack([X[:,:-1], X[:,1:]], dim=2)
# X_shift.shape: (3,5,2) in your example

# To assign this properly, we need to turn it into a "concatenated" list,
# where each entry corresponds to a 2D tuple in the respective dimension of A.
temp_tuples = X_shift.view(-1,2).transpose(0,1)
# temp_tuples.shape: (2,15) in your example. Below are the values:
tensor([[0, 1, 2, 0, 1, 1, 0, 0, 2, 1, 0, 0, 2, 2, 1],
        [1, 2, 0, 1, 0, 0, 0, 2, 1, 1, 0, 2, 2, 1, 1]])

# Now we have to create a matrix do indicate the proper "first dimension index"
fix_dims = torch.repeat_interleave(torch.arange(0,3,1), len(X[0])-1, 0).unsqueeze(dim=0)
# fix_dims.shape: (1,15)
# Long story short, this creates the following vector.
tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]])
# Note that the unsqueeze is necessary to properly concatenate the two matrices:
access_tuples = tuple(torch.cat([fix_dims, temp_tuples], dim=0))
A[access_tuples] = 1

这进一步假设 X 中的每个维度都更改了相同数量的元组。如果不是这种情况,则您必须手动创建一个 fix_dims 向量,其中每个增量重复 X[i] 次的长度。如果它与您的示例相同,您可以安全地使用建议的解决方案。