DataLoader与DataSet

1. 为什么要DataSet和DataLoader

image-20200721221026985

相比一次喂入全部数据的方法,我们更希望使用小批次的方法。把一轮数据全部训练一次称为一个Epoch,一个Epoch分为若干次Iteration,每次Iteration的数据量为一个Batch-Size。通过这种方法,可以有效地加快训练速度,避免显卡内存爆炸。

使用DataSet和DataLoader可以帮助我们进行小批次训练。

image-20200722135852959

2. DataLoader

DataLoader位于torch.utils.data.DataLoader。Data loader包括一个Sampler和一个DataSet,它为给定的数据集提供迭代功能。

1
2
3
4
5
6
7
8
9
10
11
torch.utils.data.DataLoader(dataset, batch_size=1, 
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)

功能:构建可迭代的数据装载器

  • dataset: Dataset类,需要自己自定义一个dataset类,继承于torch.utils.data.Dataset
  • batchsize : 批大小
  • num_works: 是否多进程读取数据
  • shuffle: 每个epoch是否乱序
  • drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
image-20200722140219507

DataLoader主要目的是实现一个迭代器,使得可以用for循环获取自动获取数据,并记录Epoch,Iteration和Batch-Size,使得程序不用关心怎么取出小批次的问题。

image-20200722141215096

3. DataSet

Dataset位于torch.utils.data.Dataset。DataSet为一个抽象类,不能实例化,只能通过自定义一个继承DataSet的子类来实现DataSet的方法。

所有表示索引到数据样本映射的dataset都应该是Dataset的子类。所有的子类都应该重写__getitem__()方法,以获得给出索引相对应的数据样本。还需要重写 __len__()方法,这样就可以通过 Sampler 的方法以及 DataLoader返回dataset的大小,以帮助构建batch。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import torch
import os
import random
from PIL import Image
from torch.utils.data import Dataset

rmb_label = {"1": 0, "100": 1}

class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform

def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255

if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等

return img, label

def __len__(self):
return len(self.data_info)

@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = rmb_label[sub_dir]
data_info.append((path_img, int(label)))

return data_info

4. 深入探索

完成了dataset的继承自定义类RMBDataset,就可以对RMBDataset实例化。将RMBDataset的实例传入参数,创建DataLoader。

1
2
3
4
5
6
7
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

然后就可以在训练的for循环中实现小批量训练,每次for i, data in enumerate(train_loader)都可以获得一个bath的数据和label。

1
2
3
4
for epoch in range(MAX_EPOCH):
...
for i, data in enumerate(train_loader):
...

image-20200722141750525

实验:

  1. 在训练for循环读数据的地方打断点,进入train_loader。image-20200723182320156

  2. 跳转到dataloader.py的class DataLoader(object)def __iter__(self),选择是单线程还是多线程读取。我们这里配置的是单线程。

    image-20200723182605399

  3. 设置好读取方式以后,会跳转到dataloader.py的class _BaseDataLoaderIter(object)def __next__(self),就开始真正读数据了。返回的是data。

    image-20200723191149366

  4. 我们进入self._next_data,观察data是如何给出的。会跳转到dataloader.py的class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)_next_data(self),这里用self._next_index()获得了index,然后利用 self._dataset_fetcher.fetch(index)获得数据data。我们再看看这个_next_data(self)self._dataset_fetcher.fetch(index)的工作机制。

    image-20200723191437645

  5. 进入_next_data(self),跳转到dataloader.py的class _BaseDataLoaderIter(object)_next_index(self)。发现这里很简单直接return了self._sampler_iter这个迭代器的next值。再进去看看吧。

    image-20200723192056360

  6. 进入self._sampler_iter,跳转到sampler.py的class BatchSampler(Sampler)__iter__(self),返回一个batch的生成器,用于生成索引的index,具体就不进去看了。我们接下来看看self._dataset_fetcher.fetch(index),跳出跳出。

    image-20200723192447473

  7. 退回到dataloader.py的class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)_next_data(self),发现index已经给出了。我们的Batch-Size等于16,给出了16个index。

    image-20200723193307928

  8. 进入self._dataset_fetcher.fetch(index),跳转到fetch.py的class _MapDatasetFetcher(_BaseDatasetFetcher)def fetch(self, possibly_batched_index)。可以看到这里用了一个列表推导式,将传入的index列表的每个元素idx在self.dataset[idx]中索引,得到data的列表。我们进入self.dataset[idx]看看。

    image-20200723193900452

  9. 进入self.dataset[idx],跳转到my_dataset.py的class RMBDataset(Dataset)__getitem__(self, index),这正是我们通过继承dataset自定义的数据集。__getitem__(self, index)真是我们重写的魔术方法,用于从数据集中根据索引index取出数据。(可以看到,第一个要取的index为86,正是我们之前得到的一个batch的第一个index)

    image-20200723194632775

  10. self.data_info[index]是我们在class RMBDataset(Dataset)__init__(self, data_dir, transform=None)中定义的一个属性。它通过我们自定义的self.get_img_info(data_dir)的函数来获取。继续看看self.get_img_info(data_dir)

    image-20200723195147697

  11. 静态方法声明@staticmethod使得可以在类里面的__init__(self, data_dir, transform=None)中直接调用。

    image-20200723195613421

    稍微看一下返回的data_info,就是一个包含了(照片路径,类别)的列表。

    image-20200723195759668
  12. 返回my_dataset.py的class RMBDataset(Dataset)__getitem__(self, index)。通过index索引self.data_info[index]列表,并对里面的元素(二元组)拆包,得到路径和label。然后通过from PIL import Image导入的模块Image对图片路径进行读取。这里的好处显而易见,每次只读一张图片,并不是直接全部导入的,一个batch完成以后会释放,不然显卡内存会炸的哦。最后经过变换和tensor转换得到img,返回一个图片和相应的label。

    image-20200723194632775

  13. 退回self._dataset_fetcher.fetch(index),结束对self.dataset[idx]的调用,得到数据集转成的tensor列表,我们稍微看看这个列表的情况:

    image-20200723204250287

    可以看出,这是一个tensor列表,长度是16,对应batch-size。每一个元素都是二元组,对应彩色三通道的图像和标签。最后我们再看看self.collate_fn(data)

    image-20200723204101111

  14. 经过self.collate_fn(data),可以将data列表打包为一个batch,变成了含有两个元素的列表,第一个元素是4维的tensor列表,是16个三维tensor列表在dim0的拼接,第二个元素是类别,数据类型也是tensor。

    image-20200723205807035

  15. 最后,我们回到for循环。可以看到,已经完成了batch的操作。

    image-20200723210131097