在 C++ 火炬中设置神经网络初始权重值
Set neural network initial weight values in C++ torch
我正在寻找 API 来在 libtorch 中设置初始权重值。在 python 版本中,(即 pytorch
)可以轻松使用 torch.nn.functional.weight.data.fill_(xx)
和 torch.nn.functional.bias.data.fill_(xx)
。但是,C++ 中似乎还没有这样的 API 。
对于实现此类功能的任何帮助或评论,我将不胜感激。
谢谢,
阿夫欣
我开发了这个功能来做到这一点:
void set_weights(fc_model &src_net) {
// torch::NoGradGuard no_grad;
torch::autograd::GradMode::set_enabled(false);
for (int k=0; k < src_net.no_layers-1; k++ ) {
src_net.layers[k]->weight.uniform_(0.001, 0.001);
src_net.layers[k]->bias.uniform_(0.0, 0.0);
}
torch::autograd::GradMode::set_enabled(true);
}
其中 src_net
是一个 nn
对象,其所有图层都聚集在一个列表中,名称为“图层”。
我得到的这个解决方案比之前的解决方案更好,其中 model
是类型 torch::nn::Sequential
的对象:
torch::NoGradGuard no_grad;
for (auto &p : model->named_parameters()) {
std::string y = p.key();
auto z = p.value(); // note that z is a Tensor, same as &p : layers->parameters
if (y.compare(2, 6, "weight") == 0)
z.uniform_(l, u);
else if (y.compare(2, 4, "bias") == 0)
z.uniform_(l, u);
}
您可以使用 normal_
而不是 uniform_
,... 可以在手电筒上使用。此解决方案不限于 torch::nn::Linear
层,可用于任何层类型。
我正在寻找 API 来在 libtorch 中设置初始权重值。在 python 版本中,(即 pytorch
)可以轻松使用 torch.nn.functional.weight.data.fill_(xx)
和 torch.nn.functional.bias.data.fill_(xx)
。但是,C++ 中似乎还没有这样的 API 。
对于实现此类功能的任何帮助或评论,我将不胜感激。
谢谢, 阿夫欣
我开发了这个功能来做到这一点:
void set_weights(fc_model &src_net) {
// torch::NoGradGuard no_grad;
torch::autograd::GradMode::set_enabled(false);
for (int k=0; k < src_net.no_layers-1; k++ ) {
src_net.layers[k]->weight.uniform_(0.001, 0.001);
src_net.layers[k]->bias.uniform_(0.0, 0.0);
}
torch::autograd::GradMode::set_enabled(true);
}
其中 src_net
是一个 nn
对象,其所有图层都聚集在一个列表中,名称为“图层”。
我得到的这个解决方案比之前的解决方案更好,其中 model
是类型 torch::nn::Sequential
的对象:
torch::NoGradGuard no_grad;
for (auto &p : model->named_parameters()) {
std::string y = p.key();
auto z = p.value(); // note that z is a Tensor, same as &p : layers->parameters
if (y.compare(2, 6, "weight") == 0)
z.uniform_(l, u);
else if (y.compare(2, 4, "bias") == 0)
z.uniform_(l, u);
}
您可以使用 normal_
而不是 uniform_
,... 可以在手电筒上使用。此解决方案不限于 torch::nn::Linear
层,可用于任何层类型。