将字符串转换为字节以供 pytorch 加载程序使用

Convert string to byte for pytorch loader

下载pytorch模型路径的方法不是我能控制的,我正在想办法将下载的字符串数据转换为字节数据。下面的代码从 Dropbox 下载我保存的模型并使用带有 utf-8 编码的字节来编码字符串。问题是当我将 torch.load 与 BytesIO 一起使用时,我得到一个 UnpicklingError 和无效的加载键 '<'.

    data = bytes(self.Download("https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"), 'utf-8')

    self.agent.local.load_state_dict(torch.load(BytesIO(data ), map_location=lambda storage, loc: storage))

在请求被禁用之前,下面的代码工作正常,我现在正在尝试使用上面的方法。

    dropbox_url = "https://www.dropbox.com/s/exampleurl/checkpoint.pth?dl=1"

    data = requests.get(dropbox_url )

    self.agent.local.load_state_dict(torch.load(BytesIO(data.content), map_location=lambda storage, loc: storage))

我只需要找出一种方法将字符串正确转换为字节数据。

我必须将字节数据转换为 base64 并以该格式保存文件。在我上传到 Dropbox 并使用内置方法下载后,我将 base64 文件转换回字节并且成功了!

import base64
from io import BytesIO

with open("checkpoint.pth", "rb") as f:
    byte = f.read(1)

# Base64 Encode the bytes
data_e = base64.b64encode(byte)

filename ='base64_checkpoint.pth'

with open(filename, "wb") as output:
    output.write(data_e)

# Save file to Dropbox

# Download file on server
b64_str= self.Download('url')

# String Encode to bytes
byte_data = b64_str.encode("UTF-8")

# Decoding the Base64 bytes
str_decoded = base64.b64decode(byte_data)

# String Encode to bytes
byte_decoded = str_decoded.encode("UTF-8")

# Decoding the Base64 bytes
decoded = base64.b64decode(byte_decoded)

torch.load(BytesIO(decoded))