# torch.utils.data > 原文:[`pytorch.org/docs/stable/data.html`](https://pytorch.org/docs/stable/data.html)> > > 译者:[飞龙](https://github.com/wizardforcel) > > 协议:[CC BY-NC-SA 4.0](http://creativecommons.org/licenses/by-nc-sa/4.0/) PyTorch 数据加载实用程序的核心是`torch.utils.data.DataLoader`类。它表示数据集的 Python 可迭代对象,支持 + 映射风格和可迭代风格数据集, + 自定义数据加载顺序, + 自动批处理, + 单进程和多进程数据加载, + 自动内存固定。 这些选项由`DataLoader`的构造函数参数配置,其签名为: ```py DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False) ``` 下面的部分详细描述了这些选项的效果和用法。 ## 数据集类型 `DataLoader`构造函数最重要的参数是`dataset`,表示要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集: + 映射风格数据集, + 可迭代风格数据集。 ### 映射风格数据集 映射风格数据集是实现了`__getitem__()`和`__len__()`协议的数据集,表示从(可能是非整数)索引/键到数据样本的映射。 例如,这样一个数据集,当使用`dataset[idx]`访问时,可以从磁盘上的文件夹读取第`idx`个图像及其对应的标签。 更多详情请参阅`Dataset`。 ### 可迭代风格数据集 可迭代风格数据集是`IterableDataset`子类的实例,实现了`__iter__()`协议,表示数据样本的可迭代。这种类型的数据集特别适用于随机读取昂贵甚至不太可能的情况,批量大小取决于获取的数据。 例如,这样一个数据集,当调用`iter(dataset)`时,可以返回从数据库、远程服务器甚至实时生成的日志中读取的数据流。 更多详情请参阅`IterableDataset`。 注意 当使用带有多进程数据加载的`IterableDataset`时。相同的数据集对象在每个工作进程上复制,因此必须对副本进行不同配置以避免重复数据。请参阅`IterableDataset`文档以了解如何实现此目的。 ## 数据加载顺序和`Sampler` 对于可迭代风格数据集,数据加载顺序完全由用户定义的可迭代对象控制。这允许更容易实现分块读取和动态批量大小(例如,每次产生一个批量样本)。 本节剩余部分涉及映射风格数据集的情况。`torch.utils.data.Sampler`类用于指定数据加载中使用的索引/键的顺序。它们表示数据集索引的可迭代对象。例如,在随机梯度下降(SGD)的常见情况下,`Sampler`可以随机排列索引列表并逐个产生每个索引,或者为小批量 SGD 产生少量索引。 基于传递给`DataLoader`的`shuffle`参数,将自动构建顺序或随机采样器。或者,用户可以使用`sampler`参数指定自定义`Sampler`对象,该对象在每次产生下一个索引/键以获取时。 可以将一次产生一批批次索引列表的自定义`Sampler`作为`batch_sampler`参数传递。也可以通过`batch_size`和`drop_last`参数启用自动批处理。有关此内容的更多详细信息,请参见下一节。 注意 既不`sampler`也不`batch_sampler`与可迭代样式数据集兼容,因为这类数据集没有键或索引的概念。 ## 加载批量和非批量数据[](#加载批量和非批量数据“链接到此标题”) `DataLoader`支持通过参数`batch_size`,`drop_last`,`batch_sampler`和`collate_fn`(具有默认函数)自动整理单独获取的数据样本为批次。 ### 自动批处理(默认)[](#自动批处理默认“链接到此标题”) 这是最常见的情况,对应于获取一批数据并将它们整理成批量样本,即包含张量的一个维度为批量维度(通常是第一个)。 当`batch_size`(默认为`1`)不是`None`时,数据加载器产生批量样本而不是单个样本。`batch_size`和`drop_last`参数用于指定数据加载器如何获取数据集键的批次。对于映射样式数据集,用户可以选择指定`batch_sampler`,它一次产生一个键列表。 注意 `batch_size`和`drop_last`参数本质上用于从`sampler`构建`batch_sampler`。对于映射样式数据集,`sampler`由用户提供或基于`shuffle`参数构建。对于可迭代样式数据集,`sampler`是一个虚拟的无限循环。有关采样器的更多详细信息,请参见此部分。 注意 从可迭代样式数据集中获取数据时,使用多处理,`drop_last`参数会丢弃每个工作进程数据集副本的最后一个非完整批次。 使用来自采样器的索引获取样本列表后,作为`collate_fn`参数传递的函数用于将样本列表整理成批次。 在这种情况下,从映射样式数据集加载大致相当于: ```py for indices in batch_sampler: yield collate_fn([dataset[i] for i in indices]) ``` 从可迭代样式数据集加载大致相当于: ```py dataset_iter = iter(dataset) for indices in batch_sampler: yield collate_fn([next(dataset_iter) for _ in indices]) ``` 可以使用自定义的`collate_fn`来自定义整理,例如,将顺序数据填充到批次的最大长度。有关`collate_fn`的更多信息,请参见此部分。 ### 禁用自动批处理[](#禁用自动批处理“链接到此标题”) 在某些情况下,用户可能希望在数据集代码中手动处理批处理,或者仅加载单个样本。例如,直接加载批量数据可能更便宜(例如,从数据库进行批量读取或读取连续的内存块),或者批处理大小取决于数据,或者程序设计为处理单个样本。在这些情况下,最好不使用自动批处理(其中`collate_fn`用于整理样本),而是让数据加载器直接返回`dataset`对象的每个成员。 当`batch_size`和`batch_sampler`都为`None`时(`batch_sampler`的默认值已经是`None`),自动批处理被禁用。从`dataset`获取的每个样本都会使用作为`collate_fn`参数传递的函数进行处理。 **当禁用自动批处理时**,默认的`collate_fn`只是将 NumPy 数组转换为 PyTorch 张量,并保持其他所有内容不变。 在这种情况下,从映射样式数据集加载大致等同于: ```py for index in sampler: yield collate_fn(dataset[index]) ``` 从可迭代样式数据集加载大致等同于: ```py for data in iter(dataset): yield collate_fn(data) ``` 查看关于`collate_fn`的更多信息此部分。 ### 使用`collate_fn` 当启用或禁用自动批处理时,使用`collate_fn`略有不同。 **当禁用自动批处理时**,`collate_fn`会对每个单独的数据样本进行调用,并且输出会从数据加载器迭代器中产生。在这种情况下,默认的`collate_fn`只是将 NumPy 数组转换为 PyTorch 张量。 **当启用自动批处理时**,`collate_fn`会在每次调用时传入一个数据样本列表。它预期将输入样本整理成一个批次以便从数据加载器迭代器中产生。本节的其余部分描述了默认`collate_fn`的行为(`default_collate()`)。 例如,如果每个数据样本由一个 3 通道图像和一个整数类标签组成,即数据集的每个元素返回一个元组`(image, class_index)`,默认的`collate_fn`将这样的元组列表整理成一个批量图像张量和一个批量类标签张量的单个元组。特别是,默认的`collate_fn`具有以下属性: + 它总是在批处理维度之前添加一个新维度。 + 它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。 + 它保留了数据结构,例如,如果每个样本是一个字典,它会输出一个具有相同键集的字典,但批量化的张量作为值(或者如果值无法转换为张量,则为列表)。对于`list`、`tuple`、`namedtuple`等也是一样的。 用户可以使用定制的`collate_fn`来实现自定义的批处理,例如,沿着不同于第一个维度进行整理,填充不同长度的序列,或者添加对自定义数据类型的支持。 如果您发现`DataLoader`的输出维度或类型与您的期望不同,您可能需要检查您的`collate_fn`。 ## 单进程和多进程数据加载 `DataLoader`默认使用单进程数据加载。 在 Python 进程中,[全局解释器锁(GIL)](https://wiki.python.org/moin/GlobalInterpreterLock)阻止了真正将 Python 代码在线程间完全并行化。为了避免用数据加载阻塞计算代码,PyTorch 提供了一个简单的开关,通过将参数`num_workers`设置为正整数来执行多进程数据加载。 ### 单进程数据加载(默认) 在此模式下,数据获取是在初始化`DataLoader`的同一进程中完成的。因此,数据加载可能会阻塞计算。但是,当用于在进程之间共享数据的资源(例如共享内存、文件描述符)有限时,或者整个数据集很小且可以完全加载到内存中时,可能更喜欢此模式。此外,单进程加载通常显示更易读的错误跟踪,因此对于调试很有用。 ### 多进程数据加载 将参数`num_workers`设置为正整数将打开具有指定数量的加载器工作进程的多进程数据加载。 警告 经过几次迭代,加载器工作进程将消耗与父进程中从工作进程访问的所有 Python 对象相同的 CPU 内存量。如果数据集包含大量数据(例如,在数据集构建时加载了非常大的文件名列表),和/或您使用了大量工作进程(总内存使用量为`工作进程数量 * 父进程大小`),这可能会有问题。最简单的解决方法是用 Pandas、Numpy 或 PyArrow 对象等非引用计数表示替换 Python 对象。查看[问题#13246](https://github.com/pytorch/pytorch/issues/13246#issuecomment-905703662)以获取更多关于为什么会发生这种情况以及如何解决这些问题的示例代码。 在此模式下,每次创建`DataLoader`的迭代器(例如,当调用`enumerate(dataloader)`时),将创建`num_workers`个工作进程。在这一点上,`dataset`、`collate_fn`和`worker_init_fn`被传递给每个工作进程,它们用于初始化和获取数据。这意味着数据集访问以及其内部 IO、转换(包括`collate_fn`)在工作进程中运行。 `torch.utils.data.get_worker_info()`在工作进程中返回各种有用信息(包括工作进程 ID、数据集副本、初始种子等),在主进程中返回`None`。用户可以在数据集代码和/或`worker_init_fn`中使用此函数来单独配置每个数据集副本,并确定代码是否在工作进程中运行。例如,在对数据集进行分片时,这可能特别有帮助。 对于映射式数据集,主进程使用`sampler`生成索引并将其发送给工作进程。因此,任何洗牌随机化都是在主进程中完成的,主进程通过分配索引来指导加载。 对于可迭代式数据集,由于每个工作进程都会获得一个`dataset`对象的副本,天真的多进程加载通常会导致数据重复。使用`torch.utils.data.get_worker_info()`和/或`worker_init_fn`,用户可以独立配置每个副本。(请参阅`IterableDataset`文档以了解如何实现。)出于类似的原因,在多进程加载中,`drop_last`参数会删除每个工作进程的可迭代式数据集副本的最后一个非完整批次。 工作进程在迭代结束时关闭,或者当迭代器被垃圾回收时关闭。 警告 通常不建议在多进程加载中返回 CUDA 张量,因为在使用 CUDA 和在多进程中共享 CUDA 张量时存在许多微妙之处(请参阅 CUDA 在多进程中的使用)。相反,我们建议使用自动内存固定(即设置`pin_memory=True`),这样可以实现快速数据传输到支持 CUDA 的 GPU。 #### 平台特定行为 由于工作人员依赖于 Python [`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing "(在 Python v3.12 中)"),与 Unix 相比,在 Windows 上工作启动行为是不同的。 + 在 Unix 上,`fork()`是默认的[`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing "(在 Python v3.12 中)")启动方法。使用`fork()`,子工作进程通常可以直接通过克隆的地址空间访问`dataset`和 Python 参数函数。 + 在 Windows 或 MacOS 上,`spawn()`是默认的[`multiprocessing`](https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing "(在 Python v3.12 中)")启动方法。使用`spawn()`,会启动另一个解释器来运行您的主脚本,然后是接收`dataset`、`collate_fn`和其他参数的内部工作函数,通过[`pickle`](https://docs.python.org/3/library/pickle.html#module-pickle "(在 Python v3.12 中)")序列化。 这种单独的序列化意味着在使用多进程数据加载时,您应该采取两个步骤来确保与 Windows 兼容: + 将大部分主脚本代码放在`if __name__ == '__main__':`块中,以确保在启动每个工作进程时不会再次运行(很可能会生成错误)。您可以在这里放置数据集和`DataLoader`实例创建逻辑,因为它不需要在工作进程中重新执行。 + 确保任何自定义的`collate_fn`、`worker_init_fn`或`dataset`代码被声明为顶层定义,放在`__main__`检查之外。这样可以确保它们在工作进程中可用。(这是因为函数只被序列化为引用,而不是`bytecode`。) #### 多进程数据加载中的随机性 默认情况下,每个工作进程的 PyTorch 种子将设置为`base_seed + worker_id`,其中`base_seed`是由主进程使用其 RNG 生成的长整型(因此,强制消耗 RNG 状态)或指定的`generator`。然而,初始化工作进程时,其他库的种子可能会重复,导致每个工作进程返回相同的随机数。(请参阅 FAQ 中的此部分。) 在`worker_init_fn`中,您可以使用`torch.utils.data.get_worker_info().seed`或`torch.initial_seed()`访问为每个工作进程设置的 PyTorch 种子,并在数据加载之前用它来为其他库设置种子。 ## 内存固定 当从固定(页锁定)内存发起主机到 GPU 的复制时,复制速度要快得多。有关何时以及如何通常使用固定内存,请参阅使用固定内存缓冲区以获取更多详细信息。 对于数据加载,将`pin_memory=True`传递给`DataLoader`将自动将获取的数据张量放在固定内存中,从而实现更快的数据传输到支持 CUDA 的 GPU。 默认的内存固定逻辑仅识别张量和包含张量的映射和可迭代对象。默认情况下,如果固定逻辑看到一个批次是自定义类型(如果您有一个返回自定义批次类型的`collate_fn`,或者如果批次的每个元素是自定义类型),或者如果批次(或这些元素)是自定义类型,则固定逻辑将不会识别它们,并且将返回该批次(或这些元素)而不固定内存。要为自定义批次或数据类型启用内存固定,请在自定义类型上定义一个`pin_memory()`方法。 请参见下面的示例。 示例: ```py class SimpleCustomBatch: def __init__(self, data): transposed_data = list(zip(*data)) self.inp = torch.stack(transposed_data[0], 0) self.tgt = torch.stack(transposed_data[1], 0) # custom memory pinning method on custom type def pin_memory(self): self.inp = self.inp.pin_memory() self.tgt = self.tgt.pin_memory() return self def collate_wrapper(batch): return SimpleCustomBatch(batch) inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5) dataset = TensorDataset(inps, tgts) loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper, pin_memory=True) for batch_ndx, sample in enumerate(loader): print(sample.inp.is_pinned()) print(sample.tgt.is_pinned()) ``` ```py class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device='') ``` 数据加载器结合了数据集和采样器,并提供了对给定数据集的可迭代对象。 `DataLoader`支持单进程或多进程加载的映射样式和可迭代样式数据集,自定义加载顺序以及可选的自动分批(整理)和内存固定。 请查看`torch.utils.data`文档页面以获取更多详细信息。 参数 + **dataset**(*Dataset*)- 要从中加载数据的数据集。 + **batch_size**(*int*,可选)- 每批加载多少个样本(默认值:`1`)。 + **shuffle**(*bool*,可选)- 设置为`True`以在每个时期重新洗牌数据(默认值:`False`)。 + **sampler**(*Sampler*或*Iterable*,可选)- 定义从数据集中抽取样本的策略。可以是任何实现了`__len__`的`Iterable`。如果指定了,必须不指定`shuffle`。 + **batch_sampler**(*Sampler*或*Iterable*,可选)- 类似于`sampler`,但一次返回一批索引。与`batch_size`、`shuffle`、`sampler`和`drop_last`互斥。 + **num_workers**(*int*,可选)- 用于数据加载的子进程数。`0`表示数据将在主进程中加载。(默认值:`0`) + **collate_fn**(可调用,可选)- 将样本列表合并为一个 Tensor(s)的小批量。在从映射样式数据集进行批量加载时使用。 + **pin_memory**(*bool*,可选)- 如果为`True`,数据加载器将在返回之前将张量复制到设备/CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的`collate_fn`返回一个自定义类型的批次,请参见下面的示例。 + **drop_last**(*bool*,可选)- 设置为`True`以丢弃最后一个不完整的批次,如果数据集大小不是批次大小的整数倍。如果为`False`且数据集大小不是批次大小的整数倍,则最后一个批次将更小。(默认值:`False`) + **timeout**(数值,可选)- 如果为正数,则为从工作进程收集一批数据的超时值。应始终为非负数。(默认值:`0`) + **worker_init_fn**(可调用,可选)- 如果不为`None`,则将在每个工作进程上调用此函数,输入为工作进程的工作 ID(一个在`[0,num_workers - 1]`中的整数),在种子和数据加载之后。 (默认值:`None`) + **multiprocessing_context**(*str*或*multiprocessing.context.BaseContext*,可选)- 如果为`None`,则将使用您操作系统的默认[多进程上下文](https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods)。(默认值:`None`) + **generator**(*torch.Generator*,可选)- 如果不为`None`,则 RandomSampler 将使用此 RNG 生成随机索引,多进程将生成工作进程的`base_seed`。(默认值:`None`) + **prefetch_factor**(*int*,可选,关键字参数)- 每个工作进程提前加载的批次数。`2`表示所有工作进程中将预取 2*num_workers 批次。 (默认值取决于 num_workers 的设置值。如果 num_workers=0,默认值为`None`。否则,如果`num_workers > 0`,默认值为`2`)。 + **persistent_workers**(*bool*,可选)- 如果为`True`,数据加载器将在数据集被消耗一次后不关闭工作进程。这允许保持工作进程的数据集实例处于活动状态。(默认值:`False`) + **pin_memory_device**(*str*,可选)- 如果`pin_memory`为`True`,则要将其`pin_memory`到的设备。 警告 如果使用`spawn`启动方法,则`worker_init_fn`不能是一个不可序列化的对象,例如 lambda 函数。有关 PyTorch 中多进程的更多详细信息,请参见多进程最佳实践。 警告 `len(dataloader)`的启发式方法基于使用的采样器的长度。当`dataset`是一个`IterableDataset`时,它将根据`len(dataset) / batch_size`估算一个值,根据`drop_last`进行适当的四舍五入,而不考虑多进程加载配置。这代表 PyTorch 可以做出的最佳猜测,因为 PyTorch 信任用户`dataset`代码正确处理多进程加载以避免重复数据。 然而,如果分片导致多个工作进程具有不完整的最后批次,这个估计仍然可能不准确,因为(1)一个否则完整的批次可以被分成多个批次,(2)当设置`drop_last`时,可以丢弃超过一个批次的样本。不幸的是,PyTorch 通常无法检测到这种情况。 有关这两种数据集类型以及`IterableDataset`如何与多进程数据加载交互的更多详细信息,请参见数据集类型。 警告 有关随机种子相关问题,请参见可重现性,以及我的数据加载器工作进程返回相同的随机数,以及多进程数据加载中的随机性注释。 ```py class torch.utils.data.Dataset(*args, **kwds) ``` 表示`Dataset`的抽象类。 所有表示从键到数据样本的映射的数据集都应该是它的子类。所有子类都应该重写`__getitem__()`,支持为给定键获取数据样本。子类还可以选择重写`__len__()`,许多`Sampler`实现和`DataLoader`的默认选项期望返回数据集的大小。子类还可以选择实现`__getitems__()`,以加快批量样本加载。此方法接受批次样本的索引列表,并返回样本列表。 注意 `DataLoader` 默认构造一个产生整数索引的索引采样器。要使其与具有非整数索引/键的映射样式数据集一起工作,必须提供自定义采样器。 ```py class torch.utils.data.IterableDataset(*args, **kwds) ``` 一个可迭代的数据集。 所有表示数据样本可迭代的数据集都应该是它的子类。当数据来自流时,这种形式的数据集特别有用。 所有子类都应该重写`__iter__()`,它将返回此数据集中样本的迭代器。 当子类与`DataLoader`一起使用时,数据集中的每个项目将从`DataLoader`迭代器中产生。当`num_workers > 0`时,每个工作进程将有数据集对象的不同副本,因此通常希望配置每个副本以避免从工作进程返回重复数据。当在工作进程中调用时,`get_worker_info()`返回有关工作进程的信息。它可以在数据集的`__iter__()`方法或`DataLoader`的`worker_init_fn`选项中使用,以修改每个副本的行为。 示例 1:在`__iter__()`中将工作负载分配给所有工作进程: ```py >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... worker_info = torch.utils.data.get_worker_info() ... if worker_info is None: # single-process data loading, return the full iterator ... iter_start = self.start ... iter_end = self.end ... else: # in a worker process ... # split workload ... per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... iter_start = self.start + worker_id * per_worker ... iter_end = min(iter_start + per_worker, self.end) ... return iter(range(iter_start, iter_end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [tensor([3]), tensor([4]), tensor([5]), tensor([6])] >>> # Mult-process loading with two worker processes >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12))) [tensor([3]), tensor([5]), tensor([4]), tensor([6])] ``` 示例 2:使用`worker_init_fn`在所有工作进程之间分配工作负载: ```py >>> class MyIterableDataset(torch.utils.data.IterableDataset): ... def __init__(self, start, end): ... super(MyIterableDataset).__init__() ... assert end > start, "this example code only works with end >= start" ... self.start = start ... self.end = end ... ... def __iter__(self): ... return iter(range(self.start, self.end)) ... >>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6]. >>> ds = MyIterableDataset(start=3, end=7) >>> # Single-process loading >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0))) [3, 4, 5, 6] >>> >>> # Directly doing multi-process loading yields duplicate data >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2))) [3, 3, 4, 4, 5, 5, 6, 6] >>> # Define a `worker_init_fn` that configures each dataset copy differently >>> def worker_init_fn(worker_id): ... worker_info = torch.utils.data.get_worker_info() ... dataset = worker_info.dataset # the dataset copy in this worker process ... overall_start = dataset.start ... overall_end = dataset.end ... # configure the dataset to only process the split workload ... per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers))) ... worker_id = worker_info.id ... dataset.start = overall_start + worker_id * per_worker ... dataset.end = min(dataset.start + per_worker, overall_end) ... >>> # Mult-process loading with the custom `worker_init_fn` >>> # Worker 0 fetched [3, 4]. Worker 1 fetched [5, 6]. >>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn))) [3, 5, 4, 6] >>> # With even more workers >>> print(list(torch.utils.data.DataLoader(ds, num_workers=12, worker_init_fn=worker_init_fn))) [3, 4, 5, 6] ``` ```py class torch.utils.data.TensorDataset(*tensors) ``` 包装张量的数据集。 每个样本将通过沿第一维度索引张量来检索。 参数 ***tensors** (*Tensor*) - 具有相同第一维度大小的张量。 ```py class torch.utils.data.StackDataset(*args, **kwargs) ``` 数据集作为多个数据集的堆叠。 这个类对于组装给定为数据集的复杂输入数据的不同部分很有用。 示例 ```py >>> images = ImageDataset() >>> texts = TextDataset() >>> tuple_stack = StackDataset(images, texts) >>> tuple_stack[0] == (images[0], texts[0]) >>> dict_stack = StackDataset(image=images, text=texts) >>> dict_stack[0] == {'image': images[0], 'text': texts[0]} ``` 参数 + ***args** (*Dataset*) - 作为元组返回的堆叠数据集。 + ****kwargs** (*Dataset*) - 作为字典返回的堆叠数据集。 ```py class torch.utils.data.ConcatDataset(datasets) ``` 数据集作为多个数据集的串联。 这个类对于组装不同的现有数据集很有用。 参数 **datasets** (*序列*) - 要连接的数据集列表 ```py class torch.utils.data.ChainDataset(datasets) ``` 用于链接多个`IterableDataset`的数据集。 这个类对于组装不同的现有数据集流很有用。链式操作是即时进行的,因此使用这个类连接大规模数据集将是高效的。 参数 **datasets** (*可迭代的* *IterableDataset*) - 要链接在一起的数据集 ```py class torch.utils.data.Subset(dataset, indices) ``` 指定索引处数据集的子集。 参数 + **dataset** (*Dataset*) - 整个数据集 + **indices** (*序列*) - 选择子集的整个集合中的索引 ```py torch.utils.data._utils.collate.collate(batch, *, collate_fn_map=None) ``` 处理每个批次中元素的集合类型的一般整理函数。 该函数还打开函数注册表以处理特定元素类型。default_collate_fn_map 为张量、numpy 数组、数字和字符串提供默认整理函数。 参数 + **batch** - 要整理的单个批次 + **collate_fn_map** ([*可选*](https://docs.python.org/3/library/typing.html#typing.Optional "(在 Python v3.12 中)")*[*[*字典*](https://docs.python.org/3/library/typing.html#typing.Dict "(在 Python v3.12 中)")*[*[*联合*](https://docs.python.org/3/library/typing.html#typing.Union "(在 Python v3.12 中)")*[*[*类型*](https://docs.python.org/3/library/typing.html#typing.Type "(在 Python v3.12 中)")*,* [*元组*](https://docs.python.org/3/library/typing.html#typing.Tuple "(在 Python v3.12 中)")*[*[*类型*](https://docs.python.org/3/library/typing.html#typing.Type "(在 Python v3.12 中)")*,* *...**]**]**,* [*可调用*](https://docs.python.org/3/library/typing.html#typing.Callable "(在 Python v3.12 中)")*]**]*) – 可选字典,将元素类型映射到相应的聚合函数。如果元素类型不在此字典中,则此函数将按插入顺序遍历字典的每个键,以调用相应的聚合函数,如果元素类型是键的子类。 示例 ```py >>> def collate_tensor_fn(batch, *, collate_fn_map): >>> # Extend this function to handle batch of tensors ... return torch.stack(batch, 0) >>> def custom_collate(batch): ... collate_map = {torch.Tensor: collate_tensor_fn} ... return collate(batch, collate_fn_map=collate_map) >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map` >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn}) ``` 注意 每个聚合函数都需要一个用于批处理的位置参数和一个用于聚合函数字典的关键字参数作为 collate_fn_map。 ```py torch.utils.data.default_collate(batch) ``` 接收一批数据并将批次内的元素放入一个具有额外外部维度(批量大小)的张量中。 确切的输出类型可以是 `torch.Tensor`、`torch.Tensor` 序列、`torch.Tensor` 集合,或保持不变,具体取决于输入类型。当在 `DataLoader` 中定义了 batch_size 或 batch_sampler 时,这将用作默认的聚合函数。 这是一般输入类型(基于批次内元素类型)到输出类型的映射: > + `torch.Tensor` -> `torch.Tensor`(带有额外的外部维度批量大小) > + > + NumPy 数组 -> `torch.Tensor` > + > + 浮点数 -> `torch.Tensor` > + > + 整数 -> `torch.Tensor` > + > + 字符串 -> 字符串(保持不变) > + > + 字节 -> 字节(保持不变) > + > + Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])] > + > + NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …] > + > + Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …] 参数 **batch** – 要聚合的单个批次 示例 ```py >>> # Example with a batch of `int`s: >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: >>> default_collate(['a', 'b', 'c']) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple('Point', ['x', 'y']) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: >>> default_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])] >>> # Example with `List` inside the batch: >>> default_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])] >>> # Two options to extend `default_collate` to handle specific type >>> # Option 1: Write custom collate function and invoke `default_collate` >>> def custom_collate(batch): ... elem = batch[0] ... if isinstance(elem, CustomType): # Some custom condition ... return ... ... else: # Fall back to `default_collate` ... return default_collate(batch) >>> # Option 2: In-place modify `default_collate_fn_map` >>> def collate_customtype_fn(batch, *, collate_fn_map=None): ... return ... >>> default_collate_fn_map.update(CustoType, collate_customtype_fn) >>> default_collate(batch) # Handle `CustomType` automatically ``` ```py torch.utils.data.default_convert(data) ``` 将每个 NumPy 数组元素转换为 `torch.Tensor`。 如果输入是序列、集合或映射,则尝试将其中的每个元素转换为 `torch.Tensor`。如果输入不是 NumPy 数组,则保持不变。当 `DataLoader` 中未定义 batch_sampler 和 batch_size 时,这将用作默认的聚合函数。 一般的输入类型到输出类型的映射类似于 `default_collate()`。有关更多详细信息,请参阅那里的描述。 参数 **data** – 要转换的单个数据点 示例 ```py >>> # Example with `int` >>> default_convert(0) 0 >>> # Example with NumPy array >>> default_convert(np.array([0, 1])) tensor([0, 1]) >>> # Example with NamedTuple >>> Point = namedtuple('Point', ['x', 'y']) >>> default_convert(Point(0, 0)) Point(x=0, y=0) >>> default_convert(Point(np.array(0), np.array(0))) Point(x=tensor(0), y=tensor(0)) >>> # Example with List >>> default_convert([np.array([0, 1]), np.array([2, 3])]) [tensor([0, 1]), tensor([2, 3])] ``` ```py torch.utils.data.get_worker_info() ``` 返回有关当前 `DataLoader` 迭代器工作进程的信息。 在工作进程中调用时,此函数返回一个对象,保证具有以下属性: + `id`: 当前工作进程 ID。 + `num_workers`: 总工作人员数量。 + `seed`: 为当前工作进程设置的随机种子。此值由主进程 RNG 和工作进程 ID 确定。有关更多详细信息,请参阅 `DataLoader` 的文档。 + `dataset`:**此**进程中数据集对象的副本。请注意,这将与主进程中的对象不同。 在主进程中调用时,返回`None`。 注意 当在传递给`DataLoader`的`worker_init_fn`中使用时,此方法可用于设置每个工作进程的不同设置,例如,使用`worker_id`来配置`dataset`对象,以仅读取分片数据集的特定部分,或使用`seed`来为数据集代码中使用的其他库设置种子。 返回类型 [*可选*](https://docs.python.org/3/library/typing.html#typing.Optional "(在 Python v3.12 中)")[*WorkerInfo*] ```py torch.utils.data.random_split(dataset, lengths, generator=) ``` 将数据集随机拆分为给定长度的不重叠新数据集。 如果给定一组总和为 1 的分数列表,则长度将自动计算为每个提供的分数的 floor(frac * len(dataset))。 计算长度后,如果有任何余数,将一个计数以循环方式分配给长度,直到没有余数为止。 可选择地固定生成器以获得可重复的结果,例如: 示例 ```py >>> generator1 = torch.Generator().manual_seed(42) >>> generator2 = torch.Generator().manual_seed(42) >>> random_split(range(10), [3, 7], generator=generator1) >>> random_split(range(30), [0.3, 0.3, 0.4], generator=generator2) ``` 参数 + **dataset** (*数据集*) – 要拆分的数据集 + **lengths** (*序列*) – 要生成的拆分长度或分数 + **generator** (*生成器*) – 用于随机排列的生成器。 返回类型 [*列表*](https://docs.python.org/3/library/typing.html#typing.List "(在 Python v3.12 中)")[*子集*[*T*]] ```py class torch.utils.data.Sampler(data_source=None) ``` 所有采样器的基类。 每个采样器子类都必须提供一个`__iter__()`方法,提供一种迭代数据集元素索引或索引列表(批次)的方法,并且必须提供一个`__len__()`方法,返回返回的迭代器的长度。 参数 **data_source** (*数据集*) – 此参数未被使用,将在 2.2.0 中移除。您仍然可以有使用它的自定义实现。 示例 ```py >>> class AccedingSequenceLengthSampler(Sampler[int]): >>> def __init__(self, data: List[str]) -> None: >>> self.data = data >>> >>> def __len__(self) -> int: >>> return len(self.data) >>> >>> def __iter__(self) -> Iterator[int]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> yield from torch.argsort(sizes).tolist() >>> >>> class AccedingSequenceLengthBatchSampler(Sampler[List[int]]): >>> def __init__(self, data: List[str], batch_size: int) -> None: >>> self.data = data >>> self.batch_size = batch_size >>> >>> def __len__(self) -> int: >>> return (len(self.data) + self.batch_size - 1) // self.batch_size >>> >>> def __iter__(self) -> Iterator[List[int]]: >>> sizes = torch.tensor([len(x) for x in self.data]) >>> for batch in torch.chunk(torch.argsort(sizes), len(self)): >>> yield batch.tolist() ``` 注意 `__len__()`方法在`DataLoader`中并非严格要求,但在涉及计算`DataLoader`长度的任何计算中都是预期的。 ```py class torch.utils.data.SequentialSampler(data_source) ``` 按顺序抽取元素,始终按相同顺序。 参数 **data_source** (*数据集*) – 要从中抽样的数据集 ```py class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None) ``` 随机抽取元素。如果不进行替换,则从打乱的数据集中抽样。 如果进行替换,则用户可以指定`num_samples`来抽取。 参数 + **data_source** (*数据集*) – 要从中抽样的数据集 + **replacement** ([*布尔值*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")) – 如果为`True`,则使用替换方式按需抽取样本,默认为``False`` + **num_samples** ([*整数*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")) – 要抽取的样本数,默认为`len(dataset)`。 + **generator** (*生成器*) – 用于采样的生成器。 ```py class torch.utils.data.SubsetRandomSampler(indices, generator=None) ``` 从给定的索引列表中随机抽取元素,不进行替换。 参数 + **indices** (*序列*) – 一系列索引 + **generator** (*生成器*) – 用于采样的生成器。 ```py class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True, generator=None) ``` 根据给定的概率(权重)从`[0,..,len(weights)-1]`中抽取元素。 参数 + **weights** (*序列*) – 一系列权重,不一定总和为一 + **num_samples** ([*整数*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")) – 要抽取的样本数 + **replacement** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")) – 如果为`True`,则使用替换抽取样本。如果不是,则不使用替换,这意味着当为一行抽取样本索引时,不能再次为该行抽取。 + **generator** (*Generator*) – 用于采样的生成器。 示例 ```py >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True)) [4, 4, 1, 4, 5] >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) [0, 1, 4, 3, 2] ``` ```py class torch.utils.data.BatchSampler(sampler, batch_size, drop_last) ``` 包装另一个采样器以产生索引的小批量。 参数 + **sampler** (*Sampler* *或* *Iterable*) – 基本采样器。可以是任何可迭代对象 + **batch_size** ([*int*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")*) – 小批量的大小。 + **drop_last** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*) – 如果为`True`,采样器将删除最后一个批次,如果其大小小于`batch_size` 示例 ```py >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) [[0, 1, 2], [3, 4, 5], [6, 7, 8]] ``` ```py class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True, seed=0, drop_last=False) ``` 限制数据加载到数据集的子集的采样器。 与`torch.nn.parallel.DistributedDataParallel`结合使用特别有用。在这种情况下,每个进程可以将`DistributedSampler`实例作为`DataLoader`采样器传递,并加载原始数据集的专属子集。 注意 假定数据集大小恒定,并且它的任何实例始终以相同的顺序返回相同的元素。 参数 + **dataset** – 用于采样的数据集。 + **num_replicas** ([*int*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")*,* *可选*) – 参与分布式训练的进程数量。默认情况下,`world_size`是从当前分布式组中检索的。 + **rank** ([*int*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")*,* *可选*) – `num_replicas`中当前进程的等级。默认情况下,`rank`是从当前分布式组中检索的。 + **shuffle** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*,* *可选*) – 如果为`True`(默认),采样器将对索引进行洗牌。 + **seed** ([*int*](https://docs.python.org/3/library/functions.html#int "(在 Python v3.12 中)")*,* *可选*) – 用于对采样器进行洗牌的随机种子,如果`shuffle=True`。这个数字应该在分布式组中的所有进程中是相同的。默认值:`0`。 + **drop_last** ([*bool*](https://docs.python.org/3/library/functions.html#bool "(在 Python v3.12 中)")*,* *可选*) – 如果为`True`,则采样器将删除数据的尾部,使其在副本数量上均匀可分。如果为`False`,则采样器将添加额外的索引,使数据在副本上均匀可分。默认值:`False`。 警告 在分布式模式下,在每个时期开始时调用`set_epoch()`方法 **之前** 创建`DataLoader`迭代器是必要的,以使在多个时期中正确地进行洗牌。否则,将始终使用相同的顺序。 示例: ```py >>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) ```