Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenDocCN
pytorch-doc-zh
提交
4c52674e
P
pytorch-doc-zh
项目概览
OpenDocCN
/
pytorch-doc-zh
通知
122
Star
3932
Fork
992
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
pytorch-doc-zh
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
4c52674e
编写于
6月 30, 2023
作者:
P
Paul
提交者:
GitHub
6月 30, 2023
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'apachecn:master' into runzhi214-save_and_load_the_model
上级
893c1d56
8b2ae399
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
299 addition
and
0 deletion
+299
-0
docs/2.0/img/fashion_mnist.png
docs/2.0/img/fashion_mnist.png
+0
-0
docs/2.0/img/fashion_mnist2.png
docs/2.0/img/fashion_mnist2.png
+0
-0
docs/2.0/tutorials/Introduction_to_PyTorch/datasets_dataloaders.md
...tutorials/Introduction_to_PyTorch/datasets_dataloaders.md
+202
-0
docs/2.0/tutorials/Introduction_to_PyTorch/transforms.md
docs/2.0/tutorials/Introduction_to_PyTorch/transforms.md
+95
-0
mkdocs.yml
mkdocs.yml
+2
-0
未找到文件。
docs/2.0/img/fashion_mnist.png
0 → 100644
浏览文件 @
4c52674e
30.7 KB
docs/2.0/img/fashion_mnist2.png
0 → 100644
浏览文件 @
4c52674e
9.5 KB
docs/2.0/tutorials/Introduction_to_PyTorch/datasets_dataloaders.md
0 → 100644
浏览文件 @
4c52674e
# 数据集和数据加载器
> 译者:[Daydaylight](https://github.com/Daydaylight)
>
> 项目地址:<https://pytorch.apachecn.org/2.0/tutorials/datasets_dataloaders>
>
> 原始地址:<https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>
处理数据样本的代码可能会变得杂乱无章,难以维护;我们希望我们的数据集代码与我们的模型训练代码分段,以提高可读性和模块化。
PyTorch提供了两个数据基元:
``torch.utils.data.DataLoader``
和
``torch.utils.data.Dataset``
。允许你使用预先加载的数据集以及你自己的数据。
``Dataset``
存储样本和它们相应的标签,
``DataLoader``
在
``Dataset``
基础上添加了一个迭代器,迭代器可以迭代数据集,以便能够轻松地访问
``Dataset``
中的样本。
PyTorch领域库提供了一些预加载的数据集(如FashionMNIST),这些数据集是
``torch.utils.data.Dataset``
的子类,并实现特定数据的功能。它们可以被用来为你的模型制作原型和基准。你可以找到它们这里:
[
Image Datasets
](
https://pytorch.org/vision/stable/datasets.html
)
、
[
Text Datasets
](
https://pytorch.org/text/stable/datasets.html
)
,和
[
Audio Datasets
](
https://pytorch.org/audio/stable/datasets.html
)
##加载一个数据集
下面是一个如何从TorchVision加载
[
Fashion-MNIST
](
https://research.zalando.com/project/fashion_mnist/fashion_mnist/
)
数据集的例子。
Fashion-MNIST是一个由60,000个训练实例和10,000个测试实例组成的Zalando的文章图像数据集。
每个例子包括一个28×28的灰度图像和10个类别中的一个相关标签。
我们加载
[
FashionMNIST Dataset
](
https://pytorch.org/vision/stable/datasets.html#fashion-mnist
)
,参数如下:
-
``root``
是存储训练/测试数据的路径,
-
``train``
指定训练或测试数据集,
-
``download=True``
如果
``root``
没有数据,就从网上下载数据。
-
``transform``
和
``target_transform``
指定特征和标签的转换。
```
py
import
torch
from
torch.utils.data
import
Dataset
from
torchvision
import
datasets
from
torchvision.transforms
import
ToTensor
import
matplotlib.pyplot
as
plt
training_data
=
datasets
.
FashionMNIST
(
root
=
"data"
,
train
=
True
,
download
=
True
,
transform
=
ToTensor
()
)
test_data
=
datasets
.
FashionMNIST
(
root
=
"data"
,
train
=
False
,
download
=
True
,
transform
=
ToTensor
()
)
```
## 迭代和可视化数据集
我们可以像列表一样手动索引
``Datasets``
:
``training_data[index]``
。
我们使用
``matplotlib``
来可视化我们训练数据中的一些样本。
```
py
labels_map
=
{
0
:
"T-Shirt"
,
1
:
"Trouser"
,
2
:
"Pullover"
,
3
:
"Dress"
,
4
:
"Coat"
,
5
:
"Sandal"
,
6
:
"Shirt"
,
7
:
"Sneaker"
,
8
:
"Bag"
,
9
:
"Ankle Boot"
,
}
figure
=
plt
.
figure
(
figsize
=
(
8
,
8
))
cols
,
rows
=
3
,
3
for
i
in
range
(
1
,
cols
*
rows
+
1
):
sample_idx
=
torch
.
randint
(
len
(
training_data
),
size
=
(
1
,)).
item
()
img
,
label
=
training_data
[
sample_idx
]
figure
.
add_subplot
(
rows
,
cols
,
i
)
plt
.
title
(
labels_map
[
label
])
plt
.
axis
(
"off"
)
plt
.
imshow
(
img
.
squeeze
(),
cmap
=
"gray"
)
plt
.
show
()
```
![
fashion_mnist.png
](
../../img/fashion_mnist.png
)
## 为你的文件创建一个自定义数据集
一个自定义的数据集类必须实现三个函数:
`__init__`
,
`__len__`
, 和
`__getitem__`
。
看看这个实现;FashionMNIST的图片被存储在一个
`img_dir'`
的目录中,而它们的标签则分别存储在一个CSV文件
``annotations_file`中。
在接下来的章节中,我们将分解这些函数中的每一个发生了什么。
```py
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
```
### __init__
在实例化数据集对象时,__init__函数被运行一次。我们初始化包含图像的目录、注释文件和两种转换(下一节将详细介绍)。
标签.csv文件看起来像:
```py
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
```
```py
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
```
### __len__
函数__len__返回我们数据集中的样本数。
Example:
```py
def __len__(self):
return len(self.img_labels)
```
### __getitem__
函数 __getitem__ 从数据集中给定的索引``
idx
``处加载并返回一个样本。根据索引,它确定图像在磁盘上的位置,用``
read_image
``将其转换为张量,从``
self.img_labels
``的csv数据中获取相应的标签。从``
self.img_labels
``中的csv数据中获取相应的标签,对它们调用transform函数(如果适用),并返回张量图像和相应的标签的元组。
```py
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
```
## 用DataLoaders准备你的数据进行训练
``
Dataset
``每次检索一个我们数据集的特征和标签样本。在训练一个模型时,我们通常希望以 "小批量 "的方式传递样本,在每个周期重新洗牌数据以减少模型的过拟合,并使用Python的``
multiprocessing
``来加快数据的检索速度。
`DataLoader'是一个可迭代的,它用一个简单的API为我们抽象出这种复杂性。
```py
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
```
## 遍历DataLoader
我们已经将该数据集加载到``
DataLoader
``中,并可以根据需要迭代该数据集。下面的每次迭代都会返回一批``
train_features
``和``
train_labels
``(分别包含``
batch_size=64
``的特征和标签)。因为我们指定了``
shuffle=True
``,在我们遍历所有批次后,数据会被洗牌(为了更精细地控制数据加载顺序的精细控制,请看[Samplers](https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler))。
```py
# 显示图像和标签。
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
```
输出:
![fashion_mnist2.png](../../img/fashion_mnist2.png)
```py
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 5
`
``
## 阅读更多
-
[
torch.utils.data API
](
https://pytorch.org/docs/stable/data.html
)
docs/2.0/tutorials/Introduction_to_PyTorch/transforms.md
0 → 100644
浏览文件 @
4c52674e
# TRANSFORMS
> 译者:[Daydaylight](https://github.com/Daydaylight)
>
> 项目地址:<https://pytorch.apachecn.org/2.0/tutorials/transforms>
>
> 原始地址:<https://pytorch.org/tutorials/beginner/basics/transforms_tutorial.html>
数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用变换来对数据进行一些处理,使其适合训练。
所有的TorchVision数据集都有两个参数-
``transform``
用于修改特征和
``target_transform``
用于修改标签,它们接受包含转换逻辑的callables。
[
torchvision.transforms
](
https://pytorch.org/vision/stable/transforms.html
)
模块提供了几个常用的转换,开箱即用。
FashionMNIST的特征是PIL图像格式,而标签是整数。对于训练,我们需要将特征作为归一化的张量,将标签作为一热编码的张量。
为了进行这些转换,我们使用 "ToTensor "和 "Lambda"。
```
py
import
torch
from
torchvision
import
datasets
from
torchvision.transforms
import
ToTensor
,
Lambda
ds
=
datasets
.
FashionMNIST
(
root
=
"data"
,
train
=
True
,
download
=
True
,
transform
=
ToTensor
(),
target_transform
=
Lambda
(
lambda
y
:
torch
.
zeros
(
10
,
dtype
=
torch
.
float
).
scatter_
(
0
,
torch
.
tensor
(
y
),
value
=
1
))
)
```
输出:
```
py
Downloading
http
:
//
fashion
-
mnist
.
s3
-
website
.
eu
-
central
-
1.
amazonaws
.
com
/
train
-
images
-
idx3
-
ubyte
.
gz
Downloading
http
:
//
fashion
-
mnist
.
s3
-
website
.
eu
-
central
-
1.
amazonaws
.
com
/
train
-
images
-
idx3
-
ubyte
.
gz
to
data
/
FashionMNIST
/
raw
/
train
-
images
-
idx3
-
ubyte
.
gz
0
%|
|
0
/
26421880
[
00
:
00
<
?
,
?
it
/
s
]
0
%|
|
65536
/
26421880
[
00
:
00
<
01
:
12
,
361690.02
it
/
s
]
1
%|
|
229376
/
26421880
[
00
:
00
<
00
:
38
,
679756.53
it
/
s
]
2
%|
2
|
655360
/
26421880
[
00
:
00
<
00
:
14
,
1775435.30
it
/
s
]
7
%|
6
|
1736704
/
26421880
[
00
:
00
<
00
:
06
,
3785228.35
it
/
s
]
15
%|
#4 | 3833856/26421880 [00:00<00:02, 8223694.86it/s]
21
%|
##1 | 5570560/26421880 [00:00<00:02, 9088903.43it/s]
32
%|
###1 | 8454144/26421880 [00:01<00:01, 13772389.09it/s]
39
%|
###9 | 10420224/26421880 [00:01<00:01, 13068367.31it/s]
50
%|
##### | 13238272/26421880 [00:01<00:00, 16440554.97it/s]
58
%|
#####7 | 15269888/26421880 [00:01<00:00, 14938744.03it/s]
68
%|
######8 | 18055168/26421880 [00:01<00:00, 17703674.30it/s]
76
%|
#######6 | 20119552/26421880 [00:01<00:00, 15854480.37it/s]
87
%|
########6 | 22904832/26421880 [00:01<00:00, 18366169.37it/s]
95
%|
#########4| 25034752/26421880 [00:01<00:00, 16404116.31it/s]
100
%|
##########| 26421880/26421880 [00:02<00:00, 13106029.06it/s]
Extracting
data
/
FashionMNIST
/
raw
/
train
-
images
-
idx3
-
ubyte
.
gz
to
data
/
FashionMNIST
/
raw
Downloading
http
:
//
fashion
-
mnist
.
s3
-
website
.
eu
-
central
-
1.
amazonaws
.
com
/
train
-
labels
-
idx1
-
ubyte
.
gz
Downloading
http
:
//
fashion
-
mnist
.
s3
-
website
.
eu
-
central
-
1.
amazonaws
.
com
/
train
-
labels
-
idx1
-
ubyte
.
gz
to
data
/
FashionMNIST
/
raw
/
train
-
labels
-
idx1
-
ubyte
.
gz
0
%|
|
0
/
29515
[
00
:
00
<
?
,
?
it
/
s
]
100
%|
##########| 29515/29515 [00:00<00:00, 326257.67it/s]
Extracting
data
/
FashionMNIST
/
raw
/
train
-
labels
-
idx1
-
ubyte
.
gz
to
data
/
FashionMNIST
/
raw
Downloading
http
:
//
fashion
-
mnist
.
s3
-
website
.
eu
-
central
-
1.
amazonaws
.
com
/
t10k
-
images
-
idx3
-
ubyte
.
gz
Downloading
http
:
//
fashion
-
mnist
.
s3
-
website
.
eu
-
central
-
1.
amazonaws
.
com
/
t10k
-
images
-
idx3
-
ubyte
.
gz
to
data
/
FashionMNIST
/
raw
/
t10k
-
images
-
idx3
-
ubyte
.
gz
0
%|
|
0
/
4422102
[
00
:
00
<
?
,
?
it
/
s
]
1
%|
1
|
65536
/
4422102
[
00
:
00
<
00
:
12
,
362747.74
it
/
s
]
5
%|
5
|
229376
/
4422102
[
00
:
00
<
00
:
06
,
681864.40
it
/
s
]
15
%|
#4 | 655360/4422102 [00:00<00:02, 1798436.42it/s]
40
%|
#### | 1769472/4422102 [00:00<00:00, 3872995.18it/s]
79
%|
#######9 | 3506176/4422102 [00:00<00:00, 7404355.18it/s]
100
%|
##########| 4422102/4422102 [00:00<00:00, 5422111.79it/s]
Extracting
data
/
FashionMNIST
/
raw
/
t10k
-
images
-
idx3
-
ubyte
.
gz
to
data
/
FashionMNIST
/
raw
Downloading
http
:
//
fashion
-
mnist
.
s3
-
website
.
eu
-
central
-
1.
amazonaws
.
com
/
t10k
-
labels
-
idx1
-
ubyte
.
gz
Downloading
http
:
//
fashion
-
mnist
.
s3
-
website
.
eu
-
central
-
1.
amazonaws
.
com
/
t10k
-
labels
-
idx1
-
ubyte
.
gz
to
data
/
FashionMNIST
/
raw
/
t10k
-
labels
-
idx1
-
ubyte
.
gz
0
%|
|
0
/
5148
[
00
:
00
<
?
,
?
it
/
s
]
100
%|
##########| 5148/5148 [00:00<00:00, 35867569.75it/s]
Extracting
data
/
FashionMNIST
/
raw
/
t10k
-
labels
-
idx1
-
ubyte
.
gz
to
data
/
FashionMNIST
/
raw
```
## ToTensor()
[
ToTensor
](
https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.ToTensor
)
将PIL图像或NumPy的
``ndarray``
转换为
``FloatTensor``
。图像的像素强度值在[0., 1.]范围内缩放。
## Lambda Transforms
Lambda transforms 应用任何用户定义的lambda函数。在这里,我们定义了一个函数来把整数变成一个单热编码的张量。
它首先创建一个大小为10(我们数据集中的标签数量)的零张量,然后调用
[
scatter_
](
https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html
)
,指定了一个
``value=1``
在标签
``y``
所给的索引上。
```
py
target_transform
=
Lambda
(
lambda
y
:
torch
.
zeros
(
10
,
dtype
=
torch
.
float
).
scatter_
(
dim
=
0
,
index
=
torch
.
tensor
(
y
),
value
=
1
))
```
### 阅读更多
-
[
torchvision.transforms API
](
https://pytorch.org/vision/stable/transforms.html
)
mkdocs.yml
浏览文件 @
4c52674e
...
...
@@ -42,6 +42,8 @@ nav:
-
"
学习基本知识"
:
"
2.0/tutorials/Introduction_to_PyTorch/learn_the_basics.md"
-
"
快速入门"
:
"
2.0/tutorials/Introduction_to_PyTorch/quickstart_tutorial.md"
-
"
Tensors介绍"
:
"
2.0/tutorials/Introduction_to_PyTorch/tensors.md"
-
"
数据集和数据加载器"
:
"
2.0/tutorials/Introduction_to_PyTorch/datasets_dataloaders.md"
-
"
Transforms介绍"
:
"
2.0/tutorials/Introduction_to_PyTorch/transforms.md"
-
"
构建神经网络模型"
:
"
2.0/tutorials/Introduction_to_PyTorch/build_the_neural_network.md"
-
"
自动微分运算
-
Torch.AutoGrad"
:
"
2.0/tutorials/Introduction_to_PyTorch/automatic_differentation_with_torch_autograd.md"
-
"
优化模型参数"
:
"
2.0/tutorials/Introduction_to_PyTorch/optimizing_model_parameters.md"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录