关于DataLoader,DataSet,Sampler

自上而下理解三者关系

DataLoader.__next__的源码:

class DataLoader(object):
	...
    
    def __next__(self):
        if self.num_workers == 0:
            indices = next(self.samper_iter) # Sampler
            batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
            if self.pin_memory:
                batch = _utils.pin_memory.pin_memory_batch(batch)
               return batch

每张图片对应一个index,即上图的indicesSampler完成选取index的方式;拿到indices后,根据index对数据进行读取即可.

数据读取
  • 读哪些数据? Sampler输出的Index
  • 从哪读数据? Dataset中的data_dir
  • 怎么读数据? Dataset中的getitem
Dataset

Pytorch支持两种类型的数据集Map-style Dataset和Iterable-style Dataset,提供表示数据集的抽象类,任何自定义的Dataset都需要继承该类并覆盖相关方法

Map-style Dataset
  • 需要继承torch.utils.data.Dataset
  • 需要重写两个方法
    • __getitem__(self, index)
    • __len__(self)
  • 本质上构建了index到data的映射,dataset[idx]返回数据集中第idx个item
    • idx可以不是int类型
  • len(dataset)返回数据集的大小
lterable-style Dataset
  • 需要继承torch.utils.data.IterableDataset
  • 需要重写一个方法__iter__(self)
  • 本质上是一个可迭代对象,通过next(dataset)调用__iter__(self)方法返回数据集的下一个item

自定义的Dataset如下:

from torch.utils.data import Dataset
class Dataset(Dataset):
    def __init__(self):
        ...
    def __getitem__(self):
        ...
    def __len__(self):
        return ...
  • __getitem__(self)
    • 是最主要的方法,规定了如何读取数据
    • python built-in方法,主要作用是让该类可以像list一样通过索引值对数据进行访问
    • 如果__getitem__(self)方法每次读数据不仅仅返回img, label则需要自定义colloate_fn来对应合并成一个batch数据
Sampler

Sampler本质上是迭代器,用于产生数据集的索引值序列

查看DataLoader的源码,如下:

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

初始化参数里有两种sampler:samplerbatch_sampler,都默认为None.

sampler:生成一系列的index.

batch_sampler:将sampler生成的indices打包分组,得到一个又一个batch的index.

DataLoader的部分初始化参数之间存在互斥关系:

  • 如果自定义了batch_sampler,那么这些参数必须使用默认值:batch_size,shuffle,sampler

drop_last

  • 如果自定义了sampler,那么shuffle需要设置为False

  • 如果samplerbatch_sampler都为None,则batch_sampler使用Pytorch已经定义好的BatchSampler(BatchSampler 将 Sampler 采样得到的索引值进行合并,当数量等于一个 batch 大小后就将这一批的索引值返回),而sampler分两种情况:

    • if shuffle = true,则sampler=RandomSampler(dataset)
    • if shuffle = False,则sampler=SequentialSampler(dataset)
DataLoder

DataLoader有两种模式Automatic batching 和 Disable automatic batching

  • batch_sizedrop_last均为None的时候,使用Disable automatic batching模式
  • 否则使用Automatic batching

Automatic batching 的处理逻辑可以简化为:

  1. sampler采样dataset
  2. batch_sampler依次将sampler采样得到的indices进行合并,当数量等于batch_size时,将这个batch的indices返回.drop_last决定是否丢弃最后不足一个batch的部分.
  3. DataLoder依次按照batch_sampler提供的batch indices将数据从dataset中读出,传给collate_fn进行整理,返回Tensor
# map-style dataset
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

# iterable-style dataset
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

Disable automatic batching 的处理逻辑可以简化为:

  1. sampler采样dataset
  2. DataLoder依次按照batch_sampler提供的batch indices将数据从dataset中读出,传给collate_fn进行整理,返回Tensor
# map-style dataset
for index in sampler:
    yield collate_fn(dataset[index])

# iterable-style dataset
for data in iter(dataset):
    yield collate_fn(data)