创建和使用 PyTorch DataLoader
Creating and Use a PyTorch DataLoader
我正在尝试使用示例数据创建 PyTorch 数据集和 DataLoader 对象。
这是制表符分隔的数据集:
1 0 0.171429 1 0 0 0.966805 0
0 1 0.085714 0 1 0 0.188797 1
1 0 0.000000 0 0 1 0.690871 2
1 0 0.057143 0 1 0 1.000000 1
0 1 1.000000 0 0 1 0.016598 2
1 0 0.171429 1 0 0 0.802905 0
0 1 0.171429 1 0 0 0.966805 1
1 0 0.257143 0 1 0 0.329876 0
这是创建上面的 Dataset 和 DataLoader 对象的代码:
import numpy as np
import torch as T
device = T.device("cpu") # to Tensor or Module
# ---------------------------------------------------
# predictors and label in same file
# data has been normalized and encoded like:
# sex age region income politic
# [0] [2] [3] [6] [7]
# 1 0 0.057143 0 1 0 0.690871 2
class PeopleDataset(T.utils.data.Dataset):
def __init__(self, src_file, num_rows=None):
x_tmp = np.loadtxt(src_file, max_rows=num_rows,
usecols=range(0,7), delimiter="\t",
skiprows=0, dtype=np.float32)
y_tmp = np.loadtxt(src_file, max_rows=num_rows,
usecols=7, delimiter="\t", skiprows=0,
dtype=np.long)
self.x_data = T.tensor(x_tmp,
dtype=T.float32).to(device)
self.y_data = T.tensor(y_tmp,
dtype=T.long).to(device)
def __len__(self):
return len(self.x_data) # required
def __getitem__(self, idx):
if T.is_tensor(idx):
idx = idx.tolist()
preds = self.x_data[idx, 0:7]
pol = self.y_data[idx]
sample = \
{ 'predictors' : preds, 'political' : pol }
return sample
# ---------------------------------------------------
def main():
print("\nBegin PyTorch DataLoader demo ")
# 0. miscellaneous prep
T.manual_seed(0)
np.random.seed(0)
print("\nSource data looks like: ")
print("1 0 0.171429 1 0 0 0.966805 0")
print("0 1 0.085714 0 1 0 0.188797 1")
print(" . . . ")
# 1. create Dataset and DataLoader object
print("\nCreating Dataset and DataLoader ")
train_file = "people_train.txt"
train_ds = PeopleDataset(train_file, num_rows=8)
bat_size = 3
train_ldr = T.utils.data.DataLoader(train_ds,
batch_size=bat_size, shuffle=True)
# 2. iterate thru training data twice
for epoch in range(2):
print("\n==============================\n")
print("Epoch = " + str(epoch))
for (batch_idx, batch) in enumerate(train_ldr):
print("\nBatch = " + str(batch_idx))
X = batch['predictors'] # [3,7]
# Y = T.flatten(batch['political']) #
Y = batch['political'] # [3]
print(X)
print(Y)
print("\n==============================")
print("\nEnd demo ")
if __name__ == "__main__":
main()
代码只是用文件名“demo.py”保存。在命令提示符屏幕上执行命令“python demo.py”后,代码应该会成功执行。我使用安装了 Torch (v 1.10) 的 Anaconda Prompt。
我已经尝试了很多方法来使上述工作正常进行,但我只收到一条错误消息:
Source data looks like:
1 0 0.171429 1 0 0 0.966805 0
0 1 0.085714 0 1 0 0.188797 1
. . .
Creating Dataset and DataLoader
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-8-cfb1177991f2> in <module>()
81
82 if __name__ == "__main__":
---> 83 main()
4 frames
<ipython-input-8-cfb1177991f2> in main()
59
60 train_file = "people_train.txt"
---> 61 train_ds = PeopleDataset(train_file, num_rows=8)
62
63 bat_size = 3
<ipython-input-8-cfb1177991f2> in __init__(self, src_file, num_rows)
20 x_tmp = np.loadtxt(src_file, max_rows=num_rows,
21 usecols=range(0,7), delimiter="\t",
---> 22 skiprows=0, dtype=np.float32)
23 y_tmp = np.loadtxt(src_file, max_rows=num_rows,
24 usecols=7, delimiter="\t", skiprows=0,
/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in loadtxt(fname, dtype, comments, delimiter, converters, skiprows, usecols, unpack, ndmin, encoding, max_rows)
1137 # converting the data
1138 X = None
-> 1139 for x in read_data(_loadtxt_chunksize):
1140 if X is None:
1141 X = np.array(x, dtype)
/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in read_data(chunk_size)
1058 continue
1059 if usecols:
-> 1060 vals = [vals[j] for j in usecols]
1061 if len(vals) != N:
1062 line_num = i + skiprows + 1
/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in <listcomp>(.0)
1058 continue
1059 if usecols:
-> 1060 vals = [vals[j] for j in usecols]
1061 if len(vals) != N:
1062 line_num = i + skiprows + 1
IndexError: list index out of range
我看不出索引的哪一部分有问题,因为我觉得索引似乎没有任何问题。有人可以帮我吗?
您的数据似乎是 space 分隔的,而不是制表符分隔的。因此,当您指定 delimiter="\t"
时,整行将作为单个列读取。但是由于 usecols=range(0,7)
,NumPy 期望有七列,并在尝试迭代它们时抛出错误。
要解决此问题,请将数据中的白色 space 更改为制表符,或将定界符参数更改为 delimiter=" "
。
我正在尝试使用示例数据创建 PyTorch 数据集和 DataLoader 对象。
这是制表符分隔的数据集:
1 0 0.171429 1 0 0 0.966805 0
0 1 0.085714 0 1 0 0.188797 1
1 0 0.000000 0 0 1 0.690871 2
1 0 0.057143 0 1 0 1.000000 1
0 1 1.000000 0 0 1 0.016598 2
1 0 0.171429 1 0 0 0.802905 0
0 1 0.171429 1 0 0 0.966805 1
1 0 0.257143 0 1 0 0.329876 0
这是创建上面的 Dataset 和 DataLoader 对象的代码:
import numpy as np
import torch as T
device = T.device("cpu") # to Tensor or Module
# ---------------------------------------------------
# predictors and label in same file
# data has been normalized and encoded like:
# sex age region income politic
# [0] [2] [3] [6] [7]
# 1 0 0.057143 0 1 0 0.690871 2
class PeopleDataset(T.utils.data.Dataset):
def __init__(self, src_file, num_rows=None):
x_tmp = np.loadtxt(src_file, max_rows=num_rows,
usecols=range(0,7), delimiter="\t",
skiprows=0, dtype=np.float32)
y_tmp = np.loadtxt(src_file, max_rows=num_rows,
usecols=7, delimiter="\t", skiprows=0,
dtype=np.long)
self.x_data = T.tensor(x_tmp,
dtype=T.float32).to(device)
self.y_data = T.tensor(y_tmp,
dtype=T.long).to(device)
def __len__(self):
return len(self.x_data) # required
def __getitem__(self, idx):
if T.is_tensor(idx):
idx = idx.tolist()
preds = self.x_data[idx, 0:7]
pol = self.y_data[idx]
sample = \
{ 'predictors' : preds, 'political' : pol }
return sample
# ---------------------------------------------------
def main():
print("\nBegin PyTorch DataLoader demo ")
# 0. miscellaneous prep
T.manual_seed(0)
np.random.seed(0)
print("\nSource data looks like: ")
print("1 0 0.171429 1 0 0 0.966805 0")
print("0 1 0.085714 0 1 0 0.188797 1")
print(" . . . ")
# 1. create Dataset and DataLoader object
print("\nCreating Dataset and DataLoader ")
train_file = "people_train.txt"
train_ds = PeopleDataset(train_file, num_rows=8)
bat_size = 3
train_ldr = T.utils.data.DataLoader(train_ds,
batch_size=bat_size, shuffle=True)
# 2. iterate thru training data twice
for epoch in range(2):
print("\n==============================\n")
print("Epoch = " + str(epoch))
for (batch_idx, batch) in enumerate(train_ldr):
print("\nBatch = " + str(batch_idx))
X = batch['predictors'] # [3,7]
# Y = T.flatten(batch['political']) #
Y = batch['political'] # [3]
print(X)
print(Y)
print("\n==============================")
print("\nEnd demo ")
if __name__ == "__main__":
main()
代码只是用文件名“demo.py”保存。在命令提示符屏幕上执行命令“python demo.py”后,代码应该会成功执行。我使用安装了 Torch (v 1.10) 的 Anaconda Prompt。
我已经尝试了很多方法来使上述工作正常进行,但我只收到一条错误消息:
Source data looks like:
1 0 0.171429 1 0 0 0.966805 0
0 1 0.085714 0 1 0 0.188797 1
. . .
Creating Dataset and DataLoader
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-8-cfb1177991f2> in <module>()
81
82 if __name__ == "__main__":
---> 83 main()
4 frames
<ipython-input-8-cfb1177991f2> in main()
59
60 train_file = "people_train.txt"
---> 61 train_ds = PeopleDataset(train_file, num_rows=8)
62
63 bat_size = 3
<ipython-input-8-cfb1177991f2> in __init__(self, src_file, num_rows)
20 x_tmp = np.loadtxt(src_file, max_rows=num_rows,
21 usecols=range(0,7), delimiter="\t",
---> 22 skiprows=0, dtype=np.float32)
23 y_tmp = np.loadtxt(src_file, max_rows=num_rows,
24 usecols=7, delimiter="\t", skiprows=0,
/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in loadtxt(fname, dtype, comments, delimiter, converters, skiprows, usecols, unpack, ndmin, encoding, max_rows)
1137 # converting the data
1138 X = None
-> 1139 for x in read_data(_loadtxt_chunksize):
1140 if X is None:
1141 X = np.array(x, dtype)
/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in read_data(chunk_size)
1058 continue
1059 if usecols:
-> 1060 vals = [vals[j] for j in usecols]
1061 if len(vals) != N:
1062 line_num = i + skiprows + 1
/usr/local/lib/python3.7/dist-packages/numpy/lib/npyio.py in <listcomp>(.0)
1058 continue
1059 if usecols:
-> 1060 vals = [vals[j] for j in usecols]
1061 if len(vals) != N:
1062 line_num = i + skiprows + 1
IndexError: list index out of range
我看不出索引的哪一部分有问题,因为我觉得索引似乎没有任何问题。有人可以帮我吗?
您的数据似乎是 space 分隔的,而不是制表符分隔的。因此,当您指定 delimiter="\t"
时,整行将作为单个列读取。但是由于 usecols=range(0,7)
,NumPy 期望有七列,并在尝试迭代它们时抛出错误。
要解决此问题,请将数据中的白色 space 更改为制表符,或将定界符参数更改为 delimiter=" "
。