diff --git a/dygraph/mnist/README.md b/dygraph/mnist/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..867f75189f0c8a2da6a87269b5563ec4685eae45
--- /dev/null
+++ b/dygraph/mnist/README.md
@@ -0,0 +1,62 @@
+# MNIST
+当我们学习编程的时候,编写的第一个程序一般是实现打印"Hello World"。而机器学习(或深度学习)的入门教程,一般都是 MNIST 数据库上的手写识别问题。原因是手写识别属于典型的图像分类问题,比较简单,同时MNIST数据集也很完备。
+本页将介绍如何使用PaddlePaddle在DyGraph模式下实现MNIST,包括[安装](#installation)、[训练](#training-a-model)、[输出](#log)、[参数保存](#save)、[模型评估](#evaluation)。
+
+---
+## 内容
+- [安装](#installation)
+- [训练](#training-a-model)
+- [输出](#log)
+
+## 安装
+
+在当前目录下运行样例代码需要PadddlePaddle Fluid的v1.4.0或以上的版本。如果你的运行环境中的PaddlePaddle低于此版本,请根据安装文档中的说明来更新PaddlePaddle。
+
+## 训练
+教程中使用`paddle.dataset.mnist`数据集作为训练数据,可以通过如下的方式启动训练:
+```
+env CUDA_VISIBLE_DEVICES=0 python mnist_dygraph.py
+```
+
+## 输出
+执行训练开始后,将得到类似如下的输出。
+```
+Loss at epoch 0 step 0: [2.3043773]
+Loss at epoch 0 step 100: [0.20764539]
+Loss at epoch 0 step 200: [0.18648806]
+Loss at epoch 0 step 300: [0.10279777]
+Loss at epoch 0 step 400: [0.03940877]
+...
+```
+
+## 参数保存
+调用`fluid.dygraph.save_persistables()`接口可以把模型的参数进行保存。
+```python
+fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
+```
+
+## 测试
+执行`mnist.eval()`可以切换至评估状态,即不更新只使用参数进行训练,通过这种方式进行测试或者评估。
+```python
+mnist.eval()
+```
+
+## 模型评估
+我们使用手写数据集中的一张图片来进行评估。为了区别训练模型,我们使用`with fluid.dygraph.guard()`来切换到一个新的参数空间,然后构建一个用于评估的网络`mnist_infer`,并通过`mnist_infer.load_dict()`来加载使用`fluid.dygraph.load_persistables`读取的参数。然后用`mnist_infer.eval()`切换到评估。
+```python
+with fluid.dygraph.guard():
+
+ mnist_infer = MNIST("mnist")
+ # load checkpoint
+ mnist_infer.load_dict(
+ fluid.dygraph.load_persistables("save_dir"))
+
+ # start evaluate mode
+ mnist_infer.eval()
+```
+如果无意外,将可以看到预测的结果:
+```text
+Inference result of image/infer_3.png is: 3
+```
+
+
diff --git a/dygraph/mnist/image/infer_3.png b/dygraph/mnist/image/infer_3.png
new file mode 100644
index 0000000000000000000000000000000000000000..030cd60d3b4af9aecd4941204da4ad15f6e1189f
Binary files /dev/null and b/dygraph/mnist/image/infer_3.png differ
diff --git a/dygraph/mnist/mnist_dygraph.py b/dygraph/mnist/mnist_dygraph.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed1a0824bda88755665daa0f0c0df9ca2760ff12
--- /dev/null
+++ b/dygraph/mnist/mnist_dygraph.py
@@ -0,0 +1,202 @@
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import numpy as np
+from PIL import Image
+import os
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.optimizer import AdamOptimizer
+from paddle.fluid.dygraph.nn import Conv2D, Pool2D, FC
+from paddle.fluid.dygraph.base import to_variable
+
+
+class SimpleImgConvPool(fluid.dygraph.Layer):
+ def __init__(self,
+ name_scope,
+ num_channels,
+ num_filters,
+ filter_size,
+ pool_size,
+ pool_stride,
+ pool_padding=0,
+ pool_type='max',
+ global_pooling=False,
+ conv_stride=1,
+ conv_padding=0,
+ conv_dilation=1,
+ conv_groups=1,
+ act=None,
+ use_cudnn=False,
+ param_attr=None,
+ bias_attr=None):
+ super(SimpleImgConvPool, self).__init__(name_scope)
+
+ self._conv2d = Conv2D(
+ self.full_name(),
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=conv_stride,
+ padding=conv_padding,
+ dilation=conv_dilation,
+ groups=conv_groups,
+ param_attr=None,
+ bias_attr=None,
+ use_cudnn=use_cudnn)
+
+ self._pool2d = Pool2D(
+ self.full_name(),
+ pool_size=pool_size,
+ pool_type=pool_type,
+ pool_stride=pool_stride,
+ pool_padding=pool_padding,
+ global_pooling=global_pooling,
+ use_cudnn=use_cudnn)
+
+ def forward(self, inputs):
+ x = self._conv2d(inputs)
+ x = self._pool2d(x)
+ return x
+
+
+class MNIST(fluid.dygraph.Layer):
+ def __init__(self, name_scope):
+ super(MNIST, self).__init__(name_scope)
+
+ self._simple_img_conv_pool_1 = SimpleImgConvPool(
+ self.full_name(), 1, 20, 5, 2, 2, act="relu")
+
+ self._simple_img_conv_pool_2 = SimpleImgConvPool(
+ self.full_name(), 20, 50, 5, 2, 2, act="relu")
+
+ pool_2_shape = 50 * 4 * 4
+ SIZE = 10
+ scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
+ self._fc = FC(self.full_name(),
+ 10,
+ param_attr=fluid.param_attr.ParamAttr(
+ initializer=fluid.initializer.NormalInitializer(
+ loc=0.0, scale=scale)),
+ act="softmax")
+
+ def forward(self, inputs, label=None):
+ x = self._simple_img_conv_pool_1(inputs)
+ x = self._simple_img_conv_pool_2(x)
+ x = self._fc(x)
+ if label is not None:
+ acc = fluid.layers.accuracy(input=x, label=label)
+ return x, acc
+ else:
+ return x
+
+
+def test_train(reader, model, batch_size):
+ acc_set = []
+ avg_loss_set = []
+ for batch_id, data in enumerate(reader()):
+ dy_x_data = np.array(
+ [x[0].reshape(1, 28, 28)
+ for x in data]).astype('float32')
+ y_data = np.array(
+ [x[1] for x in data]).astype('int64').reshape(batch_size, 1)
+
+ img = to_variable(dy_x_data)
+ label = to_variable(y_data)
+ label.stop_gradient = True
+ prediction, acc = model(img, label)
+ loss = fluid.layers.cross_entropy(input=prediction, label=label)
+ avg_loss = fluid.layers.mean(loss)
+ acc_set.append(float(acc.numpy()))
+ avg_loss_set.append(float(avg_loss.numpy()))
+
+ # get test acc and loss
+ acc_val_mean = np.array(acc_set).mean()
+ avg_loss_val_mean = np.array(avg_loss_set).mean()
+
+ return avg_loss_val_mean, acc_val_mean
+
+
+def train_mnist():
+ epoch_num = 5
+ BATCH_SIZE = 64
+ with fluid.dygraph.guard():
+
+ mnist = MNIST("mnist")
+ adam = AdamOptimizer(learning_rate=0.001)
+ train_reader = paddle.batch(
+ paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
+ test_reader = paddle.batch(
+ paddle.dataset.mnist.test(), batch_size=BATCH_SIZE, drop_last=True)
+ for epoch in range(epoch_num):
+ for batch_id, data in enumerate(train_reader()):
+ dy_x_data = np.array(
+ [x[0].reshape(1, 28, 28)
+ for x in data]).astype('float32')
+ y_data = np.array(
+ [x[1] for x in data]).astype('int64').reshape(BATCH_SIZE, 1)
+
+ img = to_variable(dy_x_data)
+ label = to_variable(y_data)
+ label.stop_gradient = True
+
+ cost, acc = mnist(img, label)
+
+ loss = fluid.layers.cross_entropy(cost, label)
+ avg_loss = fluid.layers.mean(loss)
+ avg_loss.backward()
+ adam.minimize(avg_loss)
+ # save checkpoint
+ mnist.clear_gradients()
+ if batch_id % 100 == 0:
+ print("Loss at epoch {} step {}: {:}".format(epoch, batch_id, avg_loss.numpy()))
+
+ mnist.eval()
+ test_cost, test_acc = test_train(test_reader, mnist, BATCH_SIZE)
+ mnist.train()
+ print("Loss at epoch {} , Test avg_loss is: {}, acc is: {}".format(epoch, test_cost, test_acc))
+
+ fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
+ print("checkpoint saved")
+
+ with fluid.dygraph.guard():
+
+ mnist_infer = MNIST("mnist")
+ # load checkpoint
+ mnist_infer.load_dict(
+ fluid.dygraph.load_persistables("save_dir"))
+ print("checkpoint loaded")
+
+ # start evaluate mode
+ mnist_infer.eval()
+
+ def load_image(file):
+ im = Image.open(file).convert('L')
+ im = im.resize((28, 28), Image.ANTIALIAS)
+ im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
+ im = im / 255.0 * 2.0 - 1.0
+ return im
+
+ cur_dir = os.path.dirname(os.path.realpath(__file__))
+ tensor_img = load_image(cur_dir + '/image/infer_3.png')
+
+ results = mnist_infer(to_variable(tensor_img))
+ lab = np.argsort(results.numpy())
+ print("Inference result of image/infer_3.png is: %d" % lab[0][-1])
+
+
+if __name__ == '__main__':
+ train_mnist()
diff --git a/dygraph/resnet/README.md b/dygraph/resnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0cadfdd2fd56a39ecfffe8cac01e37ad0440c399
--- /dev/null
+++ b/dygraph/resnet/README.md
@@ -0,0 +1,35 @@
+DyGraph模式下Residual Network实现
+========
+
+简介
+--------
+Residual Network(ResNet)是常用的图像分类模型。我们实现了在paddlepaddle的DyGraph模式下相应的实现。可以对比原先静态图下实现([Residual Network](https://github.com/PaddlePaddle/models/tree/develop/PaddleCV/image_classification/models))来了解paddle中DyGraph模式。
+运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
+
+
+## 代码结构
+```
+└── train.py # 训练脚本。
+```
+
+## 使用的数据
+
+教程中使用`paddle.dataset.flowers`数据集作为训练数据,该数据集通过`paddle.dataset`模块自动下载到本地。
+
+## 训练测试Residual Network
+
+在GPU单卡上训练Residual Network:
+
+```
+env CUDA_VISIBLE_DEVICES=0 python train.py
+```
+
+这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。
+
+## 输出
+执行训练开始后,将得到类似如下的输出。每一轮`batch`训练将会打印当前epoch、step以及loss值。当前默认执行`epoch=10`, `batch_size=8`。您可以调整参数以得到更好的训练效果,同时也意味着消耗更多的内存(显存)以及需要花费更长的时间。
+```text
+epoch id: 0, batch step: 0, loss: 4.951202
+epoch id: 0, batch step: 1, loss: 5.268410
+epoch id: 0, batch step: 2, loss: 5.123999
+```
diff --git a/dygraph/resnet/train.py b/dygraph/resnet/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0c4cdb00acc7dfed7135b126aa67cde150d5bb8
--- /dev/null
+++ b/dygraph/resnet/train.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+import paddle
+import paddle.fluid as fluid
+from paddle.fluid.layer_helper import LayerHelper
+from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, FC
+from paddle.fluid.dygraph.base import to_variable
+
+batch_size = 8
+epoch = 10
+
+
+def optimizer_setting():
+ return fluid.optimizer.SGD(learning_rate=0.01)
+
+
+class ConvBNLayer(fluid.dygraph.Layer):
+ def __init__(self,
+ name_scope,
+ num_channels,
+ num_filters,
+ filter_size,
+ stride=1,
+ groups=1,
+ act=None):
+ super(ConvBNLayer, self).__init__(name_scope)
+
+ self._conv = Conv2D(
+ self.full_name(),
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=filter_size,
+ stride=stride,
+ padding=(filter_size - 1) // 2,
+ groups=groups,
+ act=None,
+ bias_attr=None)
+
+ self._batch_norm = BatchNorm(self.full_name(), num_filters, act=act)
+
+ def forward(self, inputs):
+ y = self._conv(inputs)
+ y = self._batch_norm(y)
+
+ return y
+
+
+class BottleneckBlock(fluid.dygraph.Layer):
+ def __init__(self,
+ name_scope,
+ num_channels,
+ num_filters,
+ stride,
+ shortcut=True):
+ super(BottleneckBlock, self).__init__(name_scope)
+
+ self.conv0 = ConvBNLayer(
+ self.full_name(),
+ num_channels=num_channels,
+ num_filters=num_filters,
+ filter_size=1,
+ act='relu')
+ self.conv1 = ConvBNLayer(
+ self.full_name(),
+ num_channels=num_filters,
+ num_filters=num_filters,
+ filter_size=3,
+ stride=stride,
+ act='relu')
+ self.conv2 = ConvBNLayer(
+ self.full_name(),
+ num_channels=num_filters,
+ num_filters=num_filters * 4,
+ filter_size=1,
+ act=None)
+
+ if not shortcut:
+ self.short = ConvBNLayer(
+ self.full_name(),
+ num_channels=num_channels,
+ num_filters=num_filters * 4,
+ filter_size=1,
+ stride=stride)
+
+ self.shortcut = shortcut
+
+ self._num_channels_out = num_filters * 4
+
+ def forward(self, inputs):
+ y = self.conv0(inputs)
+ conv1 = self.conv1(y)
+ conv2 = self.conv2(conv1)
+
+ if self.shortcut:
+ short = inputs
+ else:
+ short = self.short(inputs)
+
+ y = fluid.layers.elementwise_add(x=short, y=conv2)
+
+ layer_helper = LayerHelper(self.full_name(), act='relu')
+ return layer_helper.append_activation(y)
+
+
+class ResNet(fluid.dygraph.Layer):
+ def __init__(self, name_scope, layers=50, class_dim=102):
+ super(ResNet, self).__init__(name_scope)
+
+ self.layers = layers
+ supported_layers = [50, 101, 152]
+ assert layers in supported_layers, \
+ "supported layers are {} but input layer is {}".format(supported_layers, layers)
+
+ if layers == 50:
+ depth = [3, 4, 6, 3]
+ elif layers == 101:
+ depth = [3, 4, 23, 3]
+ elif layers == 152:
+ depth = [3, 8, 36, 3]
+ num_filters = [64, 128, 256, 512]
+
+ self.conv = ConvBNLayer(
+ self.full_name(),
+ num_channels=3,
+ num_filters=64,
+ filter_size=7,
+ stride=2,
+ act='relu')
+ self.pool2d_max = Pool2D(
+ self.full_name(),
+ pool_size=3,
+ pool_stride=2,
+ pool_padding=1,
+ pool_type='max')
+
+ self.bottleneck_block_list = []
+ num_channels = 64
+ for block in range(len(depth)):
+ shortcut = False
+ for i in range(depth[block]):
+ bottleneck_block = self.add_sublayer(
+ 'bb_%d_%d' % (block, i),
+ BottleneckBlock(
+ self.full_name(),
+ num_channels=num_channels,
+ num_filters=num_filters[block],
+ stride=2 if i == 0 and block != 0 else 1,
+ shortcut=shortcut))
+ num_channels = bottleneck_block._num_channels_out
+ self.bottleneck_block_list.append(bottleneck_block)
+ shortcut = True
+
+ self.pool2d_avg = Pool2D(
+ self.full_name(), pool_size=7, pool_type='avg', global_pooling=True)
+
+ import math
+ stdv = 1.0 / math.sqrt(2048 * 1.0)
+
+ self.out = FC(self.full_name(),
+ size=class_dim,
+ act='softmax',
+ param_attr=fluid.param_attr.ParamAttr(
+ initializer=fluid.initializer.Uniform(-stdv, stdv)))
+
+ def forward(self, inputs):
+ y = self.conv(inputs)
+ y = self.pool2d_max(y)
+ for bottleneck_block in self.bottleneck_block_list:
+ y = bottleneck_block(y)
+ y = self.pool2d_avg(y)
+ y = self.out(y)
+ return y
+
+
+def train_resnet():
+ with fluid.dygraph.guard():
+ resnet = ResNet("resnet")
+ optimizer = optimizer_setting()
+ train_reader = paddle.batch(
+ paddle.dataset.flowers.train(),
+ batch_size=batch_size)
+
+ for eop in range(epoch):
+ for batch_id, data in enumerate(train_reader()):
+ dy_x_data = np.array(
+ [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
+ if len(np.array([x[1] for x in data]).astype('int64')) != batch_size:
+ continue
+ y_data = np.array([x[1] for x in data]).astype('int64').reshape(
+ batch_size, 1)
+
+ img = to_variable(dy_x_data)
+ label = to_variable(y_data)
+ label._stop_gradient = True
+
+ out = resnet(img)
+ loss = fluid.layers.cross_entropy(input=out, label=label)
+ avg_loss = fluid.layers.mean(x=loss)
+
+ dy_out = avg_loss.numpy()
+ avg_loss.backward()
+
+ optimizer.minimize(avg_loss)
+ resnet.clear_gradients()
+
+ print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
+
+
+if __name__ == '__main__':
+ train_resnet()
diff --git a/dygraph/transformer/README.md b/dygraph/transformer/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0381ec50b5d158a9fee59aa4b5069c73ed348d4b
--- /dev/null
+++ b/dygraph/transformer/README.md
@@ -0,0 +1,66 @@
+## 简介
+
+### 任务说明
+
+机器翻译(machine translation, MT)是利用计算机将一种自然语言(源语言)转换为另一种自然语言(目标语言)的过程,输入为源语言句子,输出为相应的目标语言的句子。本示例是机器翻译主流模型 Transformer 的实现和相关介绍。
+
+### 数据集说明
+
+我们使用公开的 [WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)训练
+
+
+可以将下载好的wmt16数据集放在`~/.cache/paddle/dataset/wmt16/`目录下
+
+### 安装说明
+
+1. paddle安装
+
+ 本项目依赖于 Paddlepaddle Fluid 1.4.1,请参考安装指南进行安装。
+
+2. 安装代码
+3. 环境依赖
+
+### 执行训练:
+利用python解释器执行train.py即可
+
+### 执行效果
+
+ W0422 13:25:53.853921 116144 device_context.cc:261] Please NOTE: device: 0, CUDA Capability: 35, Driver API Version: 9.0, Runtime API Version: 8.0
+ W0422 13:25:53.861614 116144 device_context.cc:269] device: 0, cuDNN Version: 7.0.
+
+ pass num : 0, batch_id: 10, dy_graph avg loss: [9.033163]
+ pass num : 0, batch_id: 20, dy_graph avg loss: [8.869838]
+ pass num : 0, batch_id: 30, dy_graph avg loss: [8.635877]
+ pass num : 0, batch_id: 40, dy_graph avg loss: [8.460026]
+ pass num : 0, batch_id: 50, dy_graph avg loss: [8.293438]
+ pass num : 0, batch_id: 60, dy_graph avg loss: [8.138791]
+ pass num : 0, batch_id: 70, dy_graph avg loss: [7.9594088]
+ pass num : 0, batch_id: 80, dy_graph avg loss: [7.7303553]
+ pass num : 0, batch_id: 90, dy_graph avg loss: [7.6716228]
+ pass num : 0, batch_id: 100, dy_graph avg loss: [7.611051]
+ pass num : 0, batch_id: 110, dy_graph avg loss: [7.4179897]
+ pass num : 0, batch_id: 120, dy_graph avg loss: [7.318419]
+
+## 进阶使用
+
+### 模型原理介绍
+
+Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构。其同样使用了 Seq2Seq 任务中典型的编码器-解码器(Encoder-Decoder)的框架结构,但相较于此前广泛使用的循环神经网络(Recurrent Neural Network, RNN),其完全使用注意力(Attention)机制来实现序列到序列的建模,整体网络结构如图1所示。
+
+
+
+图 1. Transformer 网络结构图
+
+
+Encoder 由若干相同的 layer 堆叠组成,每个 layer 主要由多头注意力(Multi-Head Attention)和全连接的前馈(Feed-Forward)网络这两个 sub-layer 构成。
+- Multi-Head Attention 在这里用于实现 Self-Attention,相比于简单的 Attention 机制,其将输入进行多路线性变换后分别计算 Attention 的结果,并将所有结果拼接后再次进行线性变换作为输出。参见图2,其中 Attention 使用的是点积(Dot-Product),并在点积后进行了 scale 的处理以避免因点积结果过大进入 softmax 的饱和区域。
+- Feed-Forward 网络会对序列中的每个位置进行相同的计算(Position-wise),其采用的是两次线性变换中间加以 ReLU 激活的结构。
+
+此外,每个 sub-layer 后还施以 [Residual Connection](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf) 和 [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf) 来促进梯度传播和模型收敛。
+
+
+
+图 2. Multi-Head Attention
+
+
+Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 layer ,在组成 Decoder 的 layer 中还多了一个 Multi-Head Attention 的 sub-layer 来实现对 Encoder 输出的 Attention,这个 Encoder-Decoder Attention 在其他 Seq2Seq 模型中也是存在的。
diff --git a/dygraph/transformer/train.py b/dygraph/transformer/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..63ff4addccfc7549a96002ea3a30192ef7f27a3c
--- /dev/null
+++ b/dygraph/transformer/train.py
@@ -0,0 +1,1142 @@
+from __future__ import print_function
+
+import contextlib
+import paddle.fluid as fluid
+from paddle.fluid.dygraph import Embedding, LayerNorm, FC, to_variable, Layer, guard
+import numpy as np
+import paddle
+import paddle.dataset.wmt16 as wmt16
+np.set_printoptions(suppress=True)
+
+
+@contextlib.contextmanager
+def new_program_scope(main=None, startup=None, scope=None):
+ """
+ base program
+ :param main:
+ :param startup:
+ :param scope:
+ :return:
+ """
+ prog = main if main else fluid.Program()
+ startup_prog = startup if startup else fluid.Program()
+ scope = scope if scope else fluid.core.Scope()
+ with fluid.scope_guard(scope):
+ with fluid.program_guard(prog, startup_prog):
+ with fluid.unique_name.guard():
+ yield
+
+
+# Copy from models
+class TrainTaskConfig(object):
+ """
+ TrainTaskConfig
+ """
+ # support both CPU and GPU now.
+ use_gpu = True
+ # the epoch number to train.
+ pass_num = 30
+ # the number of sequences contained in a mini-batch.
+ # deprecated, set batch_size in args.
+ batch_size = 32
+ # the hyper parameters for Adam optimizer.
+ # This static learning_rate will be multiplied to the LearningRateScheduler
+ # derived learning rate the to get the final learning rate.
+ learning_rate = 2.0
+ beta1 = 0.9
+ beta2 = 0.997
+ eps = 1e-9
+ # the parameters for learning rate scheduling.
+ warmup_steps = 8000
+ # the weight used to mix up the ground-truth distribution and the fixed
+ # uniform distribution in label smoothing when training.
+ # Set this as zero if label smoothing is not wanted.
+ label_smooth_eps = 0.1
+
+
+class ModelHyperParams(object):
+ """
+ ModelHyperParams
+ """
+ # These following five vocabularies related configurations will be set
+ # automatically according to the passed vocabulary path and special tokens.
+ # size of source word dictionary.
+ src_vocab_size = 10000
+ # size of target word dictionay
+ trg_vocab_size = 10000
+ # # index for token
+ # bos_idx = 0
+ # # index for token
+ # eos_idx = 1
+ # # index for token
+ # unk_idx = 2
+
+ src_pad_idx = 0
+
+ # index for token in target language.
+ trg_pad_idx = 1
+
+ # max length of sequences deciding the size of position encoding table.
+ max_length = 50
+ # the dimension for word embeddings, which is also the last dimension of
+ # the input and output of multi-head attention, position-wise feed-forward
+ # networks, encoder and decoder.
+ d_model = 512
+ # size of the hidden layer in position-wise feed-forward networks.
+ d_inner_hid = 2048
+ # the dimension that keys are projected to for dot-product attention.
+ d_key = 64
+ # the dimension that values are projected to for dot-product attention.
+ d_value = 64
+ # number of head used in multi-head attention.
+ n_head = 8
+ # number of sub-layers to be stacked in the encoder and decoder.
+ n_layer = 6
+ # dropout rates of different modules.
+ prepostprocess_dropout = 0.1
+ attention_dropout = 0.1
+ relu_dropout = 0.1
+ # to process before each sub-layer
+ preprocess_cmd = "n" # layer normalization
+ # to process after each sub-layer
+ postprocess_cmd = "da" # dropout + residual connection
+ # random seed used in dropout for CE.
+ dropout_seed = None
+ # the flag indicating whether to share embedding and softmax weights.
+ # vocabularies in source and target should be same for weight sharing.
+ weight_sharing = False
+
+
+# The placeholder for batch_size in compile time. Must be -1 currently to be
+# consistent with some ops' infer-shape output in compile time, such as the
+# sequence_expand op used in beamsearch decoder.
+batch_size = -1
+# The placeholder for squence length in compile time.
+seq_len = ModelHyperParams.max_length
+# Here list the data shapes and data types of all inputs.
+# The shapes here act as placeholder and are set to pass the infer-shape in
+# compile time.
+input_descs = {
+ # The actual data shape of src_word is:
+ # [batch_size, max_src_len_in_batch, 1]
+ "src_word": [(batch_size, seq_len, 1), "int64", 2],
+ # The actual data shape of src_pos is:
+ # [batch_size, max_src_len_in_batch, 1]
+ "src_pos": [(batch_size, seq_len, 1), "int64"],
+ # This input is used to remove attention weights on paddings in the
+ # encoder.
+ # The actual data shape of src_slf_attn_bias is:
+ # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
+ "src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
+ seq_len), "float32"],
+ # The actual data shape of trg_word is:
+ # [batch_size, max_trg_len_in_batch, 1]
+ "trg_word": [(batch_size, seq_len, 1), "int64",
+ 2], # lod_level is only used in fast decoder.
+ # The actual data shape of trg_pos is:
+ # [batch_size, max_trg_len_in_batch, 1]
+ "trg_pos": [(batch_size, seq_len, 1), "int64"],
+ # This input is used to remove attention weights on paddings and
+ # subsequent words in the decoder.
+ # The actual data shape of trg_slf_attn_bias is:
+ # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
+ "trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
+ seq_len), "float32"],
+ # This input is used to remove attention weights on paddings of the source
+ # input in the encoder-decoder attention.
+ # The actual data shape of trg_src_attn_bias is:
+ # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
+ "trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
+ seq_len), "float32"],
+ # This input is used in independent decoder program for inference.
+ # The actual data shape of enc_output is:
+ # [batch_size, max_src_len_in_batch, d_model]
+ "enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
+ # The actual data shape of label_word is:
+ # [batch_size * max_trg_len_in_batch, 1]
+ "lbl_word": [(batch_size * seq_len, 1), "int64"],
+ # This input is used to mask out the loss of paddding tokens.
+ # The actual data shape of label_weight is:
+ # [batch_size * max_trg_len_in_batch, 1]
+ "lbl_weight": [(batch_size * seq_len, 1), "float32"],
+ # This input is used in beam-search decoder.
+ "init_score": [(batch_size, 1), "float32", 2],
+ # This input is used in beam-search decoder for the first gather
+ # (cell states updation)
+ "init_idx": [(batch_size, ), "int32"],
+}
+
+# Names of word embedding table which might be reused for weight sharing.
+word_emb_param_names = (
+ "src_word_emb_table",
+ "trg_word_emb_table", )
+# Names of position encoding table which will be initialized externally.
+pos_enc_param_names = (
+ "src_pos_enc_table",
+ "trg_pos_enc_table", )
+# separated inputs for different usages.
+encoder_data_input_fields = (
+ "src_word",
+ "src_pos",
+ "src_slf_attn_bias", )
+decoder_data_input_fields = (
+ "trg_word",
+ "trg_pos",
+ "trg_slf_attn_bias",
+ "trg_src_attn_bias",
+ "enc_output", )
+label_data_input_fields = (
+ "lbl_word",
+ "lbl_weight", )
+# In fast decoder, trg_pos (only containing the current time step) is generated
+# by ops and trg_slf_attn_bias is not needed.
+fast_decoder_data_input_fields = (
+ "trg_word",
+ "init_score",
+ "init_idx",
+ "trg_src_attn_bias", )
+
+
+def merge_cfg_from_list(cfg_list, g_cfgs):
+ """
+ Set the above global configurations using the cfg_list.
+ """
+ assert len(cfg_list) % 2 == 0
+ for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
+ for g_cfg in g_cfgs:
+ if hasattr(g_cfg, key):
+ try:
+ value = eval(value)
+ except Exception: # for file path
+ pass
+ setattr(g_cfg, key, value)
+ break
+
+
+def position_encoding_init(n_position, d_pos_vec):
+ """
+ Generate the initial values for the sinusoid position encoding table.
+ """
+ channels = d_pos_vec
+ position = np.arange(n_position)
+ num_timescales = channels // 2
+ log_timescale_increment = (np.log(float(1e4) / float(1)) /
+ (num_timescales - 1))
+ inv_timescales = np.exp(np.arange(
+ num_timescales)) * -log_timescale_increment
+ scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
+ 0)
+ signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
+ signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
+ position_enc = signal
+ return position_enc.astype("float32")
+
+
+def create_data(np_values, is_static=False):
+ """
+ create_data
+ :param np_values:
+ :param is_static:
+ :return:
+ """
+ # pdb.set_trace()
+ [
+ src_word_np, src_pos_np, trg_word_np, trg_pos_np, src_slf_attn_bias_np,
+ trg_slf_attn_bias_np, trg_src_attn_bias_np, lbl_word_np, lbl_weight_np
+ ] = np_values
+
+ if is_static:
+ return [
+ src_word_np, src_pos_np, src_slf_attn_bias_np, trg_word_np,
+ trg_pos_np, trg_slf_attn_bias_np, trg_src_attn_bias_np, lbl_word_np,
+ lbl_weight_np
+ ]
+ else:
+ enc_inputs = [
+ to_variable(
+ src_word_np, name='src_word'), to_variable(
+ src_pos_np, name='src_pos'), to_variable(
+ src_slf_attn_bias_np, name='src_slf_attn_bias')
+ ]
+ dec_inputs = [
+ to_variable(
+ trg_word_np, name='trg_word'), to_variable(
+ trg_pos_np, name='trg_pos'), to_variable(
+ trg_slf_attn_bias_np, name='trg_slf_attn_bias'),
+ to_variable(
+ trg_src_attn_bias_np, name='trg_src_attn_bias')
+ ]
+ label = to_variable(lbl_word_np, name='lbl_word')
+ weight = to_variable(lbl_weight_np, name='lbl_weight')
+ return enc_inputs, dec_inputs, label, weight
+
+
+def create_feed_dict_list(data, init=False):
+ """
+ create_feed_dict_list
+ :param data:
+ :param init:
+ :return:
+ """
+ if init:
+ data_input_names = encoder_data_input_fields + \
+ decoder_data_input_fields[:-1] + label_data_input_fields + pos_enc_param_names
+ else:
+ data_input_names = encoder_data_input_fields + \
+ decoder_data_input_fields[:-1] + label_data_input_fields
+ feed_dict_list = dict()
+ for i in range(len(data_input_names)):
+ feed_dict_list[data_input_names[i]] = data[i]
+ return feed_dict_list
+
+
+def make_all_inputs(input_fields):
+ """
+ Define the input data layers for the transformer model.
+ """
+ inputs = []
+ for input_field in input_fields:
+ input_var = fluid.layers.data(
+ name=input_field,
+ shape=input_descs[input_field][0],
+ dtype=input_descs[input_field][1],
+ lod_level=input_descs[input_field][2]
+ if len(input_descs[input_field]) == 3 else 0,
+ append_batch_size=False)
+ inputs.append(input_var)
+ return inputs
+
+
+def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head):
+ """
+ Pad the instances to the max sequence length in batch, and generate the
+ corresponding position data and attention bias. Then, convert the numpy
+ data to tensors and return a dict mapping names to tensors.
+ """
+
+ def __pad_batch_data(insts,
+ pad_idx,
+ n_head,
+ is_target=False,
+ is_label=False,
+ return_attn_bias=True,
+ return_max_len=True,
+ return_num_token=False):
+ """
+ Pad the instances to the max sequence length in batch, and generate the
+ corresponding position data and attention bias.
+ """
+ return_list = []
+ max_len = max(len(inst) for inst in insts)
+ # Any token included in dict can be used to pad, since the paddings' loss
+ # will be masked out by weights and make no effect on parameter gradients.
+ inst_data = np.array(
+ [inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
+ return_list += [inst_data.astype("int64").reshape([-1, 1])]
+ if is_label: # label weight
+ inst_weight = np.array([[1.] * len(inst) + [0.] *
+ (max_len - len(inst)) for inst in insts])
+ return_list += [inst_weight.astype("float32").reshape([-1, 1])]
+ else: # position data
+ inst_pos = np.array([
+ list(range(0, len(inst))) + [0] * (max_len - len(inst))
+ for inst in insts
+ ])
+ return_list += [inst_pos.astype("int64").reshape([-1, 1])]
+ if return_attn_bias:
+ if is_target:
+ # This is used to avoid attention on paddings and subsequent
+ # words.
+ slf_attn_bias_data = np.ones(
+ (inst_data.shape[0], max_len, max_len))
+ slf_attn_bias_data = np.triu(
+ slf_attn_bias_data, 1).reshape([-1, 1, max_len, max_len])
+ slf_attn_bias_data = np.tile(slf_attn_bias_data,
+ [1, n_head, 1, 1]) * [-1e9]
+ else:
+ # This is used to avoid attention on paddings.
+ slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
+ (max_len - len(inst))
+ for inst in insts])
+ slf_attn_bias_data = np.tile(
+ slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
+ [1, n_head, max_len, 1])
+ return_list += [slf_attn_bias_data.astype("float32")]
+ if return_max_len:
+ return_list += [max_len]
+ if return_num_token:
+ num_token = 0
+ for inst in insts:
+ num_token += len(inst)
+ return_list += [num_token]
+ return return_list if len(return_list) > 1 else return_list[0]
+
+ src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data(
+ [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
+ src_word = src_word.reshape(-1, src_max_len, 1)
+ src_pos = src_pos.reshape(-1, src_max_len, 1)
+ trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data(
+ [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
+ trg_word = trg_word.reshape(-1, trg_max_len, 1)
+ trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
+
+ trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
+ [1, 1, trg_max_len, 1]).astype("float32")
+
+ lbl_word, lbl_weight, num_token = __pad_batch_data(
+ [inst[2] for inst in insts],
+ trg_pad_idx,
+ n_head,
+ is_target=False,
+ is_label=True,
+ return_attn_bias=False,
+ return_max_len=False,
+ return_num_token=True)
+
+ return [
+ src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias,
+ trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
+ ]
+
+
+pos_inp1 = position_encoding_init(ModelHyperParams.max_length + 1,
+ ModelHyperParams.d_model)
+pos_inp2 = position_encoding_init(ModelHyperParams.max_length + 1,
+ ModelHyperParams.d_model)
+
+
+class PrePostProcessLayer(Layer):
+ """
+ PrePostProcessLayer
+ """
+
+ def __init__(self, name_scope, process_cmd, shape_len=None):
+ super(PrePostProcessLayer, self).__init__(name_scope)
+ for cmd in process_cmd:
+ if cmd == "n":
+ self._layer_norm = LayerNorm(
+ name_scope=self.full_name(),
+ begin_norm_axis=shape_len - 1,
+ param_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(1.)),
+ bias_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(0.)))
+
+ def forward(self, prev_out, out, process_cmd, dropout_rate=0.):
+ """
+ forward
+ :param prev_out:
+ :param out:
+ :param process_cmd:
+ :param dropout_rate:
+ :return:
+ """
+ for cmd in process_cmd:
+ if cmd == "a": # add residual connection
+ out = out + prev_out if prev_out else out
+ elif cmd == "n": # add layer normalization
+ out = self._layer_norm(out)
+ elif cmd == "d": # add dropout
+ if dropout_rate:
+ out = fluid.layers.dropout(
+ out,
+ dropout_prob=dropout_rate,
+ seed=ModelHyperParams.dropout_seed,
+ is_test=False)
+ return out
+
+
+class PositionwiseFeedForwardLayer(Layer):
+ """
+ PositionwiseFeedForwardLayer
+ """
+
+ def __init__(self, name_scope, d_inner_hid, d_hid, dropout_rate):
+ super(PositionwiseFeedForwardLayer, self).__init__(name_scope)
+ self._i2h = FC(name_scope=self.full_name(),
+ size=d_inner_hid,
+ num_flatten_dims=2,
+ act="relu")
+ self._h2o = FC(name_scope=self.full_name(),
+ size=d_hid,
+ num_flatten_dims=2)
+ self._dropout_rate = dropout_rate
+
+ def forward(self, x):
+ """
+ forward
+ :param x:
+ :return:
+ """
+ hidden = self._i2h(x)
+ if self._dropout_rate:
+ hidden = fluid.layers.dropout(
+ hidden,
+ dropout_prob=self._dropout_rate,
+ seed=ModelHyperParams.dropout_seed,
+ is_test=False)
+ out = self._h2o(hidden)
+ return out
+
+
+class MultiHeadAttentionLayer(Layer):
+ """
+ MultiHeadAttentionLayer
+ """
+
+ def __init__(self,
+ name_scope,
+ d_key,
+ d_value,
+ d_model,
+ n_head=1,
+ dropout_rate=0.,
+ cache=None,
+ gather_idx=None,
+ static_kv=False):
+ super(MultiHeadAttentionLayer, self).__init__(name_scope)
+ self._n_head = n_head
+ self._d_key = d_key
+ self._d_value = d_value
+ self._d_model = d_model
+ self._dropout_rate = dropout_rate
+ self._q_fc = FC(name_scope=self.full_name(),
+ size=d_key * n_head,
+ bias_attr=False,
+ num_flatten_dims=2)
+ self._k_fc = FC(name_scope=self.full_name(),
+ size=d_key * n_head,
+ bias_attr=False,
+ num_flatten_dims=2)
+ self._v_fc = FC(name_scope=self.full_name(),
+ size=d_value * n_head,
+ bias_attr=False,
+ num_flatten_dims=2)
+ self._proj_fc = FC(name_scope=self.full_name(),
+ size=self._d_model,
+ bias_attr=False,
+ num_flatten_dims=2)
+
+ def forward(self, queries, keys, values, attn_bias):
+ """
+ forward
+ :param queries:
+ :param keys:
+ :param values:
+ :param attn_bias:
+ :return:
+ """
+ # compute q ,k ,v
+ keys = queries if keys is None else keys
+ values = keys if values is None else values
+
+ q = self._q_fc(queries)
+ k = self._k_fc(keys)
+ v = self._v_fc(values)
+
+ # split head
+ reshaped_q = fluid.layers.reshape(
+ x=q, shape=[0, 0, self._n_head, self._d_key], inplace=False)
+ transpose_q = fluid.layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3])
+ reshaped_k = fluid.layers.reshape(
+ x=k, shape=[0, 0, self._n_head, self._d_key], inplace=False)
+ transpose_k = fluid.layers.transpose(x=reshaped_k, perm=[0, 2, 1, 3])
+ reshaped_v = fluid.layers.reshape(
+ x=v, shape=[0, 0, self._n_head, self._d_value], inplace=False)
+ transpose_v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3])
+
+ # scale dot product attention
+ product = fluid.layers.matmul(
+ x=transpose_q,
+ y=transpose_k,
+ transpose_y=True,
+ alpha=self._d_model**-0.5)
+ if attn_bias:
+ product += attn_bias
+ weights = fluid.layers.softmax(product)
+ if self._dropout_rate:
+ weights_droped = fluid.layers.dropout(
+ weights,
+ dropout_prob=self._dropout_rate,
+ seed=ModelHyperParams.dropout_seed,
+ is_test=False)
+ out = fluid.layers.matmul(weights_droped, transpose_v)
+ else:
+ out = fluid.layers.matmul(weights, transpose_v)
+
+ # combine heads
+ if len(out.shape) != 4:
+ raise ValueError("Input(x) should be a 4-D Tensor.")
+ trans_x = fluid.layers.transpose(out, perm=[0, 2, 1, 3])
+ final_out = fluid.layers.reshape(
+ x=trans_x,
+ shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
+ inplace=False)
+
+ # fc to output
+ proj_out = self._proj_fc(final_out)
+ return proj_out
+
+
+class EncoderSubLayer(Layer):
+ """
+ EncoderSubLayer
+ """
+
+ def __init__(self,
+ name_scope,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(EncoderSubLayer, self).__init__(name_scope)
+ self._preprocess_cmd = preprocess_cmd
+ self._postprocess_cmd = postprocess_cmd
+ self._prepostprocess_dropout = prepostprocess_dropout
+
+ self._preprocess_layer = PrePostProcessLayer(self.full_name(),
+ self._preprocess_cmd, 3)
+ self._multihead_attention_layer = MultiHeadAttentionLayer(
+ self.full_name(), d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self._postprocess_layer = PrePostProcessLayer(
+ self.full_name(), self._postprocess_cmd, None)
+ self._preprocess_layer2 = PrePostProcessLayer(self.full_name(),
+ self._preprocess_cmd, 3)
+ self._positionwise_feed_forward = PositionwiseFeedForwardLayer(
+ self.full_name(), d_inner_hid, d_model, relu_dropout)
+ self._postprocess_layer2 = PrePostProcessLayer(
+ self.full_name(), self._postprocess_cmd, None)
+
+ def forward(self, enc_input, attn_bias):
+ """
+ forward
+ :param enc_input:
+ :param attn_bias:
+ :return:
+ """
+ pre_process_multihead = self._preprocess_layer(
+ None, enc_input, self._preprocess_cmd, self._prepostprocess_dropout)
+ attn_output = self._multihead_attention_layer(pre_process_multihead,
+ None, None, attn_bias)
+ attn_output = self._postprocess_layer(enc_input, attn_output,
+ self._postprocess_cmd,
+ self._prepostprocess_dropout)
+ pre_process2_output = self._preprocess_layer2(
+ None, attn_output, self._preprocess_cmd,
+ self._prepostprocess_dropout)
+ ffd_output = self._positionwise_feed_forward(pre_process2_output)
+ return self._postprocess_layer2(attn_output, ffd_output,
+ self._postprocess_cmd,
+ self._prepostprocess_dropout)
+
+
+class EncoderLayer(Layer):
+ """
+ encoder
+ """
+
+ def __init__(self,
+ name_scope,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(EncoderLayer, self).__init__(name_scope)
+ self._preprocess_cmd = preprocess_cmd
+ self._encoder_sublayers = list()
+ self._prepostprocess_dropout = prepostprocess_dropout
+ self._n_layer = n_layer
+ self._preprocess_layer = PrePostProcessLayer(self.full_name(),
+ self._preprocess_cmd, 3)
+ for i in range(n_layer):
+ self._encoder_sublayers.append(
+ self.add_sublayer(
+ 'esl_%d' % i,
+ EncoderSubLayer(
+ self.full_name(), n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd, postprocess_cmd)))
+
+ def forward(self, enc_input, attn_bias):
+ """
+ forward
+ :param enc_input:
+ :param attn_bias:
+ :return:
+ """
+ for i in range(self._n_layer):
+ enc_output = self._encoder_sublayers[i](enc_input, attn_bias)
+ enc_input = enc_output
+
+ return self._preprocess_layer(None, enc_output, self._preprocess_cmd,
+ self._prepostprocess_dropout)
+
+
+class PrepareEncoderDecoderLayer(Layer):
+ """
+ PrepareEncoderDecoderLayer
+ """
+
+ def __init__(self,
+ name_scope,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate,
+ word_emb_param_name=None,
+ pos_enc_param_name=None):
+ super(PrepareEncoderDecoderLayer, self).__init__(name_scope)
+ self._src_max_len = src_max_len
+ self._src_emb_dim = src_emb_dim
+ self._src_vocab_size = src_vocab_size
+ self._dropout_rate = dropout_rate
+ self._input_emb = Embedding(
+ name_scope=self.full_name(),
+ size=[src_vocab_size, src_emb_dim],
+ padding_idx=0,
+ param_attr=fluid.ParamAttr(
+ name=word_emb_param_name,
+ initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
+
+ if pos_enc_param_name is pos_enc_param_names[0]:
+ pos_inp = pos_inp1
+ else:
+ pos_inp = pos_inp2
+ self._pos_emb = Embedding(
+ name_scope=self.full_name(),
+ size=[self._src_max_len, src_emb_dim],
+ param_attr=fluid.ParamAttr(
+ name=pos_enc_param_name,
+ initializer=fluid.initializer.NumpyArrayInitializer(pos_inp),
+ trainable=False))
+
+ # use in dygraph_mode to fit different length batch
+ # self._pos_emb._w = to_variable(
+ # position_encoding_init(self._src_max_len, self._src_emb_dim))
+
+ def forward(self, src_word, src_pos):
+ """
+ forward
+ :param src_word:
+ :param src_pos:
+ :return:
+ """
+ # print("here")
+ # print(self._input_emb._w._numpy().shape)
+ src_word_emb = self._input_emb(src_word)
+
+ src_word_emb = fluid.layers.scale(
+ x=src_word_emb, scale=self._src_emb_dim**0.5)
+ # # TODO change this to fit dynamic length input
+ src_pos_emb = self._pos_emb(src_pos)
+ src_pos_emb.stop_gradient = True
+ enc_input = src_word_emb + src_pos_emb
+ return fluid.layers.dropout(
+ enc_input,
+ dropout_prob=self._dropout_rate,
+ seed=ModelHyperParams.dropout_seed,
+ is_test=False) if self._dropout_rate else enc_input
+
+
+class WrapEncoderLayer(Layer):
+ """
+ encoderlayer
+ """
+
+ def __init__(self, name_cope, src_vocab_size, max_length, n_layer, n_head,
+ d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd, weight_sharing):
+ """
+ The wrapper assembles together all needed layers for the encoder.
+ """
+ super(WrapEncoderLayer, self).__init__(name_cope)
+
+ self._prepare_encoder_layer = PrepareEncoderDecoderLayer(
+ self.full_name(),
+ src_vocab_size,
+ d_model,
+ max_length,
+ prepostprocess_dropout,
+ word_emb_param_name=word_emb_param_names[0],
+ pos_enc_param_name=pos_enc_param_names[0])
+ self._encoder = EncoderLayer(
+ self.full_name(), n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd, postprocess_cmd)
+
+ def forward(self, enc_inputs):
+ """forward"""
+ src_word, src_pos, src_slf_attn_bias = enc_inputs
+ enc_input = self._prepare_encoder_layer(src_word, src_pos)
+ enc_output = self._encoder(enc_input, src_slf_attn_bias)
+ return enc_output
+
+
+class DecoderSubLayer(Layer):
+ """
+ decoder
+ """
+
+ def __init__(self,
+ name_scope,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ cache=None,
+ gather_idx=None):
+ super(DecoderSubLayer, self).__init__(name_scope)
+ self._postprocess_cmd = postprocess_cmd
+ self._preprocess_cmd = preprocess_cmd
+ self._prepostprcess_dropout = prepostprocess_dropout
+ self._pre_process_layer = PrePostProcessLayer(self.full_name(),
+ preprocess_cmd, 3)
+ self._multihead_attention_layer = MultiHeadAttentionLayer(
+ self.full_name(),
+ d_key,
+ d_value,
+ d_model,
+ n_head,
+ attention_dropout,
+ cache=cache,
+ gather_idx=gather_idx)
+ self._post_process_layer = PrePostProcessLayer(self.full_name(),
+ postprocess_cmd, None)
+ self._pre_process_layer2 = PrePostProcessLayer(self.full_name(),
+ preprocess_cmd, 3)
+ self._multihead_attention_layer2 = MultiHeadAttentionLayer(
+ self.full_name(),
+ d_key,
+ d_value,
+ d_model,
+ n_head,
+ attention_dropout,
+ cache=cache,
+ gather_idx=gather_idx,
+ static_kv=True)
+ self._post_process_layer2 = PrePostProcessLayer(self.full_name(),
+ postprocess_cmd, None)
+ self._pre_process_layer3 = PrePostProcessLayer(self.full_name(),
+ preprocess_cmd, 3)
+ self._positionwise_feed_forward_layer = PositionwiseFeedForwardLayer(
+ self.full_name(), d_inner_hid, d_model, relu_dropout)
+ self._post_process_layer3 = PrePostProcessLayer(self.full_name(),
+ postprocess_cmd, None)
+
+ def forward(self, dec_input, enc_output, slf_attn_bias, dec_enc_attn_bias):
+ """
+ forward
+ :param dec_input:
+ :param enc_output:
+ :param slf_attn_bias:
+ :param dec_enc_attn_bias:
+ :return:
+ """
+ pre_process_rlt = self._pre_process_layer(
+ None, dec_input, self._preprocess_cmd, self._prepostprcess_dropout)
+ slf_attn_output = self._multihead_attention_layer(pre_process_rlt, None,
+ None, slf_attn_bias)
+ slf_attn_output_pp = self._post_process_layer(
+ dec_input, slf_attn_output, self._postprocess_cmd,
+ self._prepostprcess_dropout)
+ pre_process_rlt2 = self._pre_process_layer2(None, slf_attn_output_pp,
+ self._preprocess_cmd,
+ self._prepostprcess_dropout)
+ enc_attn_output_pp = self._multihead_attention_layer2(
+ pre_process_rlt2, enc_output, enc_output, dec_enc_attn_bias)
+ enc_attn_output = self._post_process_layer2(
+ slf_attn_output_pp, enc_attn_output_pp, self._postprocess_cmd,
+ self._prepostprcess_dropout)
+ pre_process_rlt3 = self._pre_process_layer3(None, enc_attn_output,
+ self._preprocess_cmd,
+ self._prepostprcess_dropout)
+ ffd_output = self._positionwise_feed_forward_layer(pre_process_rlt3)
+ dec_output = self._post_process_layer3(enc_attn_output, ffd_output,
+ self._postprocess_cmd,
+ self._prepostprcess_dropout)
+ return dec_output
+
+
+class DecoderLayer(Layer):
+ """
+ decoder
+ """
+
+ def __init__(self,
+ name_scope,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ caches=None,
+ gather_idx=None):
+ super(DecoderLayer, self).__init__(name_scope)
+ self._pre_process_layer = PrePostProcessLayer(self.full_name(),
+ preprocess_cmd, 3)
+ self._decoder_sub_layers = list()
+ self._n_layer = n_layer
+ self._preprocess_cmd = preprocess_cmd
+ self._prepostprocess_dropout = prepostprocess_dropout
+ for i in range(n_layer):
+ self._decoder_sub_layers.append(
+ self.add_sublayer(
+ 'dsl_%d' % i,
+ DecoderSubLayer(
+ self.full_name(),
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ cache=None if caches is None else caches[i],
+ gather_idx=gather_idx)))
+
+ def forward(self, dec_input, enc_output, dec_slf_attn_bias,
+ dec_enc_attn_bias):
+ """
+ forward
+ :param dec_input:
+ :param enc_output:
+ :param dec_slf_attn_bias:
+ :param dec_enc_attn_bias:
+ :return:
+ """
+ for i in range(self._n_layer):
+ tmp_dec_output = self._decoder_sub_layers[i](
+ dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias)
+ dec_input = tmp_dec_output
+
+ dec_output = self._pre_process_layer(None, tmp_dec_output,
+ self._preprocess_cmd,
+ self._prepostprocess_dropout)
+ return dec_output
+
+
+class WrapDecoderLayer(Layer):
+ """
+ decoder
+ """
+
+ def __init__(self,
+ name_scope,
+ trg_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ caches=None,
+ gather_idx=None):
+ """
+ The wrapper assembles together all needed layers for the encoder.
+ """
+ super(WrapDecoderLayer, self).__init__(name_scope)
+
+ self._prepare_decoder_layer = PrepareEncoderDecoderLayer(
+ self.full_name(),
+ trg_vocab_size,
+ d_model,
+ max_length,
+ prepostprocess_dropout,
+ word_emb_param_name=word_emb_param_names[1],
+ pos_enc_param_name=pos_enc_param_names[1])
+ self._decoder_layer = DecoderLayer(
+ self.full_name(),
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ caches=caches,
+ gather_idx=gather_idx)
+ self._weight_sharing = weight_sharing
+ if not weight_sharing:
+ self._fc = FC(self.full_name(),
+ size=trg_vocab_size,
+ bias_attr=False)
+
+ def forward(self, dec_inputs=None, enc_output=None):
+ """
+ forward
+ :param dec_inputs:
+ :param enc_output:
+ :return:
+ """
+ trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
+ dec_input = self._prepare_decoder_layer(trg_word, trg_pos)
+ dec_output = self._decoder_layer(dec_input, enc_output,
+ trg_slf_attn_bias, trg_src_attn_bias)
+
+ dec_output_reshape = fluid.layers.reshape(
+ dec_output, shape=[-1, dec_output.shape[-1]], inplace=False)
+
+ if self._weight_sharing:
+ predict = fluid.layers.matmul(
+ x=dec_output_reshape,
+ y=self._prepare_decoder_layer._input_emb._w,
+ transpose_y=True)
+ else:
+ predict = self._fc(dec_output_reshape)
+
+ if dec_inputs is None:
+ # Return probs for independent decoder program.
+ predict_out = fluid.layers.softmax(predict)
+ return predict_out
+ return predict
+
+
+class TransFormer(Layer):
+ """
+ model
+ """
+
+ def __init__(self, name_scope, src_vocab_size, trg_vocab_size, max_length,
+ n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout, relu_dropout,
+ preprocess_cmd, postprocess_cmd, weight_sharing,
+ label_smooth_eps):
+ super(TransFormer, self).__init__(name_scope)
+ self._label_smooth_eps = label_smooth_eps
+ self._trg_vocab_size = trg_vocab_size
+ if weight_sharing:
+ assert src_vocab_size == trg_vocab_size, (
+ "Vocabularies in source and target should be same for weight sharing."
+ )
+ self._wrap_encoder_layer = WrapEncoderLayer(
+ self.full_name(), src_vocab_size, max_length, n_layer, n_head,
+ d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd,
+ weight_sharing)
+ self._wrap_decoder_layer = WrapDecoderLayer(
+ self.full_name(), trg_vocab_size, max_length, n_layer, n_head,
+ d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd,
+ weight_sharing)
+
+ if weight_sharing:
+ self._wrap_decoder_layer._prepare_decoder_layer._input_emb._w = self._wrap_encoder_layer._prepare_encoder_layer._input_emb._w
+
+ def forward(self, enc_inputs, dec_inputs, label, weights):
+ """
+ forward
+ :param enc_inputs:
+ :param dec_inputs:
+ :param label:
+ :param weights:
+ :return:
+ """
+ enc_output = self._wrap_encoder_layer(enc_inputs)
+ predict = self._wrap_decoder_layer(dec_inputs, enc_output)
+ if self._label_smooth_eps:
+ label_out = fluid.layers.label_smooth(
+ label=fluid.layers.one_hot(
+ input=label, depth=self._trg_vocab_size),
+ epsilon=self._label_smooth_eps)
+
+ cost = fluid.layers.softmax_with_cross_entropy(
+ logits=predict,
+ label=label_out,
+ soft_label=True if self._label_smooth_eps else False)
+ weighted_cost = cost * weights
+ sum_cost = fluid.layers.reduce_sum(weighted_cost)
+ token_num = fluid.layers.reduce_sum(weights)
+ token_num.stop_gradient = True
+ avg_cost = sum_cost / token_num
+ return sum_cost, avg_cost, predict, token_num
+
+
+def train():
+ """
+ train models
+ :return:
+ """
+
+ with guard():
+ transformer = TransFormer(
+ 'transformer', ModelHyperParams.src_vocab_size,
+ ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
+ ModelHyperParams.n_layer, ModelHyperParams.n_head,
+ ModelHyperParams.d_key, ModelHyperParams.d_value,
+ ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
+ ModelHyperParams.prepostprocess_dropout,
+ ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
+ ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
+ ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
+
+ optimizer = fluid.optimizer.SGD(learning_rate=0.003)
+
+ reader = paddle.batch(
+ wmt16.train(ModelHyperParams.src_vocab_size,
+ ModelHyperParams.trg_vocab_size),
+ batch_size=TrainTaskConfig.batch_size)
+ for i in range(200):
+ dy_step = 0
+ for batch in reader():
+ np_values = prepare_batch_input(
+ batch, ModelHyperParams.src_pad_idx,
+ ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head)
+
+ enc_inputs, dec_inputs, label, weights = create_data(np_values)
+ dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
+ enc_inputs, dec_inputs, label, weights)
+ dy_avg_cost.backward()
+ optimizer.minimize(dy_avg_cost)
+ transformer.clear_gradients()
+ dy_step = dy_step + 1
+ if dy_step % 10 == 0:
+ print("pass num : {}, batch_id: {}, dy_graph avg loss: {}".
+ format(i, dy_step, dy_avg_cost.numpy()))
+ print("pass : {} finished".format(i))
+
+
+if __name__ == '__main__':
+ train()