data.md 8.8 KB
Newer Older
W
wizardforcel 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193


# torch.utils.data

```py
class torch.utils.data.Dataset
```

An abstract class representing a Dataset.

All other datasets should subclass it. All subclasses should override `__len__`, that provides the size of the dataset, and `__getitem__`, supporting integer indexing in range from 0 to len(self) exclusive.

```py
class torch.utils.data.TensorDataset(*tensors)
```

Dataset wrapping tensors.

Each sample will be retrieved by indexing tensors along the first dimension.

| Parameters: | ***tensors** ([_Tensor_](tensors.html#torch.Tensor "torch.Tensor")) – tensors that have the same size of the first dimension. |
| --- | --- |

```py
class torch.utils.data.ConcatDataset(datasets)
```

Dataset to concatenate multiple datasets. Purpose: useful to assemble different existing datasets, possibly large-scale datasets as the concatenation operation is done in an on-the-fly manner.

| Parameters: | **datasets** (_sequence_) – List of datasets to be concatenated |
| --- | --- |

```py
class torch.utils.data.Subset(dataset, indices)
```

Subset of a dataset at specified indices.

| Parameters: | 

*   **dataset** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – The whole Dataset
*   **indices** (_sequence_) – Indices in the whole set selected for subset

 |
| --- | --- |

```py
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
```

Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.

| Parameters: | 

*   **dataset** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – dataset from which to load the data.
*   **batch_size** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")_,_ _optional_) – how many samples per batch to load (default: `1`).
*   **shuffle** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _optional_) – set to `True` to have the data reshuffled at every epoch (default: `False`).
*   **sampler** ([_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler")_,_ _optional_) – defines the strategy to draw samples from the dataset. If specified, `shuffle` must be False.
*   **batch_sampler** ([_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler")_,_ _optional_) – like sampler, but returns a batch of indices at a time. Mutually exclusive with `batch_size`, `shuffle`, `sampler`, and `drop_last`.
*   **num_workers** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")_,_ _optional_) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: `0`)
*   **collate_fn** (_callable__,_ _optional_) – merges a list of samples to form a mini-batch.
*   **pin_memory** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _optional_) – If `True`, the data loader will copy tensors into CUDA pinned memory before returning them.
*   **drop_last** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _optional_) – set to `True` to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If `False` and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: `False`)
*   **timeout** (_numeric__,_ _optional_) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: `0`)
*   **worker_init_fn** (_callable__,_ _optional_) – If not `None`, this will be called on each worker subprocess with the worker id (an int in `[0, num_workers - 1]`) as input, after seeding and before data loading. (default: `None`)

 |
| --- | --- |

Note

By default, each worker will have its PyTorch seed set to `base_seed + worker_id`, where `base_seed` is a long generated by main process using its RNG. However, seeds for other libraies may be duplicated upon initializing workers (w.g., NumPy), causing each worker to return identical random numbers. (See [My data loader workers return identical random numbers](notes/faq.html#dataloader-workers-random-seed) section in FAQ.) You may use [`torch.initial_seed()`](torch.html#torch.initial_seed "torch.initial_seed") to access the PyTorch seed for each worker in `worker_init_fn`, and use it to set other seeds before data loading.

Warning

If `spawn` start method is used, `worker_init_fn` cannot be an unpicklable object, e.g., a lambda function.

```py
torch.utils.data.random_split(dataset, lengths)
```

Randomly split a dataset into non-overlapping new datasets of given lengths.

| Parameters: | 

*   **dataset** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – Dataset to be split
*   **lengths** (_sequence_) – lengths of splits to be produced

 |
| --- | --- |

```py
class torch.utils.data.Sampler(data_source)
```

Base class for all Samplers.

Every Sampler subclass has to provide an __iter__ method, providing a way to iterate over indices of dataset elements, and a __len__ method that returns the length of the returned iterators.

```py
class torch.utils.data.SequentialSampler(data_source)
```

Samples elements sequentially, always in the same order.

| Parameters: | **data_source** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – dataset to sample from |
| --- | --- |

```py
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
```

Samples elements randomly. If without replacement, then sample from a shuffled dataset. If with replacement, then user can specify `num_samples` to draw.

| Parameters: | 

*   **data_source** ([_Dataset_](#torch.utils.data.Dataset "torch.utils.data.Dataset")) – dataset to sample from
*   **num_samples** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – number of samples to draw, default=len(dataset)
*   **replacement** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")) – samples are drawn with replacement if `True`, default=False

 |
| --- | --- |

```py
class torch.utils.data.SubsetRandomSampler(indices)
```

Samples elements randomly from a given list of indices, without replacement.

| Parameters: | **indices** (_sequence_) – a sequence of indices |
| --- | --- |

```py
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
```

Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

| Parameters: | 

*   **weights** (_sequence_) – a sequence of weights, not necessary summing up to one
*   **num_samples** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – number of samples to draw
*   **replacement** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")) – if `True`, samples are drawn with replacement. If not, they are drawn without replacement, which means that when a sample index is drawn for a row, it cannot be drawn again for that row.

 |
| --- | --- |

```py
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
```

Wraps another sampler to yield a mini-batch of indices.

| Parameters: | 

*   **sampler** ([_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler")) – Base sampler.
*   **batch_size** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – Size of mini-batch.
*   **drop_last** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")) – If `True`, the sampler will drop the last batch if its size would be less than `batch_size`

 |
| --- | --- |

Example

```py
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]

```

```py
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None)
```

Sampler that restricts data loading to a subset of the dataset.

It is especially useful in conjunction with [`torch.nn.parallel.DistributedDataParallel`](nn.html#torch.nn.parallel.DistributedDataParallel "torch.nn.parallel.DistributedDataParallel"). In such case, each process can pass a DistributedSampler instance as a DataLoader sampler, and load a subset of the original dataset that is exclusive to it.

Note

Dataset is assumed to be of constant size.

| Parameters: | 

*   **dataset** – Dataset used for sampling.
*   **num_replicas** (_optional_) – Number of processes participating in distributed training.
*   **rank** (_optional_) – Rank of the current process within num_replicas.

 |
| --- | --- |