在多处理过程中加载 PyTorch 模型闪烁 cmd

Loading PyTorch model in multiprocessing Process flashes cmd

问题:

当我在子进程中打开 pytorch 模型(读取为 从磁盘加载 state_dict)时,它会弹出 cmd window 几毫秒,这会导致其他程序失去焦点 - 在做其他事情时很烦人等等。 我已将原因追溯到 2 行,在某些情况下都导致了它,并设法重现了其中一个(第二个是在执行 model.to(device))

main.py

model_path = 'testing\agent\model_test.pth'

# create model
from testing.agent.torch_model import Net
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # I have cuda available
m = Net()
m.to(device)
# save it
torch.save(m.state_dict(), model_path)

# open it in subprocess
from testing.agent.AgentOpenSim_Process import Open_Agent_Sim
p = Open_Agent_Sim(p=model_path, msgLogger=None)
p.start()

torch_model.py

(来源 pydocs:https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html

import torch.nn as nn
import torch.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

AgentOpenSim_Process.py

from multiprocessing import Queue, Process
import os, time, torch

from testing.agent.torch_model import Net


class Open_Agent_Sim(Process): 
    def __init__(self, p:str, **kwargs):          
        super(Process, self).__init__(daemon=True)
        self.path = p
        self._msgLogger = kwargs['msgLogger'] if kwargs['msgLogger'] is not None else Queue() 
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device_cpu = torch.device("cpu")

    def __print(self, msg: str, verbosity):
        # personal tool for debugging multiprocessing (it sends messages to Queue and main process 
        # reads them inside thread and prints into console... and yes i know about multiprocessing 
        # logger, its used aswell)
        self._msgLogger.put(('Open_Agent_Sim: '+msg, verbosity)) 

    def run(self):
        self.__pid = os.getpid()
        try:
            self.__print('opening model',0)
            self.init_agent()
            self.__print('opening model - done',0)
        except Exception as e:           
            # solved by custom exception wrapper
            pass
        else:            
            self.__print('has ended',0)
            return

    def init_agent(self):
        # init instance
        self.__print('0a', 0) 
        m = Net()    
        self.__print('0b', 0)  
        time.sleep(2)
        # load state dict
        self.__print('1a', 0)        
        l = torch.load(self.path, map_location=self.device_cpu)
        self.__print('1b', 0)
        time.sleep(2)
        self.__print('2a', 0)
        m.load_state_dict(l)
        # set to device
        self.__print('2b', 0)   
        time.sleep(2)
        try:
            self.__print('3a', 0)
            m.to(self.device) # ----> This line pops up cmd
            self.__print('3b', 0)             
        except RuntimeError as e:
            self.__print(str(e), 0)

当可视化调试那些 cmd 弹出时,它总是在步骤 1 (m.load_state_dict(torch.load(self.path, map_location=self.device)))

我试过禁用控制台输出之类的方法,但没有用。

import contextlib
with contextlib.redirect_stdout(None):
    ...

if __name__=='__main__': 没有区别,而且这都是一些较低子进程中繁重的多处理的一部分

更新

我将问题追溯到切换设备 - 如果我使用 torch.load(self.path, map_location=self.device_cpu) 和后来的 .to(self.device_gpu) 它会在 .to(...) 上弹出 cmd 但如果我使用 torch.load(self.path, map_location=self.device_gpu) 它会弹出那条线。另外需要注意的是,保存在哪个设备型号上并不重要。

我愿意接受任何解决方法。

通过他们网站上的安装命令更新 pytorch 版本解决了这个问题