修复 torch random_split() 的种子

Fixing the seed for torch random_split()

拆分数据集时是否可以修复 torch.utils.data.random_split() 的种子,以便可以重现测试结果?

您可以使用 torch.manual_seed 函数在全局播种脚本:

import torch
torch.manual_seed(0)

有关详细信息,请参阅 reproducibility documentation

如果你想专门播种 torch.utils.data.random_split 你可以 "reset" 之后将种子设置为它的初始值。只需像这样使用 torch.initial_seed()

torch.manual_seed(torch.initial_seed())

AFAIK pytorch 提供像 seedrandom_state 这样的参数(例如可以在 sklearn 中看到).

正如您从 documentation 中看到的那样,可以将生成器传递给 random_split

random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
generator = torch.Generator()
generator.manual_seed(0)

train, val, test = random_split(dataset=dataset,
                                lengths=[train_size, val_size, test_size],
                                generator=generator)