在 PyTorch 中,“*”运算符在张量大小之前做什么?

What does "*" operator do before the tensor size in PyTorch?

我现在正在学习在 PyTorch 中构建神经网络。以下是从 .py 文件中截取的代码:

x = torch.unsqueeze(torch.linspace(-1, 1, 1000), dim=1)
y = x.pow(2) + 0.1*torch.normal(torch.zeros(*x.size()))

我对 x.size() 之前的 * 运算符的效用感到很困惑。我尝试删除它并绘制散点图,结果证明它与未删除*的散点图相同。

我也在https://pytorch.org/docs/stable/tensors.html查了size的官方文档,没搞明白

Image of torch.size item in documentation

如果你能帮助我,我将不胜感激。

*在Python中这样使用表示(参数)解包。当您将它添加到可迭代对象(即 x.size() returns)之前时,它会解包并(此处)将其项目作为位置参数传递给函数。例如:

def f(a1, a2):
    print(a1, a2)

f(*["Hello", "World"])

您可以查看 another example and more detailed description 的文档链接。

这里 * 对结果没有影响的原因是因为 torch.zero 除了 可变数量的参数 像列表或元组 如前所述 here。不代表*本身没用

然后,由于 torch.Size() class 是 python 元组的子 class,可以使用 * 将其解包。 (x.size() 将 return 一个 torch.Size() 对象)

所以总结一下,x.size() 会给你 (1000, 1)*x.size() 在参数 会给你 1000, 1 torch.zeros()

都接受了