未验证 提交 4c52674e 编写于 作者: P Paul 提交者: GitHub

Merge branch 'apachecn:master' into runzhi214-save_and_load_the_model

# 数据集和数据加载器
> 译者:[Daydaylight](https://github.com/Daydaylight)
>
> 项目地址:<https://pytorch.apachecn.org/2.0/tutorials/datasets_dataloaders>
>
> 原始地址:<https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>
处理数据样本的代码可能会变得杂乱无章,难以维护;我们希望我们的数据集代码与我们的模型训练代码分段,以提高可读性和模块化。
PyTorch提供了两个数据基元: ``torch.utils.data.DataLoader````torch.utils.data.Dataset``。允许你使用预先加载的数据集以及你自己的数据。
``Dataset``存储样本和它们相应的标签,``DataLoader````Dataset``基础上添加了一个迭代器,迭代器可以迭代数据集,以便能够轻松地访问``Dataset``中的样本。
PyTorch领域库提供了一些预加载的数据集(如FashionMNIST),这些数据集是``torch.utils.data.Dataset``的子类,并实现特定数据的功能。它们可以被用来为你的模型制作原型和基准。你可以找到它们这里:[Image Datasets](https://pytorch.org/vision/stable/datasets.html)[Text Datasets](https://pytorch.org/text/stable/datasets.html),和[Audio Datasets](https://pytorch.org/audio/stable/datasets.html)
##加载一个数据集
下面是一个如何从TorchVision加载[Fashion-MNIST](https://research.zalando.com/project/fashion_mnist/fashion_mnist/)数据集的例子。
Fashion-MNIST是一个由60,000个训练实例和10,000个测试实例组成的Zalando的文章图像数据集。
每个例子包括一个28×28的灰度图像和10个类别中的一个相关标签。
我们加载 [FashionMNIST Dataset](https://pytorch.org/vision/stable/datasets.html#fashion-mnist) ,参数如下:
- ``root`` 是存储训练/测试数据的路径,
- ``train`` 指定训练或测试数据集,
- ``download=True`` 如果``root``没有数据,就从网上下载数据。
- ``transform````target_transform`` 指定特征和标签的转换。
```py
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
```
## 迭代和可视化数据集
我们可以像列表一样手动索引``Datasets````training_data[index]``
我们使用``matplotlib``来可视化我们训练数据中的一些样本。
```py
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
```
![fashion_mnist.png](../../img/fashion_mnist.png)
## 为你的文件创建一个自定义数据集
一个自定义的数据集类必须实现三个函数: `__init__`, `__len__`, 和 `__getitem__`
看看这个实现;FashionMNIST的图片被存储在一个`img_dir'`的目录中,而它们的标签则分别存储在一个CSV文件``annotations_file`中。
在接下来的章节中,我们将分解这些函数中的每一个发生了什么。
```py
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
```
### __init__
在实例化数据集对象时,__init__函数被运行一次。我们初始化包含图像的目录、注释文件和两种转换(下一节将详细介绍)。
标签.csv文件看起来像:
```py
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
```
```py
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
```
### __len__
函数__len__返回我们数据集中的样本数。
Example:
```py
def __len__(self):
return len(self.img_labels)
```
### __getitem__
函数 __getitem__ 从数据集中给定的索引``idx``处加载并返回一个样本。根据索引,它确定图像在磁盘上的位置,用``read_image``将其转换为张量,从``self.img_labels``的csv数据中获取相应的标签。从``self.img_labels``中的csv数据中获取相应的标签,对它们调用transform函数(如果适用),并返回张量图像和相应的标签的元组。
```py
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
```
## 用DataLoaders准备你的数据进行训练
``Dataset``每次检索一个我们数据集的特征和标签样本。在训练一个模型时,我们通常希望以 "小批量 "的方式传递样本,在每个周期重新洗牌数据以减少模型的过拟合,并使用Python的``multiprocessing``来加快数据的检索速度。
`DataLoader'是一个可迭代的,它用一个简单的API为我们抽象出这种复杂性。
```py
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
```
## 遍历DataLoader
我们已经将该数据集加载到``DataLoader``中,并可以根据需要迭代该数据集。下面的每次迭代都会返回一批``train_features``和``train_labels``(分别包含``batch_size=64``的特征和标签)。因为我们指定了``shuffle=True``,在我们遍历所有批次后,数据会被洗牌(为了更精细地控制数据加载顺序的精细控制,请看[Samplers](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler))。
```py
# 显示图像和标签。
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
```
输出:
![fashion_mnist2.png](../../img/fashion_mnist2.png)
```py
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 5
```
## 阅读更多
- [torch.utils.data API](https://pytorch.org/docs/stable/data.html)
# TRANSFORMS
> 译者:[Daydaylight](https://github.com/Daydaylight)
>
> 项目地址:<https://pytorch.apachecn.org/2.0/tutorials/transforms>
>
> 原始地址:<https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html>
数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用变换来对数据进行一些处理,使其适合训练。
所有的TorchVision数据集都有两个参数-``transform``用于修改特征和``target_transform``用于修改标签,它们接受包含转换逻辑的callables。[torchvision.transforms](https://pytorch.org/vision/stable/transforms.html)模块提供了几个常用的转换,开箱即用。
FashionMNIST的特征是PIL图像格式,而标签是整数。对于训练,我们需要将特征作为归一化的张量,将标签作为一热编码的张量。
为了进行这些转换,我们使用 "ToTensor "和 "Lambda"。
```py
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
```
输出:
```py
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
0%| | 65536/26421880 [00:00<01:12, 361690.02it/s]
1%| | 229376/26421880 [00:00<00:38, 679756.53it/s]
2%|2 | 655360/26421880 [00:00<00:14, 1775435.30it/s]
7%|6 | 1736704/26421880 [00:00<00:06, 3785228.35it/s]
15%|#4 | 3833856/26421880 [00:00<00:02, 8223694.86it/s]
21%|##1 | 5570560/26421880 [00:00<00:02, 9088903.43it/s]
32%|###1 | 8454144/26421880 [00:01<00:01, 13772389.09it/s]
39%|###9 | 10420224/26421880 [00:01<00:01, 13068367.31it/s]
50%|##### | 13238272/26421880 [00:01<00:00, 16440554.97it/s]
58%|#####7 | 15269888/26421880 [00:01<00:00, 14938744.03it/s]
68%|######8 | 18055168/26421880 [00:01<00:00, 17703674.30it/s]
76%|#######6 | 20119552/26421880 [00:01<00:00, 15854480.37it/s]
87%|########6 | 22904832/26421880 [00:01<00:00, 18366169.37it/s]
95%|#########4| 25034752/26421880 [00:01<00:00, 16404116.31it/s]
100%|##########| 26421880/26421880 [00:02<00:00, 13106029.06it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 326257.67it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
1%|1 | 65536/4422102 [00:00<00:12, 362747.74it/s]
5%|5 | 229376/4422102 [00:00<00:06, 681864.40it/s]
15%|#4 | 655360/4422102 [00:00<00:02, 1798436.42it/s]
40%|#### | 1769472/4422102 [00:00<00:00, 3872995.18it/s]
79%|#######9 | 3506176/4422102 [00:00<00:00, 7404355.18it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 5422111.79it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 35867569.75it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
```
## ToTensor()
[ToTensor](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor)将PIL图像或NumPy的``ndarray``转换为``FloatTensor``。图像的像素强度值在[0., 1.]范围内缩放。
## Lambda Transforms
Lambda transforms 应用任何用户定义的lambda函数。在这里,我们定义了一个函数来把整数变成一个单热编码的张量。
它首先创建一个大小为10(我们数据集中的标签数量)的零张量,然后调用[scatter_](https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html) ,指定了一个``value=1``在标签``y``所给的索引上。
```py
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
```
### 阅读更多
- [torchvision.transforms API](https://pytorch.org/vision/stable/transforms.html)
......@@ -42,6 +42,8 @@ nav:
- "学习基本知识": "2.0/tutorials/Introduction_to_PyTorch/learn_the_basics.md"
- "快速入门": "2.0/tutorials/Introduction_to_PyTorch/quickstart_tutorial.md"
- "Tensors介绍": "2.0/tutorials/Introduction_to_PyTorch/tensors.md"
- "数据集和数据加载器": "2.0/tutorials/Introduction_to_PyTorch/datasets_dataloaders.md"
- "Transforms介绍": "2.0/tutorials/Introduction_to_PyTorch/transforms.md"
- "构建神经网络模型": "2.0/tutorials/Introduction_to_PyTorch/build_the_neural_network.md"
- "自动微分运算 - Torch.AutoGrad": "2.0/tutorials/Introduction_to_PyTorch/automatic_differentation_with_torch_autograd.md"
- "优化模型参数": "2.0/tutorials/Introduction_to_PyTorch/optimizing_model_parameters.md"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册