《Pytorch学习指南》- Dataset和Dataloader用法详解

《Pytorch学习指南》- Dataset和Dataloader用法详解

Scroll Down

前言

本章节主要介绍如何使用torch.utils.data 中的Dataset和Dataloader来构建数据集, 重点要看使用细节

DataSet

  • torch.utils.data.Dataset
    • 功能 : Dataset抽象;类, 所有自定义的Dataset都需要继承他, 并重写相应的方法
    • getitem(self, index)
      1. 接收一个索引, 返回一个样本 : index => label, data
      2. 返回的样本的大小要一样

DataLoader

  • torch.utils.data.DataLoader
    • 功能 : 创建可以迭代的数据装载器
    • 参数 :
      1. dataset : Dataset类对象, 决定数据从哪读取以及如何读取
      2. batchsize: 决定数据批次大小
      3. num_works: 多进程读取数据的线程数
      4. shuffle: 每个 epoch 是否乱序
      5. 当样本数不能被batchsize整除时, 是否舍去最后一个batch的数据
    • 名词解释 :
      1. 样本总数 : 80, batchsize : 8 => 1 Epoch = 10 iteration

数据构建

1. 创建Dataset 类 ✨

class WeiBoDataset(Dataset):
	pass

2. 读取数据 🚑

注意 : 我们一般会在初始化的时候就加载进数据, 读取数据函数需要自定义

class WeiBoDataset(Dataset):

    def __init__(self, data_path):
        # 读取数据
        self.label, self.data = self.read_data(data_path)

3. 返回数据 ⚡

  • 这里需要注意的是, len 是必须要设置的, 返回的是你数据集的大小
  • 根据返回的len来构建索引, 然后把构建好的索引传入__getitem__里
  • getitem 根据传进来的索引获取对应的数据, 可以在这个方法里对数据进行处理
class WeiBoDataset(Dataset):

    def __init__(self, data_path):
        # 读取数据
        self.label, self.data = self.read_data(data_path)

    def __len__(self):
        """
            这个必须要设置, getitem中的index就是根据这个来设置的
        :return:
        """
        return len(self.data)

    def __getitem__(self, index):
        label = 1
        # features = [str(i) for i in range(10)]
        features = np.array([i for i in range(10)])
        return label, features

读取数据 🎨

weibo_dataset=WeiBoDataset("../../datasets/weibo_test_data.csv)
dataloader=DataLoader(weibo_dataset,batch_size=1024,shuffle=True)
for i, batch in enumerate(dataloader):
	# batch : [label, features] 组成
    print(type(batch[0]), type(batch[1]))

注意细节 🚀

  1. 先获取数据集的大小 len
  2. 根据len生成index, 然后shuffle
  3. 根据shuffle后的数据以及batch_size生成索引列表batch_index, 索引列表的大小为 batch_size
  4. 获取每个batch的数据时, 根据batch_index传入到 getitem 获取对应的数据
  5. 注意 : batch的数据类型取决于__getitem__返回的类型, 一般都会转换为tensor
  6. 有的数据类型是无法转换为tensor的, 比如 元素类型为str的list
  7. default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists found
  8. 上面报错原因就是 因为数据无法转换为 tensor , 而类型又不属于 tensors, numpy arrays, numbers, dicts or lists 这几种
  9. 如果返回的数据是集合类型, 可以直接使用 np.array() 转换为ndarray类型, 这样会被自动转换为tensor, 当然要求这个集合类型的元素类型是tensor有的
  10. 如果是tensor没有的,比如 str 类型的, 反而会报错, 比如 7. 报错

对比实验

注意 features的元素类型是str, 那么可以看到下面的输出结果中 label 是 tensor, features 是 list类型的

def __getitem__(self, index):
	label = 1
	# 转换为 ndarray 会报错
	# features = np.array([str(i) for i in range(10)]) 
    features = [str(i) for i in range(10)]
    return label, features
<class 'torch.Tensor'> <class 'list'>
<class 'torch.Tensor'> <class 'list'>
<class 'torch.Tensor'> <class 'list'>
<class 'torch.Tensor'> <class 'list'>

下面将feature中的数据元素换成了int类型的, 并且对将list转换为ndarray, 这样在获取batch时数据会自动转换为tensor , 但是这里需要注意的是, 上面的数据是不能用np.array()的, 这是因为 batch 必须包含 tensors, numpy arrays, numbers, dicts or lists 这几种类型, 其他的都会报错, 具体可以查看

def __getitem__(self, index):
	label = 1
    features = np.array([i for i in range(10)])
    return label, features
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>
<class 'torch.Tensor'> <class 'torch.Tensor'>