96.md 32.3 KB
Newer Older
片刻小哥哥's avatar
片刻小哥哥 已提交
1 2
# torch.utils.data

3
> 原文: [https://pytorch.org/docs/1.4.0/data.html](https://pytorch.org/docs/1.4.0/data.html)
片刻小哥哥's avatar
片刻小哥哥 已提交
4 5 6

PyTorch 数据加载实用程序的核心是 [`torch.utils.data.DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 类。 它表示可在数据集上迭代的 Python,并支持

7
*   [映射式和迭代式的数据集](#dataset-types)
片刻小哥哥's avatar
片刻小哥哥 已提交
8 9 10

*   [自定义数据加载顺序](#data-loading-order-and-sampler)

11
*   [自动批次](#loading-batched-and-non-batched-data)
片刻小哥哥's avatar
片刻小哥哥 已提交
12 13 14 15 16

*   [单进程和多进程数据加载](#single-and-multi-process-data-loading)

*   [自动内存固定](#memory-pinning)

17
这些选项由 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 的构造函数参数配置,构造函数的签名如下:
片刻小哥哥's avatar
片刻小哥哥 已提交
18 19 20 21 22 23 24 25 26

```
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

```

27
以下各节详细介绍了这些参数选项的作用和用法。
片刻小哥哥's avatar
片刻小哥哥 已提交
28 29 30

## 数据集类型

31
[`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 构造函数最重要的参数是`dataset`,它指定了用于加载数据的数据集对象。PyTorch 支持两种不同类型的数据集:
片刻小哥哥's avatar
片刻小哥哥 已提交
32

33
*   [映射式数据集](#map-style-datasets)
片刻小哥哥's avatar
片刻小哥哥 已提交
34 35 36

*   [迭代式数据集](#iterable-style-datasets)

37
### 映射式数据集
片刻小哥哥's avatar
片刻小哥哥 已提交
38

39
映射式数据集是一种实现`__getitem__()``__len__()`两个协议函数的数据集,它表示从(可能是非整数)索引/键到数据样本的映射。
片刻小哥哥's avatar
片刻小哥哥 已提交
40

41
例如,当使用`dataset[idx]`访问此类数据集时,可以从磁盘上的文件夹中读取第`idx`张图像及其对应的标签。
片刻小哥哥's avatar
片刻小哥哥 已提交
42 43 44 45 46

有关更多详细信息,请参见 [`Dataset`](#torch.utils.data.Dataset "torch.utils.data.Dataset")

### 迭代式数据集

47
迭代式的数据集是 [`IterableDataset`](#torch.utils.data.IterableDataset "torch.utils.data.IterableDataset") 子类的实例,该子类实现了`__iter__()`协议函数,并表示可在数据样本上进行迭代。当随机读取代价较高或不可能实现时,以及批处理大小取决于所获取数据的情况时,使用这种数据集是相当合适的。
片刻小哥哥's avatar
片刻小哥哥 已提交
48

49
例如,当这种数据集被`iter(dataset)`调用时,它可以从数据库、远程服务器甚至日志中实时读取返回数据流。
片刻小哥哥's avatar
片刻小哥哥 已提交
50 51 52 53 54

有关更多详细信息,请参见 [`IterableDataset`](#torch.utils.data.IterableDataset "torch.utils.data.IterableDataset")

注意

55
[`IterableDataset`](#torch.utils.data.IterableDataset "torch.utils.data.IterableDataset")[多进程数据加载](#multi-process-data-loading)一起使用时。 每个工作进程都会复制相同的数据集对象,因此必须对副本进行不同的配置,以避免重复的数据。 有关如何实现此功能的信息,请参见 [`IterableDataset`](#torch.utils.data.IterableDataset "torch.utils.data.IterableDataset") 文档。
片刻小哥哥's avatar
片刻小哥哥 已提交
56 57 58

## 数据加载顺序和 [`Sampler`](#torch.utils.data.Sampler "torch.utils.data.Sampler")

59
对于[迭代式数据集](#iterable-style-datasets),数据加载顺序完全由用户定义的迭代器控制。这样可以轻松实现块读取和动态批次大小(例如,每次生成一个批次的样本)。
片刻小哥哥's avatar
片刻小哥哥 已提交
60

61
本节的其余部分主要关注[映射式数据集](#map-style-datasets)[`torch.utils.data.Sampler`](#torch.utils.data.Sampler "torch.utils.data.Sampler") 类用于指定数据加载中使用的索引/键的顺序。 它们代表数据集索引上的可迭代对象。 以一个常见情况举例,使用随机梯度下降(SGD)时,一个[`Sampler`](#torch.utils.data.Sampler "torch.utils.data.Sampler") 每次可以随机生成一列索引,或者为小批量SGD生成少量索引。
片刻小哥哥's avatar
片刻小哥哥 已提交
62

63
基于 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader")`shuffle`参数,将自动构建顺序采样或打乱的采样器。 或者,用户可以使用`sampler`参数指定一个自定义的 [`Sampler`](#torch.utils.data.Sampler "torch.utils.data.Sampler") 对象,该对象每次都会生成下一个要提取的索引/关键字。
片刻小哥哥's avatar
片刻小哥哥 已提交
64

65
一次生成一个批次索引的自定义 [`Sampler`](#torch.utils.data.Sampler "torch.utils.data.Sampler") 可以赋值给`batch_sampler`参数,传递到`Dataloader`的构造函数中。 也可以通过`batch_size``drop_last`两个参数启用自动批处理。 有关更多详细信息,请参见[下一部分](#loading-batched-and-non-batched-data)
片刻小哥哥's avatar
片刻小哥哥 已提交
66 67 68

Note

69
`sampler``batch_sampler`与迭代式数据集均不兼容,因为此类数据集没有键或索引的概念。
片刻小哥哥's avatar
片刻小哥哥 已提交
70 71 72

## 加载批处理和非批处理数据

73
[`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 支持通过参数`batch_size``drop_last``batch_sampler`将各个提取的数据样本自动整理为批次数据。
片刻小哥哥's avatar
片刻小哥哥 已提交
74 75 76

### 自动批处理(默认)

77
这是最常见的情况,对应于获取一小批数据并将其整理为批处理的样本,即数据加载后生成一个张量,张量中的某一维度代表批处理数据量的数目(通常是第一维)。
片刻小哥哥's avatar
片刻小哥哥 已提交
78

79
`batch_size`(默认`1`)不是`None`时,数据加载器将生成批处理的样本,而不是单个样本。 `batch_size``drop_last`参数用于指定数据加载器如何获取一个批次的数据集的键。 对于映射式数据集,用户可以选择指定`batch_sampler`,它一次生成一个键列表。
片刻小哥哥's avatar
片刻小哥哥 已提交
80 81 82

Note

83
`batch_size``drop_last`参数本质上是用`sampler`构造`batch_sampler`。 对于映射式数据集,`sampler`由用户提供,或基于`shuffle`参数构造。 对于迭代式数据集,`sampler`是一个虚拟且无限的采样器。 有关采样器的更多详细信息,请参见[本节](#data-loading-order-and-sampler)
片刻小哥哥's avatar
片刻小哥哥 已提交
84 85 86

Note

87
当用[多进程数据加载](#multi-process-data-loading)[迭代式数据集](#iterable-style-datasets)获取数据时,`drop_last`参数的应用,会丢弃每个进程中,数据集副本最后一个非完整批次。
片刻小哥哥's avatar
片刻小哥哥 已提交
88

89
使用采样器生成索引,并获取索引对应样本列表后,该样本列表将由`collate_fn`参数传递的函数整理为批次。
片刻小哥哥's avatar
片刻小哥哥 已提交
90

91
在这种情况下,从映射式数据集加载数据大致等效于:
片刻小哥哥's avatar
片刻小哥哥 已提交
92 93 94 95 96 97 98

```
for indices in batch_sampler:
    yield collate_fn([dataset[i] for i in indices])

```

99
从迭代式数据集加载数据大致等效于:
片刻小哥哥's avatar
片刻小哥哥 已提交
100 101 102 103 104 105 106 107

```
dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

```

108
自定义`collate_fn`可用于自定义排序规则,例如,将顺序数据填充至批处理的最大长度。 有关`collate_fn`的更多信息,请参见[本部分](#dataloader-collate-fn)
片刻小哥哥's avatar
片刻小哥哥 已提交
109 110 111

### 禁用自动批处理

112
在某些情况下,用户可能希望用数据集代码手动处理批处理,或仅加载单个样本。 例如,直接加载批处理数据所需代价更低(例如,从数据库中批量读取或读取连续的内存块),或者批处理大小取决于数据,或者该程序设计为只处理单个样本。 在这些情况下,最好不要使用自动批处理(其中`collate_fn`用于整理样本),而应让数据加载器直接返回`dataset`对象的每个成员。
片刻小哥哥's avatar
片刻小哥哥 已提交
113

114
`batch_size``batch_sampler`均为`None`时(其中`batch_sampler`的默认值为`None`),自动批处理被禁用。 `collate_fn`参数传递的函数对从`dataset`获得的每个样本进行处理。
片刻小哥哥's avatar
片刻小哥哥 已提交
115

116
**禁用自动批处理**时,`collate_fn`的默认函数仅将 NumPy 数组转换为 PyTorch 张量,而其他所有内容均保持不变。
片刻小哥哥's avatar
片刻小哥哥 已提交
117

118
在这种情况下,从映射式数据集中读取样本基本相当于:
片刻小哥哥's avatar
片刻小哥哥 已提交
119 120 121 122 123 124 125

```
for index in sampler:
    yield collate_fn(dataset[index])

```

126
从迭代式数据集中读取样本基本相当于:
片刻小哥哥's avatar
片刻小哥哥 已提交
127 128 129 130 131 132 133 134 135 136 137

```
for data in iter(dataset):
    yield collate_fn(data)

```

有关`collate_fn`的更多信息,请参见[本部分](#dataloader-collate-fn)

### 使用`collate_fn`

138
启用或禁用自动批处理时,`collate_fn`的用法略有不同。
片刻小哥哥's avatar
片刻小哥哥 已提交
139

140
**禁用自动批处理**时,对每个数据样本分别调用`collate_fn`,并且从数据加载器迭代器产生输出。 在这种情况下,默认的`collate_fn`仅转换 PyTorch 张量中的 NumPy 数组。
片刻小哥哥's avatar
片刻小哥哥 已提交
141

142
**启用自动批处理**时,对数据样本列表每次调用`collate_fn`,这是为了将输入样本整理为一个批次,或从数据加在迭代器中生成一个批次的数据。本节的其余部分描述了这种情况下默认`collate_fn`的行为。
片刻小哥哥's avatar
片刻小哥哥 已提交
143

144
例如,如果每个数据样本都包含一个3通道图像和一个整体类标签,即数据集的每个元素返回一个元组`(image, class_index)`,则默认的`collate_fn`将此类元组的列表整理为一个单独的元组 该元组包括批处理化的图像张量和批处理化的类标签张量。特别地,默认`collate_fn`具有以下属性:
片刻小哥哥's avatar
片刻小哥哥 已提交
145

146
*   它始终将新维度前置为批次维度。
片刻小哥哥's avatar
片刻小哥哥 已提交
147 148 149

*   它会自动将 NumPy 数组和 Python 数值转换为 PyTorch 张量。

150
*   它保留了数据结构,例如,如果每个样本都是一个字典,它将输出一个字典,该字典具有相同的键集合,但值将被批处理化为张量(或者,当值无法被转换为张量时,将被转换为列表)。`list``tuple``namedtuple`等相同,数据结构不变。
片刻小哥哥's avatar
片刻小哥哥 已提交
151

152
用户可以使用自定义的`collate_fn`来实现自定义批处理,例如,沿除第一个维度之外的其他维度进行数据整理,将序列填充至不同长度,或添加对自定义数据类型的支持。
片刻小哥哥's avatar
片刻小哥哥 已提交
153 154 155 156 157

## 单进程和多进程数据加载

默认情况下, [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 使用单进程数据加载。

158
在 Python 进程中,[全局解释器锁定(GIL)](https://wiki.python.org/moin/GlobalInterpreterLock)阻止了跨线程时Python代码真正的完全并行化。为了避免在加载数据时系统阻塞计算代码,PyTorch 提供了一个简单的开关,只需将参数`num_workers`设置为正整数,即可执行多进程数据加载。
片刻小哥哥's avatar
片刻小哥哥 已提交
159 160 161

### 单进程数据加载(默认)

162
在此模式下,获取数据的进程与 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader")初始化的进程一致。因此,数据加载时计算可能会被阻止。然而,以下情况可优先选择本模式:如进程之间用于共享数据的资源(例如共享内存,文件描述符等)有限,或者整个数据集很小,可以被完全加载到内存中。此外,单进程加载通常能够显示更多可读的错误跟踪,因此有利于代码调试。
片刻小哥哥's avatar
片刻小哥哥 已提交
163 164 165

### 多进程数据加载

166
将参数`num_workers`设置为正整数,将启动多进程数据加载,工作进程数目为参数指定值。
片刻小哥哥's avatar
片刻小哥哥 已提交
167

168
在此模式下,每次创建 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 的迭代器时(例如,当您调用`enumerate(dataloader)`时),都会创建`num_workers`工作进程。 此时,`dataset``collate_fn``worker_init_fn`被传递给每个工作进程,在这些工作进程中,它们被用来初始化和获取数据。 这意味着数据集访问及其内部IO、数据变换(包括`collate_fn`)在工作进程中运行。
片刻小哥哥's avatar
片刻小哥哥 已提交
169

170
[`torch.utils.data.get_worker_info()`](#torch.utils.data.get_worker_info "torch.utils.data.get_worker_info") 返回工作进程的各种有用信息(包括工作进程的ID,数据集副本,初始化种子等),而主进程返回`None`。 用户可以在数据集代码和/或`worker_init_fn`中使用此函数来分别配置每个数据集副本,并确定代码是否正在工作进程中运行。例如,这个函数在对数据集切片时特别有用。
片刻小哥哥's avatar
片刻小哥哥 已提交
171

172
对于映射式数据集,主过程使用`sampler`生成索引并将其发送给工作进程。 因此,任何随机打乱都是在主进程中完成的,该进程通过为加载分配索引来引导加载过程。
片刻小哥哥's avatar
片刻小哥哥 已提交
173

174
对于迭代式数据集,由于每个工作进程都获得`dataset`对象的副本,因此简易的多进程加载通常会导致数据重复。 用户可以使用 [`torch.utils.data.get_worker_info()`](#torch.utils.data.get_worker_info "torch.utils.data.get_worker_info") 和/或`worker_init_fn`独立配置每个副本。(有关如何实现此操作的信息,请参见 [`IterableDataset`](#torch.utils.data.IterableDataset "torch.utils.data.IterableDataset") 文档。)出于类似的原因,在多进程加载中,`drop_last`参数会删除每个工作进程中迭代式数据集副本的最后一个非完整批次。
片刻小哥哥's avatar
片刻小哥哥 已提交
175

176
一旦迭代结束或迭代器被垃圾回收,工作进程将关闭。
片刻小哥哥's avatar
片刻小哥哥 已提交
177 178 179

警告

180
通常不建议在多进程加载中返回 CUDA 张量,因为在使用 CUDA 和并行处理共享 CUDA 张量时,存在很多微妙之处(请参见[在并行处理中的 CUDA](notes/multiprocessing.html#multiprocessing-cuda-note))。相反,我们建议使用[自动内存固定](#memory-pinning)(即,设置`pin_memory=True`),该功能可以将数据快速传输到支持CUDA的GPU中。
片刻小哥哥's avatar
片刻小哥哥 已提交
181 182 183

#### 平台特定的行为

184
由于工作进程依赖于 Python `multiprocessing`,因此与 Unix 相比,Windows 上的工作进程启动行为有所不同。
片刻小哥哥's avatar
片刻小哥哥 已提交
185

186
*   在 Unix 上,`fork()``multiprocessing`默认的启动方法。 使用`fork()`,子进程通常可以通过克隆的地址空间,直接访问`dataset`和 Python 参数函数。
片刻小哥哥's avatar
片刻小哥哥 已提交
187

188
*   在 Windows 上,`spawn()``multiprocessing`的默认启动方法。 使用`spawn()`启动另一个解释器,该解释器将运行您的主脚本,然后运行内部的工作程序函数,该函数通过`pickle`序列化获取`dataset``collate_fn`和其他参数。
片刻小哥哥's avatar
片刻小哥哥 已提交
189

190
这种独立的序列化意味着,应采取两个步骤确保多进程数据与 Windows 兼容:
片刻小哥哥's avatar
片刻小哥哥 已提交
191

192
*   大部分主脚本代码应在`if __name__ == '__main__':`块中,以确保在启动每个工作进程时,该脚本不会再次运行(很可能会产生错误)。 可以在此处放置数据集和 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 实例创建逻辑,因为实例创建不需要在工作进程中重新执行。
片刻小哥哥's avatar
片刻小哥哥 已提交
193

194
*   确保在`__main__`检查之外将任意自定义`collate_fn``worker_init_fn``dataset`代码声明为顶级定义,以确保它们在工作进程中可用。 (这是必需的,因为这些函数被序列化为引用,而不是`bytecode`。)
片刻小哥哥's avatar
片刻小哥哥 已提交
195 196 197

#### 多进程数据加载中的随机性

198
默认情况下,每个工作进程的 PyTorch 种子将设置为`base_seed + worker_id`,其中`base_seed`是主进程使用其 RNG 生成的长整数(因此,强制使用 RNG 状态)。 但是,初始化工作进程时,可能会复制其他库的种子(例如 NumPy),导致每个工作进程返回相同的随机数。 (请参阅 FAQ 中的[本部分](notes/faq.html#dataloader-workers-random-seed)。)
片刻小哥哥's avatar
片刻小哥哥 已提交
199

200
`worker_init_fn`中,可以用 [`torch.utils.data.get_worker_info().seed`](#torch.utils.data.get_worker_info "torch.utils.data.get_worker_info")[`torch.initial_seed()`](torch.html#torch.initial_seed "torch.initial_seed") 访问每个工作进程的 PyTorch 种子集合,并在加载数据之前使用它为其他库提供种子。
片刻小哥哥's avatar
片刻小哥哥 已提交
201 202 203

## 内存固定

204
当GPU副本来自固定(页锁定)内存时,主机访问GPU副本速度会快得多。有关通常何时以及如何使用固定内存的详细信息,请参见[使用固定内存缓冲区](notes/cuda.html#cuda-memory-pinning)
片刻小哥哥's avatar
片刻小哥哥 已提交
205

206
对于数据加载,将`pin_memory=True`传递到 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 后,获取的数据张量将自动放置在固定内存中,此时数据可以更快传输到启用 CUDA 的 GPU中。
片刻小哥哥's avatar
片刻小哥哥 已提交
207

208
默认的内存固定逻辑,仅识别张量以及包含张量的映射和可迭代对象。 默认情况下,如果固定逻辑看到一个自定义类型的批处理(例如您有一个`collate_fn`,并返回自定义批处理类型),或者该批处理的每个元素都是自定义类型,则固定逻辑将无法识别它们,它将返回该批处理(或那些元素)而不固定内存。 要为自定义批处理或数据类型启用内存固定,请在自定义类型上定义`pin_memory()`方法。
片刻小哥哥's avatar
片刻小哥哥 已提交
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248

请参见下面的示例。

例:

```
class SimpleCustomBatch:
    def __init__(self, data):
        transposed_data = list(zip(*data))
        self.inp = torch.stack(transposed_data[0], 0)
        self.tgt = torch.stack(transposed_data[1], 0)

    # custom memory pinning method on custom type
    def pin_memory(self):
        self.inp = self.inp.pin_memory()
        self.tgt = self.tgt.pin_memory()
        return self

def collate_wrapper(batch):
    return SimpleCustomBatch(batch)

inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
dataset = TensorDataset(inps, tgts)

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                    pin_memory=True)

for batch_ndx, sample in enumerate(loader):
    print(sample.inp.is_pinned())
    print(sample.tgt.is_pinned())

```

* * *

```
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)¶
```

249
数据加载器。 合并了数据集和采样器,并为给定数据集上提供迭代器。
片刻小哥哥's avatar
片刻小哥哥 已提交
250

251
[`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 支持映射式和迭代式数据集,可进行单进程或多进程数据加载,可自定义加载顺序,自动批处理(整理)和内存固定为可选参数。
片刻小哥哥's avatar
片刻小哥哥 已提交
252 253 254 255 256

有关更多详细信息,请参见 [`torch.utils.data`](#module-torch.utils.data "torch.utils.data") 文档页面。

参数

257
*   **dataset** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset"))–要从中加载数据的数据集。
片刻小哥哥's avatar
片刻小哥哥 已提交
258

259
*   **batch_size**  (_python:int_ _,_ _可选_)–每批次要加载多少个样本。(默认值:`1`
片刻小哥哥's avatar
片刻小哥哥 已提交
260

261
*   **shuffle** (_bool_ _,_ _可选_)–设置为`True`时,数据每轮会被重新打乱。(默认值:`False` )
片刻小哥哥's avatar
片刻小哥哥 已提交
262

263
*   **sampler** ([_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler") _,_ _可选_)–采样器定义了从数据集中抽取样本的策略。 如果指定了采样器,则`shuffle`必须为`False`
片刻小哥哥's avatar
片刻小哥哥 已提交
264

265
*   **batch_sampler**  ([_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler") _,_ _可选_)–类似`sampler`,但每次只返回一个批次的索引。 与`batch_size``shuffle``sampler``drop_last`互斥。
片刻小哥哥's avatar
片刻小哥哥 已提交
266

267
*   **num_workers**  (_python:int_ _,_ _可选_)–数据加载需要的子进程数目。 `0`表示将在主进程中加载​​数据。 (默认:`0`
片刻小哥哥's avatar
片刻小哥哥 已提交
268

269
*   **collat​​e_fn** (可调用的_,_ _可选_)–合并样本列表以形成一个小批次的张量。 从映射式数据集中使用批量加载时使用。
片刻小哥哥's avatar
片刻小哥哥 已提交
270

271
*   **pin_memory**  (_bool_ _,_ _可选_)–如果值为`True`,则数据加载器将把张量复制到 CUDA 固定的内存中,然后返回。 如果您的数据元素是自定义类型,或者您的`collate_fn`返回的是自定义类型的批次,请参见下面的示例。
片刻小哥哥's avatar
片刻小哥哥 已提交
272

273
*   **drop_last** (_bool_ _,_ _可选_)–当数据集大小无法被批次大小整除时,若该参数设置为`True`,则最后一个不完整的批次将被删除;设置为`False`,则最后一个批次的大小将比设定的批次大小要小。(默认:`False`
片刻小哥哥's avatar
片刻小哥哥 已提交
274

275
*   **timeout**(_数字_ _,_ _可选_)–如果为正,则表示从工作进程中收集批次的超时值。 应始终为非负数。 (默认:`0`
片刻小哥哥's avatar
片刻小哥哥 已提交
276

277
*   **worker_init_fn** (_可调用_ _,_ _可选_)–如果不是`None`,则该函数将在生成种子之后和数据加载之前,在每个工作进程中被调用,其中工作进程的id(`[0, num_workers - 1]`范围内的整数值)是它的输入。(默认:`None`
片刻小哥哥's avatar
片刻小哥哥 已提交
278 279 280

Warning

281
如果使用`spawn`启动方法,则`worker_init_fn`不能是不可序列化对象,例如 lambda 函数。 有关 PyTorch 中与并行处理有关的更多详细信息,请参见[并行处理最佳实践](notes/multiprocessing.html#multiprocessing-best-practices)
片刻小哥哥's avatar
片刻小哥哥 已提交
282 283 284

Note

285
`len(dataloader)`启发式方法基于所用采样器的长度。 当`dataset`[`IterableDataset`](#torch.utils.data.IterableDataset "torch.utils.data.IterableDataset") 时,无论多进程加载配置如何,都将返回`len(dataset)`(如果实现),因为 PyTorch 信任用户的`dataset`代码可以正确处理多进程加载,避免重复数据。 有关这两种类型的数据集以及 [`IterableDataset`](#torch.utils.data.IterableDataset "torch.utils.data.IterableDataset") 如何与[多进程数据加载](#multi-process-data-loading)交互的更多详细信息,请参见[数据集类型](#dataset-types)
片刻小哥哥's avatar
片刻小哥哥 已提交
286 287 288 289 290 291 292 293 294

* * *

```
class torch.utils.data.Dataset¶
```

表示 [`Dataset`](#torch.utils.data.Dataset "torch.utils.data.Dataset") 的抽象类。

295
所有从键映射到数据样本的数据集都应该是它的子类。所有子类都应该重写`__getitem__()`,实现键值给定,返回对应数据样本的功能。 子类还可以选择重写`__len__()`,它的预计功能为,通过许多 [`Sampler`](#torch.utils.data.Sampler "torch.utils.data.Sampler") 实现以及 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 的默认选项,返回数据集的大小。
片刻小哥哥's avatar
片刻小哥哥 已提交
296 297 298

Note

299
默认情况下, [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 构造一个索引采样器,该采样器产生整数索引。 要使`DataLoader`与具有非整数索引/键的映射式数据集一起使用,必须提供自定义采样器。
片刻小哥哥's avatar
片刻小哥哥 已提交
300 301 302 303 304 305 306

* * *

```
class torch.utils.data.IterableDataset¶
```

307
迭代式数据集。
片刻小哥哥's avatar
片刻小哥哥 已提交
308

309
所有代表可迭代数据样本的数据集都应该是它的子类。当数据来自流时,这种形式的数据集特别有用。
片刻小哥哥's avatar
片刻小哥哥 已提交
310

311
所有子类都应重写`__iter__()`,该函数将返回此数据集中的样本迭代器。
片刻小哥哥's avatar
片刻小哥哥 已提交
312

313
当子类与 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 一起使用时,数据集中的每条数据都将由 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 迭代器产生。 当`num_workers > 0`时,每个工作进程将具有数据集对象的不同副本,因此通常需要独立配置每个副本,以避免从工作进程返回重复的数据。在工作程序进程中调用[`get_worker_info()`](#torch.utils.data.get_worker_info "torch.utils.data.get_worker_info")时,该函数将返回与工作进程有关的信息。可以在数据集的`__iter__()`方法或 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader")`worker_init_fn`选项中应用`get_worker_info()`来修改每个副本的行为。
片刻小哥哥's avatar
片刻小哥哥 已提交
314

315
示例 1:在`__iter__()`中将工作负载分配给所有工作进程:
片刻小哥哥's avatar
片刻小哥哥 已提交
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355

```
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20)))
[3, 4, 5, 6]

```

356
示例 2:使用`worker_init_fn`将工作负载分配给所有工作进程:
片刻小哥哥's avatar
片刻小哥哥 已提交
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409

```
>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         return iter(range(self.start, self.end))
...
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]
>>>
>>> # Directly doing multi-process loading yields duplicate data
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 3, 4, 4, 5, 5, 6, 6]

>>> # Define a `worker_init_fn` that configures each dataset copy differently
>>> def worker_init_fn(worker_id):
...     worker_info = torch.utils.data.get_worker_info()
...     dataset = worker_info.dataset  # the dataset copy in this worker process
...     overall_start = dataset.start
...     overall_end = dataset.end
...     # configure the dataset to only process the split workload
...     per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
...     worker_id = worker_info.id
...     dataset.start = overall_start + worker_id * per_worker
...     dataset.end = min(dataset.start + per_worker, overall_end)
...

>>> # Mult-process loading with the custom `worker_init_fn`
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2, worker_init_fn=worker_init_fn)))
[3, 5, 4, 6]

>>> # With even more workers
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=20, worker_init_fn=worker_init_fn)))
[3, 4, 5, 6]

```

* * *

```
class torch.utils.data.TensorDataset(*tensors)¶
```

410
包装张量的数据集。
片刻小哥哥's avatar
片刻小哥哥 已提交
411

412
每个样本将沿第一维度索引张量,以完成检索。
片刻小哥哥's avatar
片刻小哥哥 已提交
413 414 415

Parameters

416
***tensors** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor"))–张量,它的尺寸的与第一维大小相同。
片刻小哥哥's avatar
片刻小哥哥 已提交
417 418 419 420 421 422 423

* * *

```
class torch.utils.data.ConcatDataset(datasets)¶
```

424
该数据集是多个数据集的串联。
片刻小哥哥's avatar
片刻小哥哥 已提交
425

426
此类对于合并不同的现有数据集很有用。
片刻小哥哥's avatar
片刻小哥哥 已提交
427 428 429

Parameters

430
**datasets**(_序列_)–要串联的数据集列表
片刻小哥哥's avatar
片刻小哥哥 已提交
431 432 433 434 435 436 437 438 439

* * *

```
class torch.utils.data.ChainDataset(datasets)¶
```

用于链接多个 [`IterableDataset`](#torch.utils.data.IterableDataset "torch.utils.data.IterableDataset") 的数据集。

440
此类对于合并不同的现有数据集流很有用。 链接操作是即时完成的,因此将大型数据集与此类连接起来使用将非常有效。
片刻小哥哥's avatar
片刻小哥哥 已提交
441 442 443

Parameters

444
**datasets**(_IterableDataset 的可迭代对象_)– 要链接在一起的数据集
片刻小哥哥's avatar
片刻小哥哥 已提交
445 446 447 448 449 450 451 452 453 454 455

* * *

```
class torch.utils.data.Subset(dataset, indices)¶
```

指定索引处的数据集子集。

Parameters

456
*   **dataset** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset"))–整个数据集
片刻小哥哥's avatar
片刻小哥哥 已提交
457

458
*   **indices**(_序列_)–在整个数据集中为子集选择的索引序列
片刻小哥哥's avatar
片刻小哥哥 已提交
459 460 461 462 463 464 465

* * *

```
torch.utils.data.get_worker_info()¶
```

466
返回当前 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 迭代器工作进程的相关信息。
片刻小哥哥's avatar
片刻小哥哥 已提交
467

468
在工作进程中调用时,此方法返回一个对象,该对象保证具有以下属性:
片刻小哥哥's avatar
片刻小哥哥 已提交
469

470
*   `id`:当前工作进程 ID。
片刻小哥哥's avatar
片刻小哥哥 已提交
471

472
*   `num_workers`:工作进程总数。
片刻小哥哥's avatar
片刻小哥哥 已提交
473

474
*   `seed`:当前工作进程的随机种子集。 该值由主进程 RNG 和工作进程 ID 确定。 有关更多详细信息,请参见 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 的文档。
片刻小哥哥's avatar
片刻小哥哥 已提交
475

476
*   `dataset`:在**当前**进程中的数据集对象副本。请注意,在不同的进程中,该数据集对象与主进程中的对象将不同。
片刻小哥哥's avatar
片刻小哥哥 已提交
477 478 479 480 481

在主进程中调用时,将返回`None`

Note

482
当此方法用作`worker_init_fn`传递给 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 使用时,它可用于对每个工作进程进行不同设置,例如,使用`worker_id`配置`dataset`对象,只读共享数据集的某一特定部分,或在数据集代码中用`seed`为其他库(例如 NumPy)做种。
片刻小哥哥's avatar
片刻小哥哥 已提交
483 484 485 486 487 488 489

* * *

```
torch.utils.data.random_split(dataset, lengths)¶
```

490
将数据集随机切片为给定长度的不重叠的新数据集。
片刻小哥哥's avatar
片刻小哥哥 已提交
491 492 493

Parameters

494
*   **dataset** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset"))–要切片的数据集
片刻小哥哥's avatar
片刻小哥哥 已提交
495

496
*   **lengths**(_序列_)–要产生的切片的长度
片刻小哥哥's avatar
片刻小哥哥 已提交
497 498 499 500 501 502 503 504 505

* * *

```
class torch.utils.data.Sampler(data_source)¶
```

所有采样器的基类。

506
每个 Sampler 子类都必须提供`__iter__()`方法和`__len__()`方法,前者提供一种对数据集元素的索引进行迭代的方法,后者返回已返回的迭代器的长度。
片刻小哥哥's avatar
片刻小哥哥 已提交
507 508 509

Note

510
[`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 并未严格要求`__len__()`方法,但涉及 [`DataLoader`](#torch.utils.data.DataLoader "torch.utils.data.DataLoader") 长度的任意计算时,该方法都可能被调用。
片刻小哥哥's avatar
片刻小哥哥 已提交
511 512 513 514 515 516 517

* * *

```
class torch.utils.data.SequentialSampler(data_source)¶
```

518
始终以相同顺序序列化采样元素。
片刻小哥哥's avatar
片刻小哥哥 已提交
519 520 521

Parameters

522
**data_source**  ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset"))–要从中采样的数据集
片刻小哥哥's avatar
片刻小哥哥 已提交
523 524 525 526 527 528 529

* * *

```
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)¶
```

530
随机采样元素。 如果参数`replacement`等于`False`,则从乱序数据集中采样。 如果`replacement`等于`True`,则用户可以指定`num_samples`抽取数据。
片刻小哥哥's avatar
片刻小哥哥 已提交
531 532 533

Parameters

534
*   **data_source** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – 用于采样的数据集
片刻小哥哥's avatar
片刻小哥哥 已提交
535

536
*   **replacement** (_bool_ )–当值为`True`时,样本抽取有替换,`True`为默认值
片刻小哥哥's avatar
片刻小哥哥 已提交
537

538
*   **num_samples**  (_python:int_ )–要抽取的样本数,默认值为`len(dataset)`。 仅当`replacement``True`时才可指定此参数。
片刻小哥哥's avatar
片刻小哥哥 已提交
539 540 541 542 543 544 545

* * *

```
class torch.utils.data.SubsetRandomSampler(indices)¶
```

546
从给定的索引列表中随机抽样元素,无替换。
片刻小哥哥's avatar
片刻小哥哥 已提交
547 548 549

Parameters

550
**indices**(_序列_)–索引序列
片刻小哥哥's avatar
片刻小哥哥 已提交
551 552 553 554 555 556 557 558 559 560 561

* * *

```
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)¶
```

以给定的概率(权重)从`[0,..,len(weights)-1]`中采样元素。

Parameters

562
*   **weights**(_序列_)–权重序列,权重序列和不必为1
片刻小哥哥's avatar
片刻小哥哥 已提交
563

564
*   **num_samples**  (_python:int_ )–要抽取的样本数
片刻小哥哥's avatar
片刻小哥哥 已提交
565

566
*   **replacement** (_bool_ )–如果值为`True`,则抽取样本时可以有替代。 否则,抽取样本时无替换,这意味着当为一行抽取样本索引时,无法为该行再次抽取样本。
片刻小哥哥's avatar
片刻小哥哥 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583



```
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[0, 0, 0, 1, 0]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

```

* * *

```
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)¶
```

584
包装另一个采样器以生成一个小批次的索引。
片刻小哥哥's avatar
片刻小哥哥 已提交
585 586 587

Parameters

588
*   **sampler** ([_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler"))–基本采样器。
片刻小哥哥's avatar
片刻小哥哥 已提交
589

590
*   **batch_size**  (_python:int_ )–小批次的大小。
片刻小哥哥's avatar
片刻小哥哥 已提交
591

592
*   **drop_last**  (_bool_ )–如果值为`True`,则采样器将丢弃最后一批大小小于`batch_size`的数据
片刻小哥哥's avatar
片刻小哥哥 已提交
593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619

Example

```
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

```

* * *

```
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None, shuffle=True)¶
```

将数据加载限制为数据集子集的采样器。

[`torch.nn.parallel.DistributedDataParallel`](nn.html#torch.nn.parallel.DistributedDataParallel "torch.nn.parallel.DistributedDataParallel") 结合使用时特别有用。 在这种情况下,每个进程都可以将 DistributedSampler 实例作为 DataLoader 采样器传递,并加载原始数据集的专有子集。

Note

假定数据集大小恒定。

Parameters

620
*   **dataset** –用于采样的数据集。
片刻小哥哥's avatar
片刻小哥哥 已提交
621 622 623

*   **num_replicas** (_可选_)–参与分布式训练的进程数。

624
*   **rank**(_可选_)–当前进程在 num_replicas 中的等级。
片刻小哥哥's avatar
片刻小哥哥 已提交
625

626
*   **shuffle**(_可选_)–如果值为 true(默认值),采样器将随机打乱索引