Libtorch:无法加载跟踪的 lstm 脚本模型
Libtorch: cannot load traced lstm scriptmodel
我保存了一个 pytorch ScriptModule 并使用 libtorch 加载它。但是我遇到了以下问题
我用的是win10下的linux子系统,我用的是pytorch 1.2。
要重现我的问题,您可以运行这段python代码来保存一个pt模型
import torch
import torch.nn as nn
# TODO: https://github.com/pytorch/pytorch/issues/23930
class test(torch.jit.ScriptModule):
def __init__(self, vocab_size=10, rnn_dims=512):
super().__init__()
self.word_embeds = nn.Embedding(vocab_size, rnn_dims)
self.emb_drop = nn.Dropout(0.1)
self.rnn = nn.LSTM(input_size=rnn_dims, hidden_size=rnn_dims, batch_first=True,
num_layers=2, dropout=0.1)
# delattr(self.rnn, 'forward_packed')
@torch.jit.script_method
def forward(self, x):
h1 = (torch.zeros(2, 1, 512), torch.zeros(2, 1, 512))
embeds = self.emb_drop(self.word_embeds(x))
out, h1 = self.rnn(embeds, h1)
return h1
model = test()
input = torch.ones((1,3)).long()
output = model(input)
print('output', output)
# torch.onnx.export(model, # model being run
# input,
# 'test.onnx',
# example_outputs=output)
#torch.jit.trace(model, (torch.ones((1,3)).long(), torch.ones((3,1))), check_trace=False)
model.save('lstm_test.pt')
然后在libtorch中加载模型。
我不知道为什么会出现这个错误。我根本不使用 PackedSequence。希望有人能帮帮我。
我现在知道怎么回事了。 libtorch版本官网版本不对。现在当我使用正确的 libtorch 1.2 时就可以了。参考问题 https://github.com/pytorch/pytorch/issues/24382
我保存了一个 pytorch ScriptModule 并使用 libtorch 加载它。但是我遇到了以下问题
我用的是win10下的linux子系统,我用的是pytorch 1.2。
要重现我的问题,您可以运行这段python代码来保存一个pt模型
import torch
import torch.nn as nn
# TODO: https://github.com/pytorch/pytorch/issues/23930
class test(torch.jit.ScriptModule):
def __init__(self, vocab_size=10, rnn_dims=512):
super().__init__()
self.word_embeds = nn.Embedding(vocab_size, rnn_dims)
self.emb_drop = nn.Dropout(0.1)
self.rnn = nn.LSTM(input_size=rnn_dims, hidden_size=rnn_dims, batch_first=True,
num_layers=2, dropout=0.1)
# delattr(self.rnn, 'forward_packed')
@torch.jit.script_method
def forward(self, x):
h1 = (torch.zeros(2, 1, 512), torch.zeros(2, 1, 512))
embeds = self.emb_drop(self.word_embeds(x))
out, h1 = self.rnn(embeds, h1)
return h1
model = test()
input = torch.ones((1,3)).long()
output = model(input)
print('output', output)
# torch.onnx.export(model, # model being run
# input,
# 'test.onnx',
# example_outputs=output)
#torch.jit.trace(model, (torch.ones((1,3)).long(), torch.ones((3,1))), check_trace=False)
model.save('lstm_test.pt')
然后在libtorch中加载模型。
我不知道为什么会出现这个错误。我根本不使用 PackedSequence。希望有人能帮帮我。
我现在知道怎么回事了。 libtorch版本官网版本不对。现在当我使用正确的 libtorch 1.2 时就可以了。参考问题 https://github.com/pytorch/pytorch/issues/24382