在 PyTorch 中组合参数列表

Combining Parameterlist in PyTorch

我正在尝试在 Pytorch 中组合两个 ParameterList。我已经实现了以下代码片段:

import torch

list = nn.ParameterList()
for i in sub_list_1:
    list.append(i)
for i in sub_list_2:
    list.append(i)

是否有任何函数可以解决这个问题而不需要遍历每个列表?

您可以使用 nn.ParameterList.extend,它的工作方式类似于 python 的内置 list.extend

plist = nn.ParameterList()
plist.extend(sub_list_1)
plist.extend(sub_list_2)

或者,您可以使用 +=,它只是 extend

的别名
plist = nn.ParameterList()
plist += sub_list_1
plist += sub_list_2

除了@jodag 的回答之外,另一种解决方案是展开列表并用它们构建一个新列表

param_list = nn.ParameterList([*sub_list_1, *sub_list_2])