data.md 8.1 KB
Newer Older
W
wizardforcel 已提交
1 2 3
# torch.utils.data

```py
W
wizardforcel 已提交
4
class torch.utils.data.Dataset
W
wizardforcel 已提交
5
```
B
BXuan694 已提交
6
表示数据集的抽象类。
W
wizardforcel 已提交
7

B
BXuan694 已提交
8
所有用到的数据集都必须是其子类。这些子类都必须重写以下方法:`__len__`:定义了数据集的规模;`__getitem__`:支持0到len(self)范围内的整数索引。
W
wizardforcel 已提交
9 10

```py
W
wizardforcel 已提交
11
class torch.utils.data.TensorDataset(*tensors)
W
wizardforcel 已提交
12 13
```

B
BXuan694 已提交
14
用于张量封装的Dataset类。
W
wizardforcel 已提交
15

B
BXuan694 已提交
16
张量可以沿第一个维度划分为样例之后进行检索。
W
wizardforcel 已提交
17

B
BXuan694 已提交
18
| 参数: | ***tensors** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – 第一个维度相同的张量。 |
W
wizardforcel 已提交
19 20 21
| --- | --- |

```py
W
wizardforcel 已提交
22
class torch.utils.data.ConcatDataset(datasets)
W
wizardforcel 已提交
23 24
```

B
BXuan694 已提交
25
用于融合不同数据集的Dataset类。目的:组合不同的现有数据集,鉴于融合操作是同时执行的,数据集规模可以很大。
W
wizardforcel 已提交
26

B
BXuan694 已提交
27
| 参数: | **datasets**(_序列_)– 要融合的数据集列表。 |
W
wizardforcel 已提交
28 29 30
| --- | --- |

```py
W
wizardforcel 已提交
31
class torch.utils.data.Subset(dataset, indices)
W
wizardforcel 已提交
32 33
```

B
BXuan694 已提交
34
用索引指定的数据集子集。
W
wizardforcel 已提交
35

B
BXuan694 已提交
36
参数: 
W
wizardforcel 已提交
37

B
BXuan694 已提交
38 39
*   **dataset**[_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset"))– 原数据集。
*   **indices**(_序列_)– 全集中选择作为子集的索引。
W
wizardforcel 已提交
40 41

```py
W
wizardforcel 已提交
42
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
W
wizardforcel 已提交
43
```
B
BXuan694 已提交
44
数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
W
wizardforcel 已提交
45

B
BXuan694 已提交
46 47 48 49 50 51 52 53 54 55 56 57
参数: 
*   **dataset**[_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – 要加载数据的数据集。
*   **batch_size**[_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")_,_ _可选_) – 每一批要加载多少数据(默认:`1`)。
*   **shuffle**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _可选_) – 如果每一个epoch内要打乱数据,就设置为`True`(默认:`False`)。
*   **sampler**[_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler")_,_ _可选_)– 定义了从数据集采数据的策略。如果这一选项指定了,`shuffle`必须是False。
*   **batch_sampler**[_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler")_,_ _可选_)– 类似于sampler,但是每次返回一批索引。和`batch_size``shuffle``sampler``drop_last`互相冲突。
*   **num_workers**[_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")_,_ _可选_) – 加载数据的子进程数量。0表示主进程加载数据(默认:`0`)。
*   **collate_fn**(_可调用_ _,_ _可选_)– 归并样例列表来组成小批。
*   **pin_memory**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _可选_)– 如果设置为`True`,数据加载器会在返回前将张量拷贝到CUDA锁页内存。
*   **drop_last**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _可选_)– 如果数据集的大小不能不能被批大小整除,该选项设为`True`后不会把最后的残缺批作为输入;如果设置为`False`,最后一个批将会稍微小一点。(默认:`False`
*   **timeout**(_数值_ _,_ _可选_) – 如果是正数,即为收集一个批数据的时间限制。必须非负。(默认:`0`
*   **worker_init_fn**(_可调用_ _,_ _可选_)– 如果不是`None`,每个worker子进程都会使用worker id(在`[0, num_workers - 1]`内的整数)进行调用作为输入,这一过程发生在设置种子之后、加载数据之前。(默认:`None`
W
wizardforcel 已提交
58 59 60



B
BXuan694 已提交
61
注意:
W
wizardforcel 已提交
62

B
BXuan694 已提交
63
默认地,每个worker都会有各自的PyTorch种子,设置方法是`base_seed + worker_id`,其中`base_seed`是主进程通过随机数生成器生成的long型数。而其它库(如NumPy)的种子可能由初始worker复制得到, 使得每一个worker返回相同的种子。(见FAQ中的[My data loader workers return identical random numbers](notes/faq.html#dataloader-workers-random-seed)部分。)你可以用[`torch.initial_seed()`](torch.html#torch.initial_seed "torch.initial_seed")查看`worker_init_fn`中每个worker的PyTorch种子,也可以在加载数据之前设置其他种子。
W
wizardforcel 已提交
64

B
BXuan694 已提交
65
警告:
W
wizardforcel 已提交
66

B
BXuan694 已提交
67
如果使用了`spawn`方法,那么`worker_init_fn`不能是不可序列化对象,如lambda函数。
W
wizardforcel 已提交
68 69

```py
W
wizardforcel 已提交
70
torch.utils.data.random_split(dataset, lengths)
W
wizardforcel 已提交
71 72
```

B
BXuan694 已提交
73
以给定的长度将数据集随机划分为不重叠的子数据集。
W
wizardforcel 已提交
74

B
BXuan694 已提交
75 76 77
参数:
*   **dataset** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – 要划分的数据集。
*   **lengths**(_序列_)– 要划分的长度。
W
wizardforcel 已提交
78

W
wizardforcel 已提交
79

W
wizardforcel 已提交
80 81

```py
W
wizardforcel 已提交
82
class torch.utils.data.Sampler(data_source)
W
wizardforcel 已提交
83 84
```

B
BXuan694 已提交
85
所有采样器的基类。
W
wizardforcel 已提交
86

B
BXuan694 已提交
87
每个Sampler子类必须提供__iter__方法,以便基于索引迭代数据集元素,同时__len__方法可以返回数据集大小。
W
wizardforcel 已提交
88 89

```py
W
wizardforcel 已提交
90
class torch.utils.data.SequentialSampler(data_source)
W
wizardforcel 已提交
91
```
B
BXuan694 已提交
92
以相同的顺序依次采样。
W
wizardforcel 已提交
93

B
BXuan694 已提交
94
| 参数: | **data_source** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – 要从中采样的数据集。 |
W
wizardforcel 已提交
95 96 97
| --- | --- |

```py
W
wizardforcel 已提交
98
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
W
wizardforcel 已提交
99 100
```

B
BXuan694 已提交
101
随机采样元素。如果replacement不设置,则从打乱之后的数据集采样。如果replacement设置了,那么用户可以指定`num_samples`来采样。
W
wizardforcel 已提交
102

B
BXuan694 已提交
103
参数:
W
wizardforcel 已提交
104

B
BXuan694 已提交
105 106 107
*   **data_source** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – 要从中采样的数据集。
*   **num_samples** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – 采样的样本数,默认为len(dataset)。
*   **replacement** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")) – 如果设置为`True`,替换采样。默认False。
W
wizardforcel 已提交
108 109

```py
W
wizardforcel 已提交
110
class torch.utils.data.SubsetRandomSampler(indices)
W
wizardforcel 已提交
111 112
```

B
BXuan694 已提交
113
从给定的索引列表中采样,不替换。
W
wizardforcel 已提交
114

B
BXuan694 已提交
115
| 参数: | **indices**(_序列_)– 索引序列 |
W
wizardforcel 已提交
116 117 118
| --- | --- |

```py
W
wizardforcel 已提交
119
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
W
wizardforcel 已提交
120 121
```

B
BXuan694 已提交
122
样本元素来自[0,..,len(weights)-1],,给定概率(权重)。
W
wizardforcel 已提交
123

B
BXuan694 已提交
124
参数:
W
wizardforcel 已提交
125

B
BXuan694 已提交
126 127 128
*   **weights**(_序列_) – 权重序列,不需要和为1。
*   **num_samples** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – 采样数。
*   **replacement** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")) – 如果是`True`,替换采样。否则不替换,即:如果某个样本索引已经采过了,那么不会继续被采。
W
wizardforcel 已提交
129 130

```py
W
wizardforcel 已提交
131
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
W
wizardforcel 已提交
132 133
```

B
BXuan694 已提交
134
打包采样器来获得小批。
W
wizardforcel 已提交
135

B
BXuan694 已提交
136
参数: 
W
wizardforcel 已提交
137

B
BXuan694 已提交
138 139 140
*   **sampler**[_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler"))– 基采样器。
*   **batch_size**[_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)"))– 小批的规模。
*   **drop_last**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)"))– 如果设置为`True`,采样器会丢弃最后一个不够`batch_size`的小批(如果存在的话)。
W
wizardforcel 已提交
141

B
BXuan694 已提交
142
示例
W
wizardforcel 已提交
143 144 145 146 147 148 149 150 151

```py
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
```

```py
W
wizardforcel 已提交
152
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None)
W
wizardforcel 已提交
153 154
```

B
BXuan694 已提交
155
将数据加载限制到数据集子集的采样器。
W
wizardforcel 已提交
156

B
BXuan694 已提交
157
[`torch.nn.parallel.DistributedDataParallel`](nn.html#torch.nn.parallel.DistributedDataParallel "torch.nn.parallel.DistributedDataParallel")同时使用时尤其有效。在这中情况下,每个进程会传递一个DistributedSampler实例作为DataLoader采样器,并加载独占的原始数据集的子集。
W
wizardforcel 已提交
158

B
BXuan694 已提交
159
注意:
W
wizardforcel 已提交
160

B
BXuan694 已提交
161
假设数据集的大小不变。
W
wizardforcel 已提交
162

B
BXuan694 已提交
163
参数: 
W
wizardforcel 已提交
164

B
BXuan694 已提交
165 166 167
*   **dataset** – 采样的数据集。
*   **num_replicas**(_可选_)– 参与分布式训练的进程数。
*   **rank**(_可选_)– num_replicas中当前进程的等级。