关于DataLoader,DataSet,Sampler
关于DataLoader,DataSet,Sampler
自上而下理解三者关系
DataLoader.__next__的源码:
class DataLoader(object):
...
def __next__(self):
if self.num_workers == 0:
indices = next(self.samper_iter) # Sampler
batch = self.collate_fn([self.dataset[i] for i in indices]) # Dataset
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
每张图片对应一个index,即上图的indices;Sampler完成选取index的方式;拿到indices后,根据index对数据进行读取即可.
数据读取
- 读哪些数据? Sampler输出的Index
- 从哪读数据? Dataset中的data_dir
- 怎么读数据? Dataset中的getitem
Dataset
Pytorch支持两种类型的数据集Map-style Dataset和Iterable-style Dataset,提供表示数据集的抽象类,任何自定义的Dataset都需要继承该类并覆盖相关方法
Map-style Dataset
- 需要继承
torch.utils.data.Dataset - 需要重写两个方法
__getitem__(self, index)__len__(self)
- 本质上构建了index到data的映射,
dataset[idx]返回数据集中第idx个item- idx可以不是int类型
len(dataset)返回数据集的大小
lterable-style Dataset
- 需要继承
torch.utils.data.IterableDataset - 需要重写一个方法
__iter__(self) - 本质上是一个可迭代对象,通过
next(dataset)调用__iter__(self)方法返回数据集的下一个item
自定义的Dataset如下:
from torch.utils.data import Dataset
class Dataset(Dataset):
def __init__(self):
...
def __getitem__(self):
...
def __len__(self):
return ...
__getitem__(self)- 是最主要的方法,规定了如何读取数据
- python built-in方法,主要作用是让该类可以像list一样通过索引值对数据进行访问
- 如果
__getitem__(self)方法每次读数据不仅仅返回img, label则需要自定义colloate_fn来对应合并成一个batch数据
Sampler
Sampler本质上是迭代器,用于产生数据集的索引值序列
查看DataLoader的源码,如下:
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=default_collate,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
初始化参数里有两种sampler:sampler和batch_sampler,都默认为None.
sampler:生成一系列的index.
batch_sampler:将sampler生成的indices打包分组,得到一个又一个batch的index.
DataLoader的部分初始化参数之间存在互斥关系:
- 如果自定义了
batch_sampler,那么这些参数必须使用默认值:batch_size,shuffle,sampler
drop_last
-
如果自定义了
sampler,那么shuffle需要设置为False -
如果
sampler和batch_sampler都为None,则batch_sampler使用Pytorch已经定义好的BatchSampler(BatchSampler 将 Sampler 采样得到的索引值进行合并,当数量等于一个 batch 大小后就将这一批的索引值返回),而sampler分两种情况:if shuffle = true,则sampler=RandomSampler(dataset)if shuffle = False,则sampler=SequentialSampler(dataset)
DataLoder
DataLoader有两种模式Automatic batching 和 Disable automatic batching
- 当
batch_size和drop_last均为None的时候,使用Disable automatic batching模式 - 否则使用Automatic batching
Automatic batching 的处理逻辑可以简化为:
sampler采样datasetbatch_sampler依次将sampler采样得到的indices进行合并,当数量等于batch_size时,将这个batch的indices返回.drop_last决定是否丢弃最后不足一个batch的部分.- DataLoder依次按照
batch_sampler提供的batch indices将数据从dataset中读出,传给collate_fn进行整理,返回Tensor
# map-style dataset
for indices in batch_sampler:
yield collate_fn([dataset[i] for i in indices])
# iterable-style dataset
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
Disable automatic batching 的处理逻辑可以简化为:
sampler采样dataset- DataLoder依次按照
batch_sampler提供的batch indices将数据从dataset中读出,传给collate_fn进行整理,返回Tensor
# map-style dataset
for index in sampler:
yield collate_fn(dataset[index])
# iterable-style dataset
for data in iter(dataset):
yield collate_fn(data)