如何使用 python 在文件中有效地存储多个浮点数组

How to store multiple float arrays efficiently in a file using python

我正在尝试从 LSTM 的隐藏层中提取嵌入。我有一个包含多个句子的庞大数据集,因此它们会生成多个 numpy 向量。我想将所有这些向量有效地存储到一个文件中。这是我目前所拥有的

with open(src_vectors_save_file, "wb") as s_writer, open(tgt_vectors_save_file, "wb") as t_writer:
    for batch in data_iter:
        encoder_hidden_layer, decoder_hidden_layer = self.extract_lstm_hidden_states_for_batch(
            batch, data.src_vocabs, attn_debug
        )
        encoder_hidden_layer = encoder_hidden_layer.detach().numpy()
        decoder_hidden_layer = decoder_hidden_layer.detach().numpy()

        enc_hidden_bytes = pickle.dumps(encoder_hidden_layer)
        dec_hidden_bytes = pickle.dumps(decoder_hidden_layer)

        s_writer.write(enc_hidden_bytes)
        s_writer.write("\n")
        t_writer.write(dec_hidden_bytes)
        t_writer.write("\n")

本质上,我使用 picklenp.array 中获取 bytes 并将其写入二进制文件。我试图天真地用 ASCII 换行符分隔每个字节编码的数组,这显然会引发错误。我计划在下一个程序中使用 .readlines() 函数或使用 for 循环每行读取每个字节编码数组。但是,现在这不可能了。

我没有任何想法有人可以提出替代方案吗?我怎样才能有效地将所有数组以压缩方式存储在一个文件中,我怎样才能从该文件中读回它们?

使用 \n 作为分隔符存在问题,因为来自 pickle (enc_hidden_bytes) 的转储可能包含 \n,因为数据不是 ASCII 编码的。

有两种解决方法。您可以将数据中出现的 \n 转义,然后使用 \n 作为终止符。但这甚至在阅读时也增加了复杂性。

另一个解决方案是在开始实际数据之前将数据大小放入文件中。这就像某种 header 并且是通过连接发送数据时非常常见的做法。

可以写出下面两个函数-

import struct

def write_bytes(handle, data):
        total_bytes = len(data)
        handle.write(struct.pack(">Q", total_bytes))
        handle.write(data)

def read_bytes(handle):
        size_bytes = handle.read(8)
        if len(size_bytes) == 0:
            return None
        total_bytes = struct.unpack(">Q", size_bytes)[0]
        return handle.read(total_bytes)

现在您可以替换

s_writer.write(enc_hidden_bytes)
s_writer.write("\n")

write_bytes(s_writer, enc_hidden_bytes)

其他变量也一样。

在循环中从文件中读回时,您可以以类似的方式使用 read_bytes 函数。