当 Workers 大于 1 时,Dataloader 上的 Pytorch SSLError
Pytorch SSLError on Dataloader when Workers are greater than 1
我创建了一个数据集对象,它在加载项目时从 API 加载一些数据
class MyDataset(Dataset):
def __init__(self, obj_ids = []):
"""
"""
super(Dataset, self).__init__()
self.obj_ids = obj_ids
def __len__(self):
return len(self.obj_ids)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
result = session.get('/api/url/{}'.format(idx))
## Post processing work...
然后我将它添加到我的数据加载器:
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, num_workers=1,
collate_fn=utils.collate_fn)
使用 num_workers=1
训练时一切正常。但是当我将它增加到 2 或更大时,我的训练循环出现错误。
这一行:
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
SSLError: Caught SSLError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 600, in urlopen
chunked=chunked)
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 384, in _make_request
six.raise_from(e, None)
File "<string>", line 2, in raise_from
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 380, in _make_request
httplib_response = conn.getresponse()
File "/usr/lib/python3.7/http/client.py", line 1373, in getresponse
response.begin()
File "/usr/lib/python3.7/http/client.py", line 319, in begin
version, status, reason = self._read_status()
File "/usr/lib/python3.7/http/client.py", line 280, in _read_status
line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")
File "/usr/lib/python3.7/socket.py", line 589, in readinto
return self._sock.recv_into(b)
File "/usr/lib/python3.7/ssl.py", line 1071, in recv_into
return self.read(nbytes, buffer)
File "/usr/lib/python3.7/ssl.py", line 929, in read
return self._sslobj.read(len, buffer)
ssl.SSLError: [SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption failed or bad record mac (_ssl.c:2570)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 449, in send
timeout=timeout
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 638, in urlopen
_stacktrace=sys.exc_info()[2])
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/retry.py", line 399, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='mydomain.com', port=443): Max retries exceeded with url: 'url_with_error_is_here' (Caused by SSLError(SSLError(1, '[SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption failed or bad record mac (_ssl.c:2570)')))
如果我删除 post 请求,我将不再收到 SSL 错误,因此问题最可能出在 requests.post 库或 urllib 上。
我将错误中的域和 url 更改为虚拟值,但是 url 和域在只有 1 个工作人员时都有效。
我 运行 在启用了 GPU 的 google 协作环境中,但也在我的本地计算机上尝试过它并遇到同样的问题。
谁能帮我解决这个问题?
经过一些调试并阅读了有关 multiprocessing
和 request.session
的更多信息之后。看来问题是我不能在数据集中使用 requests.session
,因为 pytorch 最终在训练循环中使用多处理。
关于这个问题的更多信息:How to assign python requests sessions for single processes in multiprocessing pool?
问题已通过将任何 session.get
或 session.post
更改为 requests.get
和 requests.post
解决,因为在没有会话的情况下使用它会避免共享相同的连接并获得该连接SSLError.
我创建了一个数据集对象,它在加载项目时从 API 加载一些数据
class MyDataset(Dataset):
def __init__(self, obj_ids = []):
"""
"""
super(Dataset, self).__init__()
self.obj_ids = obj_ids
def __len__(self):
return len(self.obj_ids)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
result = session.get('/api/url/{}'.format(idx))
## Post processing work...
然后我将它添加到我的数据加载器:
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=2, shuffle=True, num_workers=1,
collate_fn=utils.collate_fn)
使用 num_workers=1
训练时一切正常。但是当我将它增加到 2 或更大时,我的训练循环出现错误。
这一行:
train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
SSLError: Caught SSLError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 600, in urlopen
chunked=chunked)
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 384, in _make_request
six.raise_from(e, None)
File "<string>", line 2, in raise_from
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 380, in _make_request
httplib_response = conn.getresponse()
File "/usr/lib/python3.7/http/client.py", line 1373, in getresponse
response.begin()
File "/usr/lib/python3.7/http/client.py", line 319, in begin
version, status, reason = self._read_status()
File "/usr/lib/python3.7/http/client.py", line 280, in _read_status
line = str(self.fp.readline(_MAXLINE + 1), "iso-8859-1")
File "/usr/lib/python3.7/socket.py", line 589, in readinto
return self._sock.recv_into(b)
File "/usr/lib/python3.7/ssl.py", line 1071, in recv_into
return self.read(nbytes, buffer)
File "/usr/lib/python3.7/ssl.py", line 929, in read
return self._sslobj.read(len, buffer)
ssl.SSLError: [SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption failed or bad record mac (_ssl.c:2570)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/requests/adapters.py", line 449, in send
timeout=timeout
File "/usr/local/lib/python3.7/dist-packages/urllib3/connectionpool.py", line 638, in urlopen
_stacktrace=sys.exc_info()[2])
File "/usr/local/lib/python3.7/dist-packages/urllib3/util/retry.py", line 399, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPSConnectionPool(host='mydomain.com', port=443): Max retries exceeded with url: 'url_with_error_is_here' (Caused by SSLError(SSLError(1, '[SSL: DECRYPTION_FAILED_OR_BAD_RECORD_MAC] decryption failed or bad record mac (_ssl.c:2570)')))
如果我删除 post 请求,我将不再收到 SSL 错误,因此问题最可能出在 requests.post 库或 urllib 上。
我将错误中的域和 url 更改为虚拟值,但是 url 和域在只有 1 个工作人员时都有效。
我 运行 在启用了 GPU 的 google 协作环境中,但也在我的本地计算机上尝试过它并遇到同样的问题。
谁能帮我解决这个问题?
经过一些调试并阅读了有关 multiprocessing
和 request.session
的更多信息之后。看来问题是我不能在数据集中使用 requests.session
,因为 pytorch 最终在训练循环中使用多处理。
关于这个问题的更多信息:How to assign python requests sessions for single processes in multiprocessing pool?
问题已通过将任何 session.get
或 session.post
更改为 requests.get
和 requests.post
解决,因为在没有会话的情况下使用它会避免共享相同的连接并获得该连接SSLError.