未验证 提交 f11de918 编写于 作者: 飞龙 提交者: GitHub

Update data.md

上级 99b6d8cc
# torch.utils.data
> 译者:[BXuan694](https://github.com/BXuan694)
```py
class torch.utils.data.Dataset
```
表示数据集的抽象类。
所有用到的数据集都必须是其子类。这些子类都必须重写以下方法:`__len__`:定义了数据集的规模;`__getitem__`:支持0到len(self)范围内的整数索引。
所有用到的数据集都必须是其子类。这些子类都必须重写以下方法:`__len__`:定义了数据集的规模;`__getitem__`:支持0到len(self)范围内的整数索引。
```py
class torch.utils.data.TensorDataset(*tensors)
......@@ -31,7 +33,7 @@ class torch.utils.data.ConcatDataset(datasets)
class torch.utils.data.Subset(dataset, indices)
```
用索引指定的数据集子集。
用索引指定的数据集子集。
参数:
......@@ -52,7 +54,7 @@ class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=
* **num_workers**[_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")_,_ _可选_) – 加载数据的子进程数量。0表示主进程加载数据(默认:`0`)。
* **collate_fn**(_可调用_ _,_ _可选_)– 归并样例列表来组成小批。
* **pin_memory**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _可选_)– 如果设置为`True`,数据加载器会在返回前将张量拷贝到CUDA锁页内存。
* **drop_last**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _可选_)– 如果数据集的大小不能不能被批大小整除,该选项设为`True`后不会把最后的残缺批作为输入;如果设置为`False`,最后一个批将会稍微小一点。(默认:`False`
* **drop_last**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")_,_ _可选_)– 如果数据集的大小不能不能被批大小整除,该选项设为`True`后不会把最后的残缺批作为输入;如果设置为`False`,最后一个批将会稍微小一点。(默认:`False`
* **timeout**(_数值_ _,_ _可选_) – 如果是正数,即为收集一个批数据的时间限制。必须非负。(默认:`0`
* **worker_init_fn**(_可调用_ _,_ _可选_)– 如果不是`None`,每个worker子进程都会使用worker id(在`[0, num_workers - 1]`内的整数)进行调用作为输入,这一过程发生在设置种子之后、加载数据之前。(默认:`None`
......@@ -98,7 +100,7 @@ class torch.utils.data.SequentialSampler(data_source)
class torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
```
随机采样元素。如果replacement不设置,则从打乱之后的数据集采样。如果replacement设置了,那么用户可以指定`num_samples`来采样。
随机采样元素。如果replacement不设置,则从打乱之后的数据集采样。如果replacement设置了,那么用户可以指定`num_samples`来采样。
参数:
......@@ -123,7 +125,7 @@ class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=T
参数:
* **weights**(_序列_) – 权重序列,不需要和为1。
* **weights**(_序列_) – 权重序列,不需要和为1。
* **num_samples** ([_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)")) – 采样数。
* **replacement** ([_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)")) – 如果是`True`,替换采样。否则不替换,即:如果某个样本索引已经采过了,那么不会继续被采。
......@@ -137,7 +139,7 @@ class torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
* **sampler**[_Sampler_](#torch.utils.data.Sampler "torch.utils.data.Sampler"))– 基采样器。
* **batch_size**[_int_](https://docs.python.org/3/library/functions.html#int "(in Python v3.7)"))– 小批的规模。
* **drop_last**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)"))– 如果设置为`True`,采样器会丢弃最后一个不够`batch_size`的小批(如果存在的话)。
* **drop_last**[_bool_](https://docs.python.org/3/library/functions.html#bool "(in Python v3.7)"))– 如果设置为`True`,采样器会丢弃最后一个不够`batch_size`的小批(如果存在的话)。
示例
......@@ -164,4 +166,4 @@ class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None
* **dataset** – 采样的数据集。
* **num_replicas**(_可选_)– 参与分布式训练的进程数。
* **rank**(_可选_)– num_replicas中当前进程的等级。
\ No newline at end of file
* **rank**(_可选_)– num_replicas中当前进程的等级。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册