在多处理过程中加载 PyTorch 模型闪烁 cmd
Loading PyTorch model in multiprocessing Process flashes cmd
问题:
当我在子进程中打开 pytorch 模型(读取为 从磁盘加载 state_dict)时,它会弹出 cmd window 几毫秒,这会导致其他程序失去焦点 - 在做其他事情时很烦人等等。
我已将原因追溯到 2 行,在某些情况下都导致了它,并设法重现了其中一个(第二个是在执行 model.to(device)
)
时
main.py
model_path = 'testing\agent\model_test.pth'
# create model
from testing.agent.torch_model import Net
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # I have cuda available
m = Net()
m.to(device)
# save it
torch.save(m.state_dict(), model_path)
# open it in subprocess
from testing.agent.AgentOpenSim_Process import Open_Agent_Sim
p = Open_Agent_Sim(p=model_path, msgLogger=None)
p.start()
torch_model.py
(来源 pydocs:https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html)
import torch.nn as nn
import torch.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
AgentOpenSim_Process.py
from multiprocessing import Queue, Process
import os, time, torch
from testing.agent.torch_model import Net
class Open_Agent_Sim(Process):
def __init__(self, p:str, **kwargs):
super(Process, self).__init__(daemon=True)
self.path = p
self._msgLogger = kwargs['msgLogger'] if kwargs['msgLogger'] is not None else Queue()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device_cpu = torch.device("cpu")
def __print(self, msg: str, verbosity):
# personal tool for debugging multiprocessing (it sends messages to Queue and main process
# reads them inside thread and prints into console... and yes i know about multiprocessing
# logger, its used aswell)
self._msgLogger.put(('Open_Agent_Sim: '+msg, verbosity))
def run(self):
self.__pid = os.getpid()
try:
self.__print('opening model',0)
self.init_agent()
self.__print('opening model - done',0)
except Exception as e:
# solved by custom exception wrapper
pass
else:
self.__print('has ended',0)
return
def init_agent(self):
# init instance
self.__print('0a', 0)
m = Net()
self.__print('0b', 0)
time.sleep(2)
# load state dict
self.__print('1a', 0)
l = torch.load(self.path, map_location=self.device_cpu)
self.__print('1b', 0)
time.sleep(2)
self.__print('2a', 0)
m.load_state_dict(l)
# set to device
self.__print('2b', 0)
time.sleep(2)
try:
self.__print('3a', 0)
m.to(self.device) # ----> This line pops up cmd
self.__print('3b', 0)
except RuntimeError as e:
self.__print(str(e), 0)
当可视化调试那些 cmd 弹出时,它总是在步骤 1 (m.load_state_dict(torch.load(self.path, map_location=self.device))
)
我试过禁用控制台输出之类的方法,但没有用。
import contextlib
with contextlib.redirect_stdout(None):
...
if __name__=='__main__':
没有区别,而且这都是一些较低子进程中繁重的多处理的一部分
更新
我将问题追溯到切换设备 - 如果我使用 torch.load(self.path, map_location=self.device_cpu)
和后来的 .to(self.device_gpu)
它会在 .to(...)
上弹出 cmd 但如果我使用 torch.load(self.path, map_location=self.device_gpu)
它会弹出那条线。另外需要注意的是,保存在哪个设备型号上并不重要。
我愿意接受任何解决方法。
通过他们网站上的安装命令更新 pytorch 版本解决了这个问题
问题:
当我在子进程中打开 pytorch 模型(读取为 从磁盘加载 state_dict)时,它会弹出 cmd window 几毫秒,这会导致其他程序失去焦点 - 在做其他事情时很烦人等等。
我已将原因追溯到 2 行,在某些情况下都导致了它,并设法重现了其中一个(第二个是在执行 model.to(device)
)
main.py
model_path = 'testing\agent\model_test.pth'
# create model
from testing.agent.torch_model import Net
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # I have cuda available
m = Net()
m.to(device)
# save it
torch.save(m.state_dict(), model_path)
# open it in subprocess
from testing.agent.AgentOpenSim_Process import Open_Agent_Sim
p = Open_Agent_Sim(p=model_path, msgLogger=None)
p.start()
torch_model.py
(来源 pydocs:https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html)
import torch.nn as nn
import torch.functional as F
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
AgentOpenSim_Process.py
from multiprocessing import Queue, Process
import os, time, torch
from testing.agent.torch_model import Net
class Open_Agent_Sim(Process):
def __init__(self, p:str, **kwargs):
super(Process, self).__init__(daemon=True)
self.path = p
self._msgLogger = kwargs['msgLogger'] if kwargs['msgLogger'] is not None else Queue()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device_cpu = torch.device("cpu")
def __print(self, msg: str, verbosity):
# personal tool for debugging multiprocessing (it sends messages to Queue and main process
# reads them inside thread and prints into console... and yes i know about multiprocessing
# logger, its used aswell)
self._msgLogger.put(('Open_Agent_Sim: '+msg, verbosity))
def run(self):
self.__pid = os.getpid()
try:
self.__print('opening model',0)
self.init_agent()
self.__print('opening model - done',0)
except Exception as e:
# solved by custom exception wrapper
pass
else:
self.__print('has ended',0)
return
def init_agent(self):
# init instance
self.__print('0a', 0)
m = Net()
self.__print('0b', 0)
time.sleep(2)
# load state dict
self.__print('1a', 0)
l = torch.load(self.path, map_location=self.device_cpu)
self.__print('1b', 0)
time.sleep(2)
self.__print('2a', 0)
m.load_state_dict(l)
# set to device
self.__print('2b', 0)
time.sleep(2)
try:
self.__print('3a', 0)
m.to(self.device) # ----> This line pops up cmd
self.__print('3b', 0)
except RuntimeError as e:
self.__print(str(e), 0)
当可视化调试那些 cmd 弹出时,它总是在步骤 1 (m.load_state_dict(torch.load(self.path, map_location=self.device))
)
我试过禁用控制台输出之类的方法,但没有用。
import contextlib
with contextlib.redirect_stdout(None):
...
if __name__=='__main__':
没有区别,而且这都是一些较低子进程中繁重的多处理的一部分
更新
我将问题追溯到切换设备 - 如果我使用 torch.load(self.path, map_location=self.device_cpu)
和后来的 .to(self.device_gpu)
它会在 .to(...)
上弹出 cmd 但如果我使用 torch.load(self.path, map_location=self.device_gpu)
它会弹出那条线。另外需要注意的是,保存在哪个设备型号上并不重要。
我愿意接受任何解决方法。
通过他们网站上的安装命令更新 pytorch 版本解决了这个问题