DataLoader与DataSet
1. 为什么要DataSet和DataLoader
相比一次喂入全部数据的方法,我们更希望使用小批次的方法。把一轮数据全部训练一次称为一个Epoch,一个Epoch分为若干次Iteration,每次Iteration的数据量为一个Batch-Size。通过这种方法,可以有效地加快训练速度,避免显卡内存爆炸。
使用DataSet和DataLoader可以帮助我们进行小批次训练。
2. DataLoader
DataLoader位于torch.utils.data.DataLoader。Data loader包括一个Sampler和一个DataSet,它为给定的数据集提供迭代功能。
1 | torch.utils.data.DataLoader(dataset, batch_size=1, |
功能:构建可迭代的数据装载器
- dataset: Dataset类,需要自己自定义一个dataset类,继承于torch.utils.data.Dataset
- batchsize : 批大小
- num_works: 是否多进程读取数据
- shuffle: 每个epoch是否乱序
- drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
DataLoader主要目的是实现一个迭代器,使得可以用for循环获取自动获取数据,并记录Epoch,Iteration和Batch-Size,使得程序不用关心怎么取出小批次的问题。
3. DataSet
Dataset位于torch.utils.data.Dataset。DataSet为一个抽象类,不能实例化,只能通过自定义一个继承DataSet的子类来实现DataSet的方法。
所有表示索引到数据样本映射的dataset都应该是Dataset的子类。所有的子类都应该重写__getitem__()
方法,以获得给出索引相对应的数据样本。还需要重写 __len__()
方法,这样就可以通过 Sampler
的方法以及 DataLoader
返回dataset的大小,以帮助构建batch。
1 | import numpy as np |
4. 深入探索
完成了dataset的继承自定义类RMBDataset,就可以对RMBDataset实例化。将RMBDataset的实例传入参数,创建DataLoader。
1 | # 构建MyDataset实例 |
然后就可以在训练的for循环中实现小批量训练,每次for i, data in enumerate(train_loader)
都可以获得一个bath的数据和label。
1 | for epoch in range(MAX_EPOCH): |
实验:
-
在训练for循环读数据的地方打断点,进入train_loader。
-
跳转到dataloader.py的
class DataLoader(object)
的def __iter__(self)
,选择是单线程还是多线程读取。我们这里配置的是单线程。 -
设置好读取方式以后,会跳转到dataloader.py的
class _BaseDataLoaderIter(object)
的def __next__(self)
,就开始真正读数据了。返回的是data。 -
我们进入
self._next_data
,观察data是如何给出的。会跳转到dataloader.py的class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)
的_next_data(self)
,这里用self._next_index()
获得了index,然后利用self._dataset_fetcher.fetch(index)
获得数据data。我们再看看这个_next_data(self)
和self._dataset_fetcher.fetch(index)
的工作机制。 -
进入
_next_data(self)
,跳转到dataloader.py的class _BaseDataLoaderIter(object)
的_next_index(self)
。发现这里很简单直接return了self._sampler_iter
这个迭代器的next值。再进去看看吧。 -
进入
self._sampler_iter
,跳转到sampler.py的class BatchSampler(Sampler)
的__iter__(self)
,返回一个batch的生成器,用于生成索引的index,具体就不进去看了。我们接下来看看self._dataset_fetcher.fetch(index)
,跳出跳出。 -
退回到dataloader.py的
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)
的_next_data(self)
,发现index已经给出了。我们的Batch-Size等于16,给出了16个index。 -
进入
self._dataset_fetcher.fetch(index)
,跳转到fetch.py的class _MapDatasetFetcher(_BaseDatasetFetcher)
的def fetch(self, possibly_batched_index)
。可以看到这里用了一个列表推导式,将传入的index列表的每个元素idx在self.dataset[idx]
中索引,得到data的列表。我们进入self.dataset[idx]
看看。 -
进入
self.dataset[idx]
,跳转到my_dataset.py的class RMBDataset(Dataset)
的__getitem__(self, index)
,这正是我们通过继承dataset自定义的数据集。__getitem__(self, index)
真是我们重写的魔术方法,用于从数据集中根据索引index取出数据。(可以看到,第一个要取的index为86,正是我们之前得到的一个batch的第一个index) -
self.data_info[index]
是我们在class RMBDataset(Dataset)
的__init__(self, data_dir, transform=None)
中定义的一个属性。它通过我们自定义的self.get_img_info(data_dir)
的函数来获取。继续看看self.get_img_info(data_dir)
-
静态方法声明
@staticmethod
使得可以在类里面的__init__(self, data_dir, transform=None)
中直接调用。稍微看一下返回的data_info,就是一个包含了(照片路径,类别)的列表。
-
返回my_dataset.py的
class RMBDataset(Dataset)
的__getitem__(self, index)
。通过index索引self.data_info[index]
列表,并对里面的元素(二元组)拆包,得到路径和label。然后通过from PIL import Image
导入的模块Image对图片路径进行读取。这里的好处显而易见,每次只读一张图片,并不是直接全部导入的,一个batch完成以后会释放,不然显卡内存会炸的哦。最后经过变换和tensor转换得到img,返回一个图片和相应的label。 -
退回
self._dataset_fetcher.fetch(index)
,结束对self.dataset[idx]
的调用,得到数据集转成的tensor列表,我们稍微看看这个列表的情况:可以看出,这是一个tensor列表,长度是16,对应batch-size。每一个元素都是二元组,对应彩色三通道的图像和标签。最后我们再看看
self.collate_fn(data)
-
经过
self.collate_fn(data)
,可以将data列表打包为一个batch,变成了含有两个元素的列表,第一个元素是4维的tensor列表,是16个三维tensor列表在dim0的拼接,第二个元素是类别,数据类型也是tensor。 -
最后,我们回到for循环。可以看到,已经完成了batch的操作。