如何在 PyTorch 中“点”权重以批处理数据?

How to `dot` weights to batch data in PyTorch?

我有批量数据,想dot()到数据。 W是可训练参数。 批次数据和权重如何打点?

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)

更新

这个怎么样?

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim, 1) # assume trainable parameters via nn.Parameter
W = W.unsqueeze(0).expand(10, hid_dim, 1)
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)

展开W张量以匹配data张量的形状。以下应该有效。

hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim)
W = W.unsqueeze(0).unsqueeze(0).expand(*data.size())
result = torch.sum(data * W, 2)
result = result.view(10, 2, 3)

编辑:您更新的代码是正确的。由于您正在将 W 转换为 Bxhid_dimx1 并且您的数据的形状为 Bxdxhid_dim,因此进行批量矩阵乘法将导致 Bxdx1,这实际上是 [= 之间的点积11=] 参数和 data (dxhid_dim) 中的所有行向量。