PyTorch 小课堂开课啦!带你解析数据处理全流程(一)

本文涉及的产品
云解析 DNS,旗舰版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
全局流量管理 GTM,标准版 1个月
简介: OK,在正式解析 PyTorch 中的 torch.utils.data 模块之前,我们需要理解一下 Python 中的迭代器(Iterator),因为在源码的 Dataset, Sampler 和 DataLoader 这三个类中都会用到包括 __len__(self),__getitem__(self) 和 __iter__(self) 的抽象类的魔法方法。

640.png

一张图带你看懂全文


最近被迫开始了居家办公,这不,每天认真工(mo)作(yu)之余,也有了更多时间重新学习分析起了 PyTorch 源码分享,属于是直接站在巨人的肩膀上了。在简单捋一捋思路之后,就从 torch.utils.data 数据处理模块开始,一步步重新学习 PyTorch 的一些源码模块解析,希望也能让大家重新认识已经不陌生的 PyTorch 这个小伙伴。

640.gif

1. 迭代器介绍



OK,在正式解析 PyTorch 中的 torch.utils.data 模块之前,我们需要理解一下 Python 中的迭代器(Iterator),因为在源码的 Dataset, Sampler 和 DataLoader 这三个类中都会用到包括 __len__(self),__getitem__(self) 和 __iter__(self) 的抽象类的魔法方法。


· __len__(self):定义当被 len() 函数调用时的行为,一般返回迭代器中元素的个数。


· __getitem__(self):定义获取容器中指定元素时的行为,相当于 self[key] ,即允许类对象拥有索引操作。


· __iter__(self):定义当迭代容器中的元素时的行为。


除此之外,我们也需要清楚两个概念:


· 迭代(Iteration):当我们用一个循环(比如 for 循环)来遍历容器(比如列表,元组)中的元素时,这种遍历的过程可称为迭代。


· 可迭代对象(Iterable):一般指含有 __iter__() 方法或 __getitem__() 方法的对象。我们通常接触的数据结构,如序列(列表、元组和字符串)还有字典等,都支持迭代操作,也可称为可迭代对象。


那什么是迭代器(Iterator)呢?简而言之,迭代器就是一种可以被遍历的容器类对象,但它又比较特别,它需要遵循迭代器协议,那什么又是迭代器协议呢?迭代器协议(iterator protocol)是指要实现对象的__iter()____next__() 方法。一个容器或者类如果是迭代器,那么就必须实现 __iter__() 方法以及重点实现 __next__() 方法,前者会返回一个迭代器(通常是迭代器对象本身),而后者决定了迭代的规则。现在,为更好地理解迭代器的内部运行机制,我们可以看一个斐波那契数列的迭代器实现例子:

class Fibs:
    def __init__(self, n=20):
        self.a = 0
        self.b = 1
        self.n = n
    def __iter__(self):
        return self
    def __next__(self):
        self.a, self.b = self.b, self.a + self.b
        if self.a > self.n:
            raise StopIteration
        return self.a
fibs = Fibs()
for each in fibs:
    print(each)
# 输出 
# 1 1 2 3 5 8 13

一般而言,迭代器满足以下几种特性:


· 迭代器是⼀个对象,但比较特别,需要满足迭代器协议,他还可以被 for 语句循环迭代直到终⽌。


· 迭代器可以被 next() 函数调⽤,并返回⼀个值,亦可以被 iter() 函数调⽤,但返回的是一个迭代器(可以是自身)。


· 迭代器连续被 next() 函数调⽤时,依次返回⼀系列的值,但如果到了迭代的末尾,则抛出 StopIteration 异常,另外他可以没有末尾,但只要被 next() 函数调⽤,就⼀定会返回⼀个值。


· Python3 中, next() 内置函数调⽤的是对象的 __next__() ⽅法,iter() 内置函数调⽤的是对象的 __iter__() ⽅法。


那么,了解了什么是迭代器后,我们马上开始解析 torch.utils.data 模块,对于 torch.utils.data 而言,重点是其 Dataset,Sampler,DataLoader 三个模块,辅以 collate,fetch,pin_memory 等组件对特定功能予以支持。


Tips:涉及的源码皆以 PyTorch 1.7 为准。


2. Dataset



Dataset 主要负责对 raw data source 封装,将其封装成 Python 可识别的数据结构,其必须提供提取数据个体的接口。Dataset 中共有 Map-style datasets 和 Iterable-style datasets 两种:


1.1 Map-style dataset


torch.utils.data.Dataset 它是一种通过实现  __len__()__getitem__()方法来获取数据的 Dataset,它表示从(可能是非整数)索引/关键字到数据样本的映射。因而,在我们访问 Map-style 的数据集时,使用 dataset[idx] 即可访问 idx 对应的数据。通常,我们使用 Map-style 类型的 dataset 居多,可以看到其数据接口定义如下:

class Dataset(Generic[T_co]):
    # Generic is an Abstract base class for generic types.
    def __getitem__(self, index) -> T_co:
        raise NotImplementedError
    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

在 PyTorch 1.7 源码中所有定义的 Dataset 都是其子类,而对于一般计算机视觉任务,我们通常也会在其中进行一些 resize,crop,flip 等预处理的操作。


值得一提的是,PyTorch 源码中并没有提供默认的 __len__() 方法实现,原因是 return NotImplemented 或者 raise NotImplementedError() 之类的默认实现都会存在各自的问题,这点我们在源码 pytorch/torch/utils/data/sampler.py 中的注释也可以得到解释。


1.2 Iterable-style dataset


torch.utils.data.IterableDataset 它是一种实现 __iter__() 来获取数据的 Dataset,Iterable-style 的数据集特别适用于以下情况:随机读取代价很大甚至不可能,且 batch size 取决于获取到的数据。其接口定义如下:

class IterableDataset(Dataset[T_co]):
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError
    def __add__(self, other: Dataset[T_co]):
        return ChainDataset([self, other])
    # No `def __len__(self)` default? Subclasses raise `TypeError` when needed.
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]


特别地
,当 DataLoader 的 num_workers > 0 时, 每个 worker 都将具有数据对象的不同样本。因此需要独立地对每个副本进行配置,以防止每个 worker 产生的数据不重复。同时,数据加载顺序完全由用户定义的可迭代样式控制。这允许更容易地实现块读取和动态批次大小(例如,通过每次产生一个批次的样本)。

1.3 其他 Dataset


除了 Map-style dataset 和 Iterable-style dataset 以外,PyTorch 也在此基础上提供了其他类型的 Dataset 子类:


· torch.utils.data.ConcatDataset:用于连接多个 ConcatDataset 数据集。


· torch.utils.data.ChainDataset:用于连接多个 IterableDataset 数据集,在 IterableDataset 的 __add__() 方法中被调用。


· torch.utils.data.Subset:用于获取指定一个索引序列对应的子数据集。

class Subset(Dataset[T_co]):
    dataset: Dataset[T_co]
    indices: Sequence[int]
    def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
        self.dataset = dataset
        self.indices = indices
    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]
    def __len__(self):
        return len(self.indices)

· torch.utils.data.TensorDataset:用于获取封装成 tensor 的数据集,每一个样本都可通过索引张量来获得。

class TensorDataset(Dataset):
    def __init__(self, *tensor):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in tensors
    def __len__(self):
        return self.tensors[0].size(0)


3. Sampler



torch.utils.data.Sampler 主要负责提供一种遍历数据集所有元素索引的方式。可支持我们自定义,也可以使用 PyTorch 本身提供的,其基类接口定义如下:

lass Sampler(Generic[T_co]):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
    way to iterate over indices of dataset elements, and a :meth:`__len__` method
    that returns the length of the returned iterators.
    .. note:: The :meth:`__len__` method isn't strictly required by
              :class:`~torch.utils.data.DataLoader`, but is expected in any
              calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
    """
    def __init__(self, data_source: Optional[Sized]) -> None:
        pass
    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

特别地__len()__ 方法虽不是必要的,但是当 DataLoader 需要计算 length 的时候必须定义,这点在源码中也有注释加以体现。


同样,PyTorch 也在此基础上提供了其他类型的 Sampler 子类:


· torch.utils.data.SequentialSampler:顺序采样样本,始终按照同一个顺序。


· torch.utils.data.RandomSampler:可指定有无放回地,进行随机采样样本元素。


· torch.utils.data.SubsetRandomSampler:无放回地按照给定的索引列表采样样本元素。


· torch.utils.data.WeightedRandomSampler:按照给定的概率来采样样本。样本元素来自 [0,…,len(weights)-1] ,给定概率(权重)。


· torch.utils.data.BatchSampler:在一个 batch 中封装一个其他的采样器, 返回一个 batch 大小的 index 索引。


· torch.utils.data.DistributedSample:将数据加载限制为数据集子集的采样器。与 torch.nn.parallel.DistributedDataParallel 结合使用。在这种情况下,每个进程都可以将 DistributedSampler 实例作为 DataLoader 采样器传递。


4. DataLoader


torch.utils.data.DataLoader 是 PyTorch 数据加载的核心,负责加载数据,同时支持 Map-style 和 Iterable-style Dataset,支持单进程/多进程,还可以通过参数设置如 sampler, batch size, pin memory 等自定义数据加载顺序以及控制数据批处理功能。其接口定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)

对于每个参数的含义,下面通过一个表格进行直观地介绍:

640.png

从参数定义中,我们可以看到 DataLoader 主要支持以下几个功能:


· 支持加载 map-style 和 iterable-style 的 dataset,主要涉及到的参数是 dataset。


· 自定义数据加载顺序,主要涉及到的参数有 shuffle,sampler,batch_sampler,collate_fn。


· 自动把数据整理成batch序列,主要涉及到的参数有 batch_size,batch_sampler,collate_fn,drop_last。


· 单进程和多进程的数据加载,主要涉及到的参数有 num_workers,worker_init_fn。


· 自动进行锁页内存读取 (memory pinning),主要涉及到的参数 pin_memory。


· 支持数据预加载,主要涉及的参数 prefetch_factor。


3.1 批处理

3.1.1 自动批处理(默认)


DataLoader 支持通过参数 batch_size, drop_last, batch_sampler,自动地把取出的数据整理(collate)成批次样本(batch),其中 batch_size 和 drop_last 参数用于指定 DataLoader 如何获取 dataset 的 key。特别地,对于 map-style 类型的 dataset,用户可以选择指定 batch_sample 参数,一次就生成一个 keys list。


在使用 sampler 产生的 indices 获取采样到的数据时,DataLoader 使用 collate_fn 参数将样本列表整理成 batch。抽象整个过程,其表示方式大致如下:

# For Map-style
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])
# For Iterable-style
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])


3.1.2 关闭自动批处理


当我们想用 dataset 代码手动处理 batch,或仅加载单个 sample data 时,可将 batch_size 和 batch_sampler 设为 None, 将关闭自动批处理。此时,由 Dataset 产生的 sample 将会直接被 collate_fn 处理。抽象整个过程,其表示方式大致如下:

# For Map-style
for index in sampler:
    yield collate_fn(dataset[index])
# For Iterable-style
for data in iter(dataset):
    yield collate_fn(data)


3.1.3 collate_fn


当关闭自动批处理 (automatic batching) 时,collate_fn 作用于单个数据样本,只是在 PyTorch 张量中转换 NumPy 数组。


而当开启自动批处理 (automatic batching) 时,collate_fn 作用于数据样本列表,将输入样本整理为一个 batch,一般做下面 3 件事情:


· 添加新的批次维度(一般是第一维)。


· 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。


· 它保留数据结构,例如,如果每个样本都是 dict,则输出具有相同键集但批处理过的张量作为值的字典(或 list,当数据类型不能转换的时候)。这在 list,tuples,namedtuples 同样适用。


自定义 collate_fn 可用于自定义排序规则,例如,将顺序数据填充到批处理的最大长度,添加对自定义数据类型的支持等。


5. 三者关系



通过以上解析的三者工作内容,不难可以推出其内在关系:


1)设置 Dataset,将数据 data source 包装成 Dataset 类,暴露出提取接口。


2)设置 Sampler,决定采样方式。我们虽然能从 Dataset 中提取元素了,但还是需要设置 Sampler 告诉程序提取 Dataset 的策略。


3)将设置好的 Dataset 和 Sampler 传入 DataLoader,同时可以设置 shuffle,batch_size 等参数。使用 DataLoader 对象可以方便快捷地在数据集上遍历。


至此我们就可以了解到了 Dataset,Sampler,Dataloader 三个类的基本定义以及对应实现功能,同时也介绍了批处理对应参数组件。总结来说,我们需要记得的是三点,即 Dataloader 负责总的调度,命令 Sampler 定义遍历索引的方式,然后用索引去 Dataset 中提取元素。于是就实现了对给定数据集的遍历。


文章来源:【OpenMMLab

2022-04-08 18:05

目录
相关文章
|
1月前
|
缓存 前端开发 中间件
[go 面试] 前端请求到后端API的中间件流程解析
[go 面试] 前端请求到后端API的中间件流程解析
|
1月前
|
人工智能 PyTorch 算法框架/工具
Xinference实战指南:全面解析LLM大模型部署流程,携手Dify打造高效AI应用实践案例,加速AI项目落地进程
【8月更文挑战第6天】Xinference实战指南:全面解析LLM大模型部署流程,携手Dify打造高效AI应用实践案例,加速AI项目落地进程
Xinference实战指南:全面解析LLM大模型部署流程,携手Dify打造高效AI应用实践案例,加速AI项目落地进程
手机上网流程解析
【9月更文挑战第5天】
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
PyTorch提供了几种张量乘法的方法,每种方法都是不同的,并且有不同的应用。我们来详细介绍每个方法,并且详细解释这些函数有什么区别:
34 4
掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
|
23天前
|
测试技术 持续交付 UED
|
19天前
|
持续交付 jenkins Devops
WPF与DevOps的完美邂逅:从Jenkins配置到自动化部署,全流程解析持续集成与持续交付的最佳实践
【8月更文挑战第31天】WPF与DevOps的结合开启了软件生命周期管理的新篇章。通过Jenkins等CI/CD工具,实现从代码提交到自动构建、测试及部署的全流程自动化。本文详细介绍了如何配置Jenkins来管理WPF项目的构建任务,确保每次代码提交都能触发自动化流程,提升开发效率和代码质量。这一方法不仅简化了开发流程,还加强了团队协作,是WPF开发者拥抱DevOps文化的理想指南。
39 1
|
21天前
|
机器学习/深度学习 PyTorch 数据处理
PyTorch数据处理:torch.utils.data模块的7个核心函数详解
在机器学习和深度学习项目中,数据处理是至关重要的一环。PyTorch作为一个强大的深度学习框架,提供了多种灵活且高效的数据处理工具
15 1
|
22天前
|
机器学习/深度学习 算法 PyTorch
PyTorch Lightning:简化研究到生产的工作流程
【8月更文第29天】深度学习项目往往面临着从研究阶段到生产部署的挑战。研究人员和工程师需要处理大量的工程问题,比如数据加载、模型训练、性能优化等。PyTorch Lightning 是一个轻量级的封装库,旨在通过减少样板代码的数量来简化 PyTorch 的使用,从而让开发者更专注于算法本身而不是工程细节。
43 1
|
24天前
|
缓存 运维 Linux
深入解析:一步步掌握 CentOS 7 安装全流程及运维实战技巧
深入解析:一步步掌握 CentOS 7 安装全流程及运维实战技巧
|
11天前
|
缓存 网络协议 Linux
DNS的执行流程是什么?
DNS的执行流程是什么?
23 0