如何在 PyTorch 中“点”权重以批处理数据?
How to `dot` weights to batch data in PyTorch?
我有批量数据,想dot()
到数据。 W是可训练参数。
批次数据和权重如何打点?
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)
更新
这个怎么样?
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim, 1) # assume trainable parameters via nn.Parameter
W = W.unsqueeze(0).expand(10, hid_dim, 1)
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)
展开W
张量以匹配data
张量的形状。以下应该有效。
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim)
W = W.unsqueeze(0).unsqueeze(0).expand(*data.size())
result = torch.sum(data * W, 2)
result = result.view(10, 2, 3)
编辑:您更新的代码是正确的。由于您正在将 W
转换为 Bxhid_dimx1
并且您的数据的形状为 Bxdxhid_dim
,因此进行批量矩阵乘法将导致 Bxdx1
,这实际上是 [= 之间的点积11=] 参数和 data
(dxhid_dim
) 中的所有行向量。
我有批量数据,想dot()
到数据。 W是可训练参数。
批次数据和权重如何打点?
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim) # assume trainable parameters via nn.Parameter
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)
更新
这个怎么样?
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = tdata.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim, 1) # assume trainable parameters via nn.Parameter
W = W.unsqueeze(0).expand(10, hid_dim, 1)
result = torch.bmm(data, W).squeeze() # error, want (N, 6)
result = result.view(10, 2, 3)
展开W
张量以匹配data
张量的形状。以下应该有效。
hid_dim = 32
data = torch.randn(10, 2, 3, hid_dim)
data = data.view(10, 2*3, hid_dim)
W = torch.randn(hid_dim)
W = W.unsqueeze(0).unsqueeze(0).expand(*data.size())
result = torch.sum(data * W, 2)
result = result.view(10, 2, 3)
编辑:您更新的代码是正确的。由于您正在将 W
转换为 Bxhid_dimx1
并且您的数据的形状为 Bxdxhid_dim
,因此进行批量矩阵乘法将导致 Bxdx1
,这实际上是 [= 之间的点积11=] 参数和 data
(dxhid_dim
) 中的所有行向量。