提交 7a6a1c9b 编写于 作者: M minqiyang

Add dygraph models

上级 1fcde971
# MNIST
当我们学习编程的时候,编写的第一个程序一般是实现打印"Hello World"。而机器学习(或深度学习)的入门教程,一般都是 MNIST 数据库上的手写识别问题。原因是手写识别属于典型的图像分类问题,比较简单,同时MNIST数据集也很完备。
本页将介绍如何使用PaddlePaddle在DyGraph模式下实现MNIST,包括[安装](#installation)[训练](#training-a-model)[输出](#log)[参数保存](#save)[模型评估](#evaluation)
---
## 内容
- [安装](#installation)
- [训练](#training-a-model)
- [输出](#log)
## 安装
在当前目录下运行样例代码需要PadddlePaddle Fluid的v1.4.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据安装文档中的说明来更新PaddlePaddle。
## 训练
教程中使用`paddle.dataset.mnist`数据集作为训练数据,可以通过如下的方式启动训练:
```
env CUDA_VISIBLE_DEVICES=0 python mnist_dygraph.py
```
## 输出
执行训练开始后,将得到类似如下的输出。
```
Loss at epoch 0 step 0: [2.3043773]
Loss at epoch 0 step 100: [0.20764539]
Loss at epoch 0 step 200: [0.18648806]
Loss at epoch 0 step 300: [0.10279777]
Loss at epoch 0 step 400: [0.03940877]
...
```
## 参数保存
调用`fluid.dygraph.save_persistables()`接口可以把模型的参数进行保存。
```python
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
```
## 测试
执行`mnist.eval()`可以切换至评估状态,即不更新只使用参数进行训练,通过这种方式进行测试或者评估。
```python
mnist.eval()
```
## 模型评估
我们使用手写数据集中的一张图片来进行评估。为了区别训练模型,我们使用`with fluid.dygraph.guard()`来切换到一个新的参数空间,然后构建一个用于评估的网络`mnist_infer`,并通过`mnist_infer.load_dict()`来加载使用`fluid.dygraph.load_persistables`读取的参数。然后用`mnist_infer.eval()`切换到评估。
```python
with fluid.dygraph.guard():
mnist_infer = MNIST("mnist")
# load checkpoint
mnist_infer.load_dict(
fluid.dygraph.load_persistables("save_dir"))
# start evaluate mode
mnist_infer.eval()
```
如果无意外,将可以看到预测的结果:
```text
Inference result of image/infer_3.png is: 3
```
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
from PIL import Image
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
from paddle.fluid.dygraph.base import to_variable
class SimpleImgConvPool(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__(name_scope)
self._conv2d = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
use_cudnn=use_cudnn)
self._pool2d = Pool2D(
self.full_name(),
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
class MNIST(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(MNIST, self).__init__(name_scope)
self._simple_img_conv_pool_1 = SimpleImgConvPool(
self.full_name(), 1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
self.full_name(), 20, 50, 5, 2, 2, act="relu")
pool_2_shape = 50 * 4 * 4
SIZE = 10
scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
self._fc = FC(self.full_name(),
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label=None):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
def test_train(reader, model, batch_size):
acc_set = []
avg_loss_set = []
for batch_id, data in enumerate(reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(batch_size, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
prediction, acc = model(img, label)
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
acc_set.append(float(acc.numpy()))
avg_loss_set.append(float(avg_loss.numpy()))
# get test acc and loss
acc_val_mean = np.array(acc_set).mean()
avg_loss_val_mean = np.array(avg_loss_set).mean()
return avg_loss_val_mean, acc_val_mean
def train_mnist():
epoch_num = 5
BATCH_SIZE = 64
with fluid.dygraph.guard():
mnist = MNIST("mnist")
adam = AdamOptimizer(learning_rate=0.001)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE, drop_last=True)
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(BATCH_SIZE, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
cost, acc = mnist(img, label)
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
# save checkpoint
mnist.clear_gradients()
if batch_id % 100 == 0:
print("Loss at epoch {} step {}: {:}".format(epoch, batch_id, avg_loss.numpy()))
mnist.eval()
test_cost, test_acc = test_train(test_reader, mnist, BATCH_SIZE)
mnist.train()
print("Loss at epoch {} , Test avg_loss is: {}, acc is: {}".format(epoch, test_cost, test_acc))
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
print("checkpoint saved")
with fluid.dygraph.guard():
mnist_infer = MNIST("mnist")
# load checkpoint
mnist_infer.load_dict(
fluid.dygraph.load_persistables("save_dir"))
print("checkpoint loaded")
# start evaluate mode
mnist_infer.eval()
def load_image(file):
im = Image.open(file).convert('L')
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = os.path.dirname(os.path.realpath(__file__))
tensor_img = load_image(cur_dir + '/image/infer_3.png')
results = mnist_infer(to_variable(tensor_img))
lab = np.argsort(results.numpy())
print("Inference result of image/infer_3.png is: %d" % lab[0][-1])
if __name__ == '__main__':
train_mnist()
DyGraph模式下Residual Network实现
========
简介
--------
Residual Network(ResNet)是常用的图像分类模型。我们实现了在paddlepaddle的DyGraph模式下相应的实现。可以对比原先静态图下实现([Residual Network](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification/models))来了解paddle中DyGraph模式。
运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
## 代码结构
```
└── train.py # 训练脚本。
```
## 使用的数据
教程中使用`paddle.dataset.flowers`数据集作为训练数据,该数据集通过`paddle.dataset`模块自动下载到本地。
## 训练测试Residual Network
在GPU单卡上训练Residual Network:
```
env CUDA_VISIBLE_DEVICES=0 python train.py
```
这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。
## 输出
执行训练开始后,将得到类似如下的输出。每一轮`batch`训练将会打印当前epoch、step以及loss值。当前默认执行`epoch=10`, `batch_size=8`。您可以调整参数以得到更好的训练效果,同时也意味着消耗更多的内存(显存)以及需要花费更长的时间。
```text
epoch id: 0, batch step: 0, loss: 4.951202
epoch id: 0, batch step: 1, loss: 5.268410
epoch id: 0, batch step: 2, loss: 5.123999
```
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC
from paddle.fluid.dygraph.base import to_variable
batch_size = 8
epoch = 10
def optimizer_setting():
return fluid.optimizer.SGD(learning_rate=0.01)
class ConvBNLayer(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None):
super(ConvBNLayer, self).__init__(name_scope)
self._conv = Conv2D(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
act=None,
bias_attr=None)
self._batch_norm = BatchNorm(self.full_name(), num_filters, act=act)
def forward(self, inputs):
y = self._conv(inputs)
y = self._batch_norm(y)
return y
class BottleneckBlock(fluid.dygraph.Layer):
def __init__(self,
name_scope,
num_channels,
num_filters,
stride,
shortcut=True):
super(BottleneckBlock, self).__init__(name_scope)
self.conv0 = ConvBNLayer(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act='relu')
self.conv1 = ConvBNLayer(
self.full_name(),
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act='relu')
self.conv2 = ConvBNLayer(
self.full_name(),
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None)
if not shortcut:
self.short = ConvBNLayer(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters * 4,
filter_size=1,
stride=stride)
self.shortcut = shortcut
self._num_channels_out = num_filters * 4
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
conv2 = self.conv2(conv1)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2)
layer_helper = LayerHelper(self.full_name(), act='relu')
return layer_helper.append_activation(y)
class ResNet(fluid.dygraph.Layer):
def __init__(self, name_scope, layers=50, class_dim=102):
super(ResNet, self).__init__(name_scope)
self.layers = layers
supported_layers = [50, 101, 152]
assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers)
if layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_filters = [64, 128, 256, 512]
self.conv = ConvBNLayer(
self.full_name(),
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act='relu')
self.pool2d_max = Pool2D(
self.full_name(),
pool_size=3,
pool_stride=2,
pool_padding=1,
pool_type='max')
self.bottleneck_block_list = []
num_channels = 64
for block in range(len(depth)):
shortcut = False
for i in range(depth[block]):
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
BottleneckBlock(
self.full_name(),
num_channels=num_channels,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut))
num_channels = bottleneck_block._num_channels_out
self.bottleneck_block_list.append(bottleneck_block)
shortcut = True
self.pool2d_avg = Pool2D(
self.full_name(), pool_size=7, pool_type='avg', global_pooling=True)
import math
stdv = 1.0 / math.sqrt(2048 * 1.0)
self.out = FC(self.full_name(),
size=class_dim,
act='softmax',
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = self.out(y)
return y
def train_resnet():
with fluid.dygraph.guard():
resnet = ResNet("resnet")
optimizer = optimizer_setting()
train_reader = paddle.batch(
paddle.dataset.flowers.train(),
batch_size=batch_size)
for eop in range(epoch):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(3, 224, 224) for x in data]).astype('float32')
if len(np.array([x[1] for x in data]).astype('int64')) != batch_size:
continue
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
batch_size, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label._stop_gradient = True
out = resnet(img)
loss = fluid.layers.cross_entropy(input=out, label=label)
avg_loss = fluid.layers.mean(x=loss)
dy_out = avg_loss.numpy()
avg_loss.backward()
optimizer.minimize(avg_loss)
resnet.clear_gradients()
print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
if __name__ == '__main__':
train_resnet()
## 简介
### 任务说明
机器翻译(machine translation, MT)是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程,输入为源语言句子,输出为相应的目标语言的句子。本示例是机器翻译主流模型 Transformer 的实现和相关介绍。
### 数据集说明
我们使用公开的 [WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)训练
可以将下载好的wmt16数据集放在`~/.cache/paddle/dataset/wmt16/`目录下
### 安装说明
1. paddle安装
本项目依赖于 Paddlepaddle Fluid 1.4.1,请参考安装指南进行安装。
2. 安装代码
3. 环境依赖
### 执行训练:
利用python解释器执行train.py即可
### 执行效果
W0422 13:25:53.853921 116144 device_context.cc:261] Please NOTE: device: 0, CUDA Capability: 35, Driver API Version: 9.0, Runtime API Version: 8.0
W0422 13:25:53.861614 116144 device_context.cc:269] device: 0, cuDNN Version: 7.0.
pass num : 0, batch_id: 10, dy_graph avg loss: [9.033163]
pass num : 0, batch_id: 20, dy_graph avg loss: [8.869838]
pass num : 0, batch_id: 30, dy_graph avg loss: [8.635877]
pass num : 0, batch_id: 40, dy_graph avg loss: [8.460026]
pass num : 0, batch_id: 50, dy_graph avg loss: [8.293438]
pass num : 0, batch_id: 60, dy_graph avg loss: [8.138791]
pass num : 0, batch_id: 70, dy_graph avg loss: [7.9594088]
pass num : 0, batch_id: 80, dy_graph avg loss: [7.7303553]
pass num : 0, batch_id: 90, dy_graph avg loss: [7.6716228]
pass num : 0, batch_id: 100, dy_graph avg loss: [7.611051]
pass num : 0, batch_id: 110, dy_graph avg loss: [7.4179897]
pass num : 0, batch_id: 120, dy_graph avg loss: [7.318419]
## 进阶使用
### 模型原理介绍
Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构。其同样使用了 Seq2Seq 任务中典型的编码器-解码器(Encoder-Decoder)的框架结构,但相较于此前广泛使用的循环神经网络(Recurrent Neural Network, RNN),其完全使用注意力(Attention)机制来实现序列到序列的建模,整体网络结构如图1所示。
<p align="center">
<img src="../../PaddleNLP/neural_machine_translation/transformer/images/transformer_network.png" height=400 hspace='10'/> <br />
图 1. Transformer 网络结构图
</p>
Encoder 由若干相同的 layer 堆叠组成,每个 layer 主要由多头注意力(Multi-Head Attention)和全连接的前馈(Feed-Forward)网络这两个 sub-layer 构成。
- Multi-Head Attention 在这里用于实现 Self-Attention,相比于简单的 Attention 机制,其将输入进行多路线性变换后分别计算 Attention 的结果,并将所有结果拼接后再次进行线性变换作为输出。参见图2,其中 Attention 使用的是点积(Dot-Product),并在点积后进行了 scale 的处理以避免因点积结果过大进入 softmax 的饱和区域。
- Feed-Forward 网络会对序列中的每个位置进行相同的计算(Position-wise),其采用的是两次线性变换中间加以 ReLU 激活的结构。
此外,每个 sub-layer 后还施以 [Residual Connection](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 来促进梯度传播和模型收敛。
<p align="center">
<img src="../../PaddleNLP/neural_machine_translation/transformer/images/multi_head_attention.png" height=300 hspace='10'/> <br />
图 2. Multi-Head Attention
</p>
Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 layer ,在组成 Decoder 的 layer 中还多了一个 Multi-Head Attention 的 sub-layer 来实现对 Encoder 输出的 Attention,这个 Encoder-Decoder Attention 在其他 Seq2Seq 模型中也是存在的。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册