Build my Dataset in Pytorch.
在使用Pytorch进行项目时,有时候需要读入自己的数据作为训练集和测试集,并按照自己指定的方式和格式处理。
Pytorch定义了Dataset类,在实际使用中可以通过继承Dataset类来构建数据集:
from torch.utils.data import Dataset, DataLoader
class myData(Dataset):
def __init__(self):
self.all_data = [] # 用于存放所有的数据
for i in range(N): # 遍历所有数据
self.all_data.append([x, y]) # 将一个样本和标签为一组存放进去
def __getitem__(self, index): # 返回一个样本和标签
return self.all_data[index][0], self.all_data[index][1]
def __len__(self): # 返回所有样本的数目
return len(self.all_data)
定义数据集后,通过标准类实例化可以创建并加载数据:
myDataSet = myData() # 实例化自己构建的数据集
train_loader = DataLoader(dataset=myDataSet, batch_size=BATCH_SIZE, shuffle=False)
创建数据集后,通过枚举获得数据并使用:
for iter, (data, label) in enumerate(train_loader):
print(data.shape)
print(label.shape)