## torch.utils.data ```py class torch.utils.data.Dataset ``` 表示数据集的抽象类。 所有其他数据集都应该进行子类化。所有子类应该覆盖`__len__`和`__getitem__`,`__len__`提供了数据集的大小,`__getitem__`支持整数索引,范围从 0 到 len(self)。 ```py class torch.utils.data.TensorDataset(data_tensor, target_tensor) ``` 包装数据和目标张量的数据集。 通过沿着第一个维度索引两个张量来恢复每个样本。 参数: 1. data_tensor (Tensor) - 包含样本数据 2. target_tensor (Tensor) - 包含样本目标(标签) 例子: ```py x = torch.linspace(1, 10, 10) # x data (torch tensor) y = torch.linspace(10, 1, 10) # y data (torch tensor) # 先转换成 torch 能识别的 Dataset torch_dataset = torch.utils.data.TensorDataset(data_tensor=x, target_tensor=y) ``` ```py class torch.utils.data.ConcatDataset(datasets) ``` 连接多个数据集。 目的:组合不同的现有数据集,可能是大规模数据集,因为连续操作是随意连接的。 * datasets 的参数:要连接的数据集列表 * datasets 样式:iterable ```py class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False) ``` 数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。 参数: 1. dataset (Dataset) – 从中加载数据的数据集。 2. batch_size (int, optional) – 批训练的数据个数(默认: 1)。 3. shuffle (bool, optional) – 设置为 True 在每个 epoch 重新排列数据(默认值:False,一般打乱比较好)。 4. sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略 shuffle 参数。 5. batch_sampler(sampler,可选) - 和 sampler 一样,但一次返回一批索引。与 batch_size,shuffle,sampler 和 drop_last 相互排斥。 6. num_workers (int, 可选) – 用于数据加载的子进程数。0 表示数据将在主进程中加载(默认值:0) 7. collate_fn (callable, optional) – 合并样本列表以形成小批量。 8. pin_memory (bool, optional) – 如果为 True,数据加载器在返回前将张量复制到 CUDA 固定内存中。 9. drop_last (bool, optional) – 如果数据集大小不能被 batch_size 整除,设置为 True 可删除最后一个不完整的批处理。如果设为 False 并且数据集的大小不能被 batch_size 整除,则最后一个 batch 将更小。(默认: False) ```py class torch.utils.data.sampler.Sampler(data_source) ``` 所有采样器的基础类。 每个采样器子类必须提供一个`__iter__`方法,提供一种迭代数据集元素的索引的方法,以及返回迭代器长度的`__len__`方法。 ```py class torch.utils.data.sampler.SequentialSampler(data_source) ``` 始终以相同的顺序将样本元素按顺序排列。 参数: * `data_source (Dataset)` – 采样的数据集。 ```py class torch.utils.data.sampler.RandomSampler(data_source) ``` 样本元素随机排列,并没有替换。 参数: - `data_source (Dataset)` – 采样的数据集。 ```py class torch.utils.data.sampler.SubsetRandomSampler(indices) ``` 样本元素从指定的索引列表中随机抽取,并没有替换。 参数: - `indices (list)` – 索引的列表 ```py class torch.utils.data.sampler.WeightedRandomSampler(weights, num_samples, replacement=True) ``` 样本元素来自[0,..,len(weights)-1],给定概率(权重)。 参数: * `weights (list)` – 权重列表。不需要加起来为 1 * `num_samples (int)` – 要绘制的样本数 ```py class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None) ``` 将数据加载限制到数据集的子集的采样器。 在`torch.nn.parallel.DistributedDataParallel`中是特别有用的。在这种情况下,每个进程都可以作为 DataLoader 采样器传递一个 DistributedSampler 实例,并加载独占的原始数据集的子集。 > 注意 假设数据集的大小不变。 参数: * dataset - 采样的数据集。 * num_replicas(可选) - 参与分布式训练的进程数。 * rank(可选) - num_replicas 中当前进程的等级。 ### 译者署名 | 用户名 | 头像 | 职能 | 签名 | | --- | --- | --- | --- | | [Song](https://ptorch.com) | ![](img/2018033000352689884.jpeg) | 翻译 | 人生总要追求点什么 |