未验证 提交 2bbe37ad 编写于 作者: G Guo Sheng 提交者: GitHub

Add beam search for dygraph Transformer. (#3555)

* Add beam search for dygraph Transformer.
Re-organize dygraph Transformer.

* Add custumed data support for dygraph Transformer.

* Add validation in dygraph Transformer.

* Update notes for multi-gpu for dygraph Transformer.
上级 627e60c1
...@@ -8,19 +8,19 @@ ...@@ -8,19 +8,19 @@
### 数据集说明 ### 数据集说明
我们使用公开的 [WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)训练 我们使用[WMT-16](http://www.statmt.org/wmt16/)新增的[multimodal task](http://www.statmt.org/wmt16/multimodal-task.html)中的[translation task](http://www.statmt.org/wmt16/multimodal-task.html#task1)的数据集作为示例。该数据集为英德翻译数据,包含29001条训练数据,1000条测试数据。
该数据集内置在了Paddle中,可以通过 `paddle.dataset.wmt16` 使用,执行本项目中的训练代码数据集将自动下载到 `~/.cache/paddle/dataset/wmt16/` 目录下。
可以将下载好的wmt16数据集放在`~/.cache/paddle/dataset/wmt16/`目录下
### 安装说明 ### 安装说明
1. paddle安装 1. paddle安装
本项目依赖于 Paddlepaddle Fluid 1.4.1,请参考安装指南进行安装。 本项目依赖于 PaddlePaddle Fluid 1.6.0 及以上版本(1.6.0 待近期正式发版,可先使用 develop),请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 环境依赖
2. 安装代码 多卡运行需要 NCCL 2.4.7 版本。
3. 环境依赖
### 执行训练: ### 执行训练:
如果是使用GPU单卡训练,启动训练的方式: 如果是使用GPU单卡训练,启动训练的方式:
...@@ -28,11 +28,19 @@ ...@@ -28,11 +28,19 @@
env CUDA_VISIBLE_DEVICES=0 python train.py env CUDA_VISIBLE_DEVICES=0 python train.py
``` ```
这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。 这里`CUDA_VISIBLE_DEVICES=0`表示是执行在0号设备卡上,请根据自身情况修改这个参数。如需调整其他模型及训练参数,可在 `config.py` 中修改或使用如下方式传入:
```sh
python train.py \
n_head 16 \
d_model 1024 \
d_inner_hid 4096 \
prepostprocess_dropout 0.3
```
Paddle动态图支持多进程多卡进行模型训练,启动训练的方式: Paddle动态图支持多进程多卡进行模型训练,启动训练的方式:
``` ```
python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog train.py --use_data_parallel 1 python -m paddle.distributed.launch --started_port 9999 --selected_gpus=0,1,2,3 --log_dir ./mylog train.py --use_data_parallel 1
``` ```
此时,程序会将每个进程的输出log导入到`./mylog`路径下: 此时,程序会将每个进程的输出log导入到`./mylog`路径下:
``` ```
...@@ -64,8 +72,105 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr ...@@ -64,8 +72,105 @@ python -m paddle.distributed.launch --selected_gpus=0,1,2,3 --log_dir ./mylog tr
pass num : 0, batch_id: 110, dy_graph avg loss: [7.4179897] pass num : 0, batch_id: 110, dy_graph avg loss: [7.4179897]
pass num : 0, batch_id: 120, dy_graph avg loss: [7.318419] pass num : 0, batch_id: 120, dy_graph avg loss: [7.318419]
### 执行预测
训练完成后,使用如下命令进行预测:
```
env CUDA_VISIBLE_DEVICES=0 python predict.py
```
预测结果将输出到 `predict.txt` 文件中(可在运行时通过 `--output_file` 更改),其他模型与预测参数也可在 `config.py` 中修改或使用如下方式传入:
```sh
python predict.py \
n_head 16 \
d_model 1024 \
d_inner_hid 4096 \
prepostprocess_dropout 0.3
```
完成预测后,可以借助第三方工具进行 BLEU 指标的评估,可按照如下方式进行:
```sh
# 提取 reference 数据
tar -zxvf ~/.cache/paddle/dataset/wmt16/wmt16.tar.gz
awk 'BEGIN {FS="\t"}; {print $2}' wmt16/test > ref.de
# clone mosesdecoder代码
git clone https://github.com/moses-smt/mosesdecoder.git
# 进行评估
perl mosesdecoder/scripts/generic/multi-bleu.perl ref.de < predict.txt
```
使用默认配置单卡训练20个 epoch 训练的模型约有如下评估结果:
```
BLEU = 32.38, 64.3/39.1/25.9/16.9 (BP=1.000, ratio=1.001, hyp_len=12122, ref_len=12104)
```
## 进阶使用 ## 进阶使用
### 自定义数据
- 训练:
修改 `train.py` 中的如下代码段
```python
reader = paddle.batch(wmt16.train(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
```
将其中的 `wmt16.train` 替换为类似如下的 python generator :
```python
def reader(file_name, src_dict, trg_dict):
start_id = src_dict[START_MARK] # BOS
end_id = src_dict[END_MARK] # EOS
unk_id = src_dict[UNK_MARK] # UNK
src_col, trg_col = 0, 1
for line in open(file_name, "r"):
line = line.strip()
line_split = line.strip().split("\t")
if len(line_split) != 2:
continue
src_words = line_split[src_col].split()
src_ids = [start_id] + [
src_dict.get(w, unk_id) for w in src_words
] + [end_id]
trg_words = line_split[trg_col].split()
trg_ids = [trg_dict.get(w, unk_id) for w in trg_words]
trg_ids_next = trg_ids + [end_id]
trg_ids = [start_id] + trg_ids
yield src_ids, trg_ids, trg_ids_next
```
该 generator 产生的数据为单个样本,是包含源句(src_ids),目标句(trg_ids)和标签(trg_ids_next)三个 integer list 的 tuple;其中 src_ids 包含 BOS 和 EOS 的 id,trg_ids 包含 BOS 的 id,trg_ids_next 包含 EOS 的 id。
- 预测:
修改 `predict.py` 中的如下代码段
```python
reader = paddle.batch(wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
id2word = wmt16.get_dict("de",
ModelHyperParams.trg_vocab_size,
reverse=True)
```
将其中的 `wmt16.test` 替换为和训练部分类似的 python generator ;另外还需要提供将 id 映射到 word 的 python dict 作为 `id2word` .
### 模型原理介绍 ### 模型原理介绍
Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构。其同样使用了 Seq2Seq 任务中典型的编码器-解码器(Encoder-Decoder)的框架结构,但相较于此前广泛使用的循环神经网络(Recurrent Neural Network, RNN),其完全使用注意力(Attention)机制来实现序列到序列的建模,整体网络结构如图1所示。 Transformer 是论文 [Attention Is All You Need](https://arxiv.org/abs/1706.03762) 中提出的用以完成机器翻译(machine translation, MT)等序列到序列(sequence to sequence, Seq2Seq)学习任务的一种全新网络结构。其同样使用了 Seq2Seq 任务中典型的编码器-解码器(Encoder-Decoder)的框架结构,但相较于此前广泛使用的循环神经网络(Recurrent Neural Network, RNN),其完全使用注意力(Attention)机制来实现序列到序列的建模,整体网络结构如图1所示。
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainTaskConfig(object):
"""
TrainTaskConfig
"""
# the epoch number to train.
pass_num = 20
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
class InferTaskConfig(object):
# the number of examples in one run for sequence generation.
batch_size = 4
# the parameters for beam search.
beam_size = 4
alpha=0.6
# max decoded length, should be less than ModelHyperParams.max_length
max_out_len = 30
class ModelHyperParams(object):
"""
ModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences deciding the size of position encoding table.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = False
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias":
[(batch_size, ModelHyperParams.n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table",
)
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table",
)
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
)
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output",
)
label_data_input_fields = (
"lbl_word",
"lbl_weight",
)
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
# "init_score",
# "init_idx",
"trg_src_attn_bias",
)
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
from paddle.fluid.dygraph import to_variable
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array([[1.] * len(inst) + [0.] * (max_len - len(inst))
for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
list(range(0, len(inst))) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.dygraph import Embedding, LayerNorm, FC, to_variable, Layer, guard
from paddle.fluid.dygraph.learning_rate_scheduler import LearningRateDecay
from config import word_emb_param_names, pos_enc_param_names
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels = d_pos_vec
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(
np.arange(num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(
inv_timescales, 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
class NoamDecay(LearningRateDecay):
"""
learning rate scheduler
"""
def __init__(self,
d_model,
warmup_steps,
static_lr=2.0,
begin=1,
step=1,
dtype='float32'):
super(NoamDecay, self).__init__(begin, step, dtype)
self.d_model = d_model
self.warmup_steps = warmup_steps
self.static_lr = static_lr
def step(self):
a = self.create_lr_var(self.step_num**-0.5)
b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
lr_value = (self.d_model**-0.5) * layers.elementwise_min(
a, b) * self.static_lr
return lr_value
class PrePostProcessLayer(Layer):
"""
PrePostProcessLayer
"""
def __init__(self, name_scope, process_cmd, shape_len=None):
super(PrePostProcessLayer, self).__init__(name_scope)
for cmd in process_cmd:
if cmd == "n":
self._layer_norm = LayerNorm(
name_scope=self.full_name(),
begin_norm_axis=shape_len - 1,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))
def forward(self, prev_out, out, process_cmd, dropout_rate=0.):
"""
forward
:param prev_out:
:param out:
:param process_cmd:
:param dropout_rate:
:return:
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = self._layer_norm(out)
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(out,
dropout_prob=dropout_rate,
is_test=False)
return out
class PositionwiseFeedForwardLayer(Layer):
"""
PositionwiseFeedForwardLayer
"""
def __init__(self, name_scope, d_inner_hid, d_hid, dropout_rate):
super(PositionwiseFeedForwardLayer, self).__init__(name_scope)
self._i2h = FC(name_scope=self.full_name(),
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
self._h2o = FC(name_scope=self.full_name(),
size=d_hid,
num_flatten_dims=2)
self._dropout_rate = dropout_rate
def forward(self, x):
"""
forward
:param x:
:return:
"""
hidden = self._i2h(x)
if self._dropout_rate:
hidden = layers.dropout(hidden,
dropout_prob=self._dropout_rate,
is_test=False)
out = self._h2o(hidden)
return out
class MultiHeadAttentionLayer(Layer):
"""
MultiHeadAttentionLayer
"""
def __init__(self,
name_scope,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
gather_idx=None,
static_kv=False):
super(MultiHeadAttentionLayer, self).__init__(name_scope)
self._n_head = n_head
self._d_key = d_key
self._d_value = d_value
self._d_model = d_model
self._dropout_rate = dropout_rate
self._q_fc = FC(name_scope=self.full_name(),
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
self._k_fc = FC(name_scope=self.full_name(),
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
self._v_fc = FC(name_scope=self.full_name(),
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
self._proj_fc = FC(name_scope=self.full_name(),
size=self._d_model,
bias_attr=False,
num_flatten_dims=2)
def forward(self,
queries,
keys,
values,
attn_bias,
cache=None,
gather_idx=None):
"""
forward
:param queries:
:param keys:
:param values:
:param attn_bias:
:return:
"""
# compute q ,k ,v
keys = queries if keys is None else keys
values = keys if values is None else values
q = self._q_fc(queries)
k = self._k_fc(keys)
v = self._v_fc(values)
# split head
reshaped_q = layers.reshape(x=q,
shape=[0, 0, self._n_head, self._d_key],
inplace=False)
transpose_q = layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3])
reshaped_k = layers.reshape(x=k,
shape=[0, 0, self._n_head, self._d_key],
inplace=False)
transpose_k = layers.transpose(x=reshaped_k, perm=[0, 2, 1, 3])
reshaped_v = layers.reshape(x=v,
shape=[0, 0, self._n_head, self._d_value],
inplace=False)
transpose_v = layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3])
if cache is not None:
cache_k, cache_v = cache["k"], cache["v"]
transpose_k = layers.concat([cache_k, transpose_k], axis=2)
transpose_v = layers.concat([cache_v, transpose_v], axis=2)
cache["k"], cache["v"] = transpose_k, transpose_v
# scale dot product attention
product = layers.matmul(x=transpose_q,
y=transpose_k,
transpose_y=True,
alpha=self._d_model**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if self._dropout_rate:
weights_droped = layers.dropout(weights,
dropout_prob=self._dropout_rate,
is_test=False)
out = layers.matmul(weights_droped, transpose_v)
else:
out = layers.matmul(weights, transpose_v)
# combine heads
if len(out.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(out, perm=[0, 2, 1, 3])
final_out = layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=False)
# fc to output
proj_out = self._proj_fc(final_out)
return proj_out
class EncoderSubLayer(Layer):
"""
EncoderSubLayer
"""
def __init__(self,
name_scope,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderSubLayer, self).__init__(name_scope)
self._preprocess_cmd = preprocess_cmd
self._postprocess_cmd = postprocess_cmd
self._prepostprocess_dropout = prepostprocess_dropout
self._preprocess_layer = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
self._multihead_attention_layer = MultiHeadAttentionLayer(
self.full_name(), d_key, d_value, d_model, n_head,
attention_dropout)
self._postprocess_layer = PrePostProcessLayer(self.full_name(),
self._postprocess_cmd,
None)
self._preprocess_layer2 = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
self._positionwise_feed_forward = PositionwiseFeedForwardLayer(
self.full_name(), d_inner_hid, d_model, relu_dropout)
self._postprocess_layer2 = PrePostProcessLayer(self.full_name(),
self._postprocess_cmd,
None)
def forward(self, enc_input, attn_bias):
"""
forward
:param enc_input:
:param attn_bias:
:return:
"""
pre_process_multihead = self._preprocess_layer(
None, enc_input, self._preprocess_cmd, self._prepostprocess_dropout)
attn_output = self._multihead_attention_layer(pre_process_multihead,
None, None, attn_bias)
attn_output = self._postprocess_layer(enc_input, attn_output,
self._postprocess_cmd,
self._prepostprocess_dropout)
pre_process2_output = self._preprocess_layer2(
None, attn_output, self._preprocess_cmd,
self._prepostprocess_dropout)
ffd_output = self._positionwise_feed_forward(pre_process2_output)
return self._postprocess_layer2(attn_output, ffd_output,
self._postprocess_cmd,
self._prepostprocess_dropout)
class EncoderLayer(Layer):
"""
encoder
"""
def __init__(self,
name_scope,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderLayer, self).__init__(name_scope)
self._preprocess_cmd = preprocess_cmd
self._encoder_sublayers = list()
self._prepostprocess_dropout = prepostprocess_dropout
self._n_layer = n_layer
self._preprocess_layer = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
for i in range(n_layer):
self._encoder_sublayers.append(
self.add_sublayer(
'esl_%d' % i,
EncoderSubLayer(self.full_name(), n_head, d_key, d_value,
d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd)))
def forward(self, enc_input, attn_bias):
"""
forward
:param enc_input:
:param attn_bias:
:return:
"""
for i in range(self._n_layer):
enc_output = self._encoder_sublayers[i](enc_input, attn_bias)
enc_input = enc_output
return self._preprocess_layer(None, enc_output, self._preprocess_cmd,
self._prepostprocess_dropout)
class PrepareEncoderDecoderLayer(Layer):
"""
PrepareEncoderDecoderLayer
"""
def __init__(self,
name_scope,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate,
word_emb_param_name=None,
pos_enc_param_name=None):
super(PrepareEncoderDecoderLayer, self).__init__(name_scope)
self._src_max_len = src_max_len
self._src_emb_dim = src_emb_dim
self._src_vocab_size = src_vocab_size
self._dropout_rate = dropout_rate
self._input_emb = Embedding(name_scope=self.full_name(),
size=[src_vocab_size, src_emb_dim],
padding_idx=0,
param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(
0., src_emb_dim**-0.5)))
pos_inp = position_encoding_init(src_max_len, src_emb_dim)
self._pos_emb = Embedding(
name_scope=self.full_name(),
size=[self._src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_name,
initializer=fluid.initializer.NumpyArrayInitializer(pos_inp),
trainable=False))
# use in dygraph_mode to fit different length batch
# self._pos_emb._w = to_variable(
# position_encoding_init(self._src_max_len, self._src_emb_dim))
def forward(self, src_word, src_pos):
"""
forward
:param src_word:
:param src_pos:
:return:
"""
# print("here")
# print(self._input_emb._w._numpy().shape)
src_word_emb = self._input_emb(src_word)
src_word_emb = layers.scale(x=src_word_emb,
scale=self._src_emb_dim**0.5)
# # TODO change this to fit dynamic length input
src_pos_emb = self._pos_emb(src_pos)
src_pos_emb.stop_gradient = True
enc_input = src_word_emb + src_pos_emb
return layers.dropout(
enc_input, dropout_prob=self._dropout_rate,
is_test=False) if self._dropout_rate else enc_input
class WrapEncoderLayer(Layer):
"""
encoderlayer
"""
def __init__(self, name_cope, src_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing):
"""
The wrapper assembles together all needed layers for the encoder.
"""
super(WrapEncoderLayer, self).__init__(name_cope)
self._prepare_encoder_layer = PrepareEncoderDecoderLayer(
self.full_name(),
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[0],
pos_enc_param_name=pos_enc_param_names[0])
self._encoder = EncoderLayer(self.full_name(), n_layer, n_head, d_key,
d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd)
def forward(self, enc_inputs):
"""forward"""
src_word, src_pos, src_slf_attn_bias = enc_inputs
enc_input = self._prepare_encoder_layer(src_word, src_pos)
enc_output = self._encoder(enc_input, src_slf_attn_bias)
return enc_output
class DecoderSubLayer(Layer):
"""
decoder
"""
def __init__(self, name_scope, n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd):
super(DecoderSubLayer, self).__init__(name_scope)
self._postprocess_cmd = postprocess_cmd
self._preprocess_cmd = preprocess_cmd
self._prepostprcess_dropout = prepostprocess_dropout
self._pre_process_layer = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._multihead_attention_layer = MultiHeadAttentionLayer(
self.full_name(), d_key, d_value, d_model, n_head,
attention_dropout)
self._post_process_layer = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
self._pre_process_layer2 = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._multihead_attention_layer2 = MultiHeadAttentionLayer(
self.full_name(), d_key, d_value, d_model, n_head,
attention_dropout)
self._post_process_layer2 = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
self._pre_process_layer3 = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._positionwise_feed_forward_layer = PositionwiseFeedForwardLayer(
self.full_name(), d_inner_hid, d_model, relu_dropout)
self._post_process_layer3 = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
def forward(self,
dec_input,
enc_output,
slf_attn_bias,
dec_enc_attn_bias,
cache=None,
gather_idx=None):
"""
forward
:param dec_input:
:param enc_output:
:param slf_attn_bias:
:param dec_enc_attn_bias:
:return:
"""
pre_process_rlt = self._pre_process_layer(None, dec_input,
self._preprocess_cmd,
self._prepostprcess_dropout)
slf_attn_output = self._multihead_attention_layer(
pre_process_rlt, None, None, slf_attn_bias, cache, gather_idx)
slf_attn_output_pp = self._post_process_layer(
dec_input, slf_attn_output, self._postprocess_cmd,
self._prepostprcess_dropout)
pre_process_rlt2 = self._pre_process_layer2(None, slf_attn_output_pp,
self._preprocess_cmd,
self._prepostprcess_dropout)
enc_attn_output_pp = self._multihead_attention_layer2(
pre_process_rlt2, enc_output, enc_output, dec_enc_attn_bias)
enc_attn_output = self._post_process_layer2(slf_attn_output_pp,
enc_attn_output_pp,
self._postprocess_cmd,
self._prepostprcess_dropout)
pre_process_rlt3 = self._pre_process_layer3(None, enc_attn_output,
self._preprocess_cmd,
self._prepostprcess_dropout)
ffd_output = self._positionwise_feed_forward_layer(pre_process_rlt3)
dec_output = self._post_process_layer3(enc_attn_output, ffd_output,
self._postprocess_cmd,
self._prepostprcess_dropout)
return dec_output
class DecoderLayer(Layer):
"""
decoder
"""
def __init__(self, name_scope, n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd):
super(DecoderLayer, self).__init__(name_scope)
self._pre_process_layer = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._decoder_sub_layers = list()
self._n_layer = n_layer
self._preprocess_cmd = preprocess_cmd
self._prepostprocess_dropout = prepostprocess_dropout
for i in range(n_layer):
self._decoder_sub_layers.append(
self.add_sublayer(
'dsl_%d' % i,
DecoderSubLayer(self.full_name(), n_head, d_key, d_value,
d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd,
postprocess_cmd)))
def forward(self,
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
caches=None,
gather_idx=None):
"""
forward
:param dec_input:
:param enc_output:
:param dec_slf_attn_bias:
:param dec_enc_attn_bias:
:return:
"""
for i in range(self._n_layer):
tmp_dec_output = self._decoder_sub_layers[i](
dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias,
None if caches is None else caches[i], gather_idx)
dec_input = tmp_dec_output
dec_output = self._pre_process_layer(None, tmp_dec_output,
self._preprocess_cmd,
self._prepostprocess_dropout)
return dec_output
class WrapDecoderLayer(Layer):
"""
decoder
"""
def __init__(self,
name_scope,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
gather_idx=None):
"""
The wrapper assembles together all needed layers for the encoder.
"""
super(WrapDecoderLayer, self).__init__(name_scope)
self._prepare_decoder_layer = PrepareEncoderDecoderLayer(
self.full_name(),
trg_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[1],
pos_enc_param_name=pos_enc_param_names[1])
self._decoder_layer = DecoderLayer(self.full_name(), n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout,
attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd)
self._weight_sharing = weight_sharing
if not weight_sharing:
self._fc = FC(self.full_name(),
size=trg_vocab_size,
bias_attr=False)
def forward(self, dec_inputs, enc_output, caches=None, gather_idx=None):
"""
forward
:param dec_inputs:
:param enc_output:
:return:
"""
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = self._prepare_decoder_layer(trg_word, trg_pos)
dec_output = self._decoder_layer(dec_input, enc_output,
trg_slf_attn_bias, trg_src_attn_bias,
caches, gather_idx)
dec_output_reshape = layers.reshape(dec_output,
shape=[-1, dec_output.shape[-1]],
inplace=False)
if self._weight_sharing:
predict = layers.matmul(x=dec_output_reshape,
y=self._prepare_decoder_layer._input_emb._w,
transpose_y=True)
else:
predict = self._fc(dec_output_reshape)
if dec_inputs is None:
# Return probs for independent decoder program.
predict_out = layers.softmax(predict)
return predict_out
return predict
class TransFormer(Layer):
"""
model
"""
def __init__(self,
name_scope,
src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
label_smooth_eps=0.0):
super(TransFormer, self).__init__(name_scope)
self._label_smooth_eps = label_smooth_eps
self._trg_vocab_size = trg_vocab_size
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
self._wrap_encoder_layer = WrapEncoderLayer(
self.full_name(), src_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing)
self._wrap_decoder_layer = WrapDecoderLayer(
self.full_name(), trg_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing)
if weight_sharing:
self._wrap_decoder_layer._prepare_decoder_layer._input_emb._w = self._wrap_encoder_layer._prepare_encoder_layer._input_emb._w
self.n_layer = n_layer
self.n_head = n_head
self.d_key = d_key
self.d_value = d_value
def forward(self, enc_inputs, dec_inputs, label, weights):
"""
forward
:param enc_inputs:
:param dec_inputs:
:param label:
:param weights:
:return:
"""
enc_output = self._wrap_encoder_layer(enc_inputs)
predict = self._wrap_decoder_layer(dec_inputs, enc_output)
if self._label_smooth_eps:
label_out = layers.label_smooth(label=layers.one_hot(
input=label, depth=self._trg_vocab_size),
epsilon=self._label_smooth_eps)
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label_out,
soft_label=True if self._label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
token_num = layers.reduce_sum(weights)
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict, token_num
def beam_search(self,
enc_inputs,
dec_inputs,
bos_id=0,
eos_id=1,
beam_size=4,
max_len=30,
alpha=0.6):
"""
Beam search with the alive and finished two queues, both have a beam size
capicity separately. It includes `grow_topk` `grow_alive` `grow_finish` as
steps.
1. `grow_topk` selects the top `2*beam_size` candidates to avoid all getting
EOS.
2. `grow_alive` selects the top `beam_size` non-EOS candidates as the inputs
of next decoding step.
3. `grow_finish` compares the already finished candidates in the finished queue
and newly added finished candidates from `grow_topk`, and selects the top
`beam_size` finished candidates.
"""
def expand_to_beam_size(tensor, beam_size):
tensor = layers.reshape(tensor,
[tensor.shape[0], 1] + tensor.shape[1:])
tile_dims = [1] * len(tensor.shape)
tile_dims[1] = beam_size
return layers.expand(tensor, tile_dims)
def merge_beam_dim(tensor):
return layers.reshape(tensor, [-1] + tensor.shape[2:])
# run encoder
enc_output = self._wrap_encoder_layer(enc_inputs)
# constant number
inf = float(1. * 1e7)
batch_size = enc_output.shape[0]
### initialize states of beam search ###
## init for the alive ##
initial_ids, trg_src_attn_bias = dec_inputs # (batch_size, 1)
initial_log_probs = to_variable(
np.array([[0.] + [-inf] * (beam_size - 1)], dtype="float32"))
alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1])
alive_seq = to_variable(
np.tile(np.array([[[bos_id]]], dtype="int64"),
(batch_size, beam_size, 1)))
## init for the finished ##
finished_scores = to_variable(
np.array([[-inf] * beam_size], dtype="float32"))
finished_scores = layers.expand(finished_scores, [batch_size, 1])
finished_seq = to_variable(
np.tile(np.array([[[bos_id]]], dtype="int64"),
(batch_size, beam_size, 1)))
finished_flags = layers.zeros_like(finished_scores)
### initialize inputs and states of transformer decoder ###
## init inputs for decoder, shaped `[batch_size*beam_size, ...]`
trg_word = layers.reshape(alive_seq[:, :, -1],
[batch_size * beam_size, 1, 1])
trg_pos = layers.zeros_like(trg_word)
trg_src_attn_bias = merge_beam_dim(
expand_to_beam_size(trg_src_attn_bias, beam_size))
enc_output = merge_beam_dim(expand_to_beam_size(enc_output, beam_size))
## init states (caches) for transformer, need to be updated according to selected beam
caches = [{
"k":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_key],
dtype=enc_output.dtype,
value=0),
"v":
layers.fill_constant(
shape=[batch_size * beam_size, self.n_head, 0, self.d_value],
dtype=enc_output.dtype,
value=0),
} for i in range(self.n_layer)]
def update_states(caches, beam_idx, beam_size):
for cache in caches:
cache["k"] = gather_2d_by_gather(cache["k"], beam_idx,
beam_size, batch_size, False)
cache["v"] = gather_2d_by_gather(cache["v"], beam_idx,
beam_size, batch_size, False)
return caches
def gather_2d_by_gather(tensor_nd,
beam_idx,
beam_size,
batch_size,
need_flat=True):
batch_idx = layers.range(0, batch_size, 1,
dtype="int64") * beam_size
flat_tensor = merge_beam_dim(tensor_nd) if need_flat else tensor_nd
idx = layers.reshape(layers.elementwise_add(beam_idx, batch_idx, 0),
[-1])
new_flat_tensor = layers.gather(flat_tensor, idx)
new_tensor_nd = layers.reshape(
new_flat_tensor,
shape=[batch_size, beam_idx.shape[1]] +
tensor_nd.shape[2:]) if need_flat else new_flat_tensor
return new_tensor_nd
def early_finish(alive_log_probs, finished_scores,
finished_in_finished):
max_length_penalty = np.power(((5. + max_len) / 6.), alpha)
# The best possible score of the most likely alive sequence
lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty
# Now to compute the lowest score of a finished sequence in finished
# If the sequence isn't finished, we multiply it's score by 0. since
# scores are all -ve, taking the min will give us the score of the lowest
# finished item.
lowest_score_of_fininshed_in_finished = layers.reduce_min(
finished_scores * finished_in_finished, 1)
# If none of the sequences have finished, then the min will be 0 and
# we have to replace it by -ve INF if it is. The score of any seq in alive
# will be much higher than -ve INF and the termination condition will not
# be met.
lowest_score_of_fininshed_in_finished += (
1. - layers.reduce_max(finished_in_finished, 1)) * -inf
bound_is_met = layers.reduce_all(
layers.greater_than(lowest_score_of_fininshed_in_finished,
lower_bound_alive_scores))
return bound_is_met
def grow_topk(i, logits, alive_seq, alive_log_probs, states):
logits = layers.reshape(logits, [batch_size, beam_size, -1])
candidate_log_probs = layers.log(layers.softmax(logits, axis=2))
log_probs = layers.elementwise_add(candidate_log_probs,
alive_log_probs, 0)
length_penalty = np.power(5.0 + (i + 1.0) / 6.0, alpha)
curr_scores = log_probs / length_penalty
flat_curr_scores = layers.reshape(curr_scores, [batch_size, -1])
topk_scores, topk_ids = layers.topk(flat_curr_scores,
k=beam_size * 2)
topk_log_probs = topk_scores * length_penalty
topk_beam_index = topk_ids // self._trg_vocab_size
topk_ids = topk_ids % self._trg_vocab_size
# use gather as gather_nd, TODO: use gather_nd
topk_seq = gather_2d_by_gather(alive_seq, topk_beam_index,
beam_size, batch_size)
topk_seq = layers.concat(
[topk_seq,
layers.reshape(topk_ids, topk_ids.shape + [1])],
axis=2)
states = update_states(states, topk_beam_index, beam_size)
eos = layers.fill_constant(shape=topk_ids.shape,
dtype="int64",
value=eos_id)
topk_finished = layers.cast(layers.equal(topk_ids, eos), "float32")
#topk_seq: [batch_size, 2*beam_size, i+1]
#topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size]
return topk_seq, topk_log_probs, topk_scores, topk_finished, states
def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
states):
curr_scores += curr_finished * -inf
_, topk_indexes = layers.topk(curr_scores, k=beam_size)
alive_seq = gather_2d_by_gather(curr_seq, topk_indexes,
beam_size * 2, batch_size)
alive_log_probs = gather_2d_by_gather(curr_log_probs, topk_indexes,
beam_size * 2, batch_size)
states = update_states(states, topk_indexes, beam_size * 2)
return alive_seq, alive_log_probs, states
def grow_finished(finished_seq, finished_scores, finished_flags,
curr_seq, curr_scores, curr_finished):
# finished scores
finished_seq = layers.concat([
finished_seq,
layers.fill_constant(shape=[batch_size, beam_size, 1],
dtype="int64",
value=eos_id)
],
axis=2)
# Set the scores of the unfinished seq in curr_seq to large negative
# values
curr_scores += (1. - curr_finished) * -inf
# concatenating the sequences and scores along beam axis
curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1)
curr_finished_scores = layers.concat([finished_scores, curr_scores],
axis=1)
curr_finished_flags = layers.concat([finished_flags, curr_finished],
axis=1)
_, topk_indexes = layers.topk(curr_finished_scores, k=beam_size)
finished_seq = gather_2d_by_gather(curr_finished_seq, topk_indexes,
beam_size * 3, batch_size)
finished_scores = gather_2d_by_gather(curr_finished_scores,
topk_indexes, beam_size * 3,
batch_size)
finished_flags = gather_2d_by_gather(curr_finished_flags,
topk_indexes, beam_size * 3,
batch_size)
return finished_seq, finished_scores, finished_flags
for i in range(max_len):
logits = self._wrap_decoder_layer(
(trg_word, trg_pos, None, trg_src_attn_bias), enc_output,
caches)
topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk(
i, logits, alive_seq, alive_log_probs, caches)
alive_seq, alive_log_probs, states = grow_alive(
topk_seq, topk_scores, topk_log_probs, topk_finished, states)
finished_seq, finished_scores, finished_flags = grow_finished(
finished_seq, finished_scores, finished_flags, topk_seq,
topk_scores, topk_finished)
trg_word = layers.reshape(alive_seq[:, :, -1],
[batch_size * beam_size, 1, 1])
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
if early_finish(alive_log_probs, finished_scores,
finished_flags).numpy():
break
return finished_seq, finished_scores
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import ast
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.dataset.wmt16 as wmt16
from model import TransFormer
from config import *
from data_util import *
def parse_args():
parser = argparse.ArgumentParser("Arguments for Inference")
parser.add_argument(
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to shuffle instances in each pass.")
parser.add_argument(
"--model_file",
type=str,
default="transformer_params",
help="Load model from the file named `model_file.pdparams`.")
parser.add_argument(
"--output_file",
type=str,
default="predict.txt",
help="The file to output the translation results of predict_file to.")
parser.add_argument('opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
merge_cfg_from_list(args.opts, [InferTaskConfig, ModelHyperParams])
return args
def prepare_infer_input(insts, src_pad_idx, bos_idx, n_head):
"""
inputs for inferencs
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1, 1)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
data_inputs = [
src_word, src_pos, src_slf_attn_bias, trg_word, trg_src_attn_bias
]
var_inputs = []
for i, field in enumerate(encoder_data_input_fields +
fast_decoder_data_input_fields):
var_inputs.append(to_variable(data_inputs[i], name=field))
enc_inputs = var_inputs[0:len(encoder_data_input_fields)]
dec_inputs = var_inputs[len(encoder_data_input_fields):]
return enc_inputs, dec_inputs
def post_process_seq(seq, bos_idx, eos_idx, output_bos=False, output_eos=False):
"""
Post-process the decoded sequence.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def infer(args):
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
transformer = TransFormer(
'transformer', ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer, ModelHyperParams.n_head,
ModelHyperParams.d_key, ModelHyperParams.d_value,
ModelHyperParams.d_model, ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing)
# load checkpoint
model_dict, _ = fluid.load_dygraph(args.model_file)
transformer.load_dict(model_dict)
print("checkpoint loaded")
# start evaluate mode
transformer.eval()
reader = paddle.batch(wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
id2word = wmt16.get_dict("de",
ModelHyperParams.trg_vocab_size,
reverse=True)
f = open(args.output_file, "wb")
for batch in reader():
enc_inputs, dec_inputs = prepare_infer_input(
batch, ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head)
finished_seq, finished_scores = transformer.beam_search(
enc_inputs,
dec_inputs,
bos_id=ModelHyperParams.bos_idx,
eos_id=ModelHyperParams.eos_idx,
max_len=InferTaskConfig.max_out_len,
alpha=InferTaskConfig.alpha)
finished_seq = finished_seq.numpy()
finished_scores = finished_scores.numpy()
for ins in finished_seq:
for beam in ins:
id_list = post_process_seq(beam, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx)
word_list = [id2word[id] for id in id_list]
sequence = " ".join(word_list) + "\n"
f.write(sequence.encode("utf8"))
break # only print the best
if __name__ == '__main__':
args = parse_args()
infer(args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,377 +15,48 @@ ...@@ -15,377 +15,48 @@
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import ast import ast
import paddle.fluid as fluid
from paddle.fluid.dygraph import Embedding, LayerNorm, FC, to_variable, Layer, guard
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid
import paddle.dataset.wmt16 as wmt16 import paddle.dataset.wmt16 as wmt16
from model import TransFormer, NoamDecay
from config import *
from data_util import *
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("Training for Mnist.") parser = argparse.ArgumentParser("Arguments for Training")
parser.add_argument( parser.add_argument(
"--use_data_parallel", "--use_data_parallel",
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="The flag indicating whether to use data parallel mode to train the model." help="The flag indicating whether to use multi-GPU.")
) parser.add_argument(
"--model_file",
type=str,
default="transformer_params",
help="Save the model as a file named `model_file.pdparams`.")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args() args = parser.parse_args()
merge_cfg_from_list(args.opts, [TrainTaskConfig, ModelHyperParams])
return args return args
args = parse_args() def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
# Copy from models
class TrainTaskConfig(object):
"""
TrainTaskConfig
"""
# support both CPU and GPU now.
use_gpu = True
# the epoch number to train.
pass_num = 30
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
class ModelHyperParams(object):
"""
ModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# # index for <bos> token
# bos_idx = 0
# # index for <eos> token
# eos_idx = 1
# # index for <unk> token
# unk_idx = 2
src_pad_idx = 0
# index for <pad> token in target language.
trg_pad_idx = 1
# max length of sequences deciding the size of position encoding table.
max_length = 50
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE.
dropout_seed = 0
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = False
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"init_idx",
"trg_src_attn_bias", )
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
""" """
channels = d_pos_vec inputs for training
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(np.arange(
num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
def create_data(np_values, is_static=False):
""" """
create_data src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
:param np_values:
:param is_static:
:return:
"""
# pdb.set_trace()
[
src_word_np, src_pos_np, trg_word_np, trg_pos_np, src_slf_attn_bias_np,
trg_slf_attn_bias_np, trg_src_attn_bias_np, lbl_word_np, lbl_weight_np
] = np_values
if is_static:
return [
src_word_np, src_pos_np, src_slf_attn_bias_np, trg_word_np,
trg_pos_np, trg_slf_attn_bias_np, trg_src_attn_bias_np, lbl_word_np,
lbl_weight_np
]
else:
enc_inputs = [
to_variable(
src_word_np, name='src_word'), to_variable(
src_pos_np, name='src_pos'), to_variable(
src_slf_attn_bias_np, name='src_slf_attn_bias')
]
dec_inputs = [
to_variable(
trg_word_np, name='trg_word'), to_variable(
trg_pos_np, name='trg_pos'), to_variable(
trg_slf_attn_bias_np, name='trg_slf_attn_bias'),
to_variable(
trg_src_attn_bias_np, name='trg_src_attn_bias')
]
label = to_variable(lbl_word_np, name='lbl_word')
weight = to_variable(lbl_weight_np, name='lbl_weight')
return enc_inputs, dec_inputs, label, weight
def create_feed_dict_list(data, init=False):
"""
create_feed_dict_list
:param data:
:param init:
:return:
"""
if init:
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields + pos_enc_param_names
else:
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields
feed_dict_list = dict()
for i in range(len(data_input_names)):
feed_dict_list[data_input_names[i]] = data[i]
return feed_dict_list
def make_all_inputs(input_fields):
"""
Define the input data layers for the transformer model.
"""
inputs = []
for input_field in input_fields:
input_var = fluid.layers.data(
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias. Then, convert the numpy
data to tensors and return a dict mapping names to tensors.
"""
def __pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array([[1.] * len(inst) + [0.] *
(max_len - len(inst)) for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
list(range(0, len(inst))) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones(
(inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(
slf_attn_bias_data, 1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
src_word, src_pos, src_slf_attn_bias, src_max_len = __pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1) src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1) src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = __pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1) trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_pos = trg_pos.reshape(-1, trg_max_len, 1) trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
...@@ -393,7 +64,7 @@ def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -393,7 +64,7 @@ def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head):
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = __pad_batch_data( lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts], [inst[2] for inst in insts],
trg_pad_idx, trg_pad_idx,
n_head, n_head,
...@@ -403,708 +74,28 @@ def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -403,708 +74,28 @@ def prepare_batch_input(insts, src_pad_idx, trg_pad_idx, n_head):
return_max_len=False, return_max_len=False,
return_num_token=True) return_num_token=True)
return [ data_inputs = [
src_word, src_pos, trg_word, trg_pos, src_slf_attn_bias, src_word, src_pos, src_slf_attn_bias, trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight
] ]
var_inputs = []
for i, field in enumerate(encoder_data_input_fields +
decoder_data_input_fields[:-1] +
label_data_input_fields):
var_inputs.append(to_variable(data_inputs[i], name=field))
pos_inp1 = position_encoding_init(ModelHyperParams.max_length + 1, enc_inputs = var_inputs[0:len(encoder_data_input_fields)]
ModelHyperParams.d_model) dec_inputs = var_inputs[len(encoder_data_input_fields
pos_inp2 = position_encoding_init(ModelHyperParams.max_length + 1, ):len(encoder_data_input_fields) +
ModelHyperParams.d_model) len(decoder_data_input_fields[:-1])]
label = var_inputs[-2]
weights = var_inputs[-1]
class PrePostProcessLayer(Layer): return enc_inputs, dec_inputs, label, weights
"""
PrePostProcessLayer
"""
def __init__(self, name_scope, process_cmd, shape_len=None):
super(PrePostProcessLayer, self).__init__(name_scope)
for cmd in process_cmd:
if cmd == "n":
self._layer_norm = LayerNorm(
name_scope=self.full_name(),
begin_norm_axis=shape_len - 1,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.)))
def forward(self, prev_out, out, process_cmd, dropout_rate=0.): def train(args):
"""
forward
:param prev_out:
:param out:
:param process_cmd:
:param dropout_rate:
:return:
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = self._layer_norm(out)
elif cmd == "d": # add dropout
if dropout_rate:
out = fluid.layers.dropout(
out,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False)
return out
class PositionwiseFeedForwardLayer(Layer):
"""
PositionwiseFeedForwardLayer
"""
def __init__(self, name_scope, d_inner_hid, d_hid, dropout_rate):
super(PositionwiseFeedForwardLayer, self).__init__(name_scope)
self._i2h = FC(name_scope=self.full_name(),
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
self._h2o = FC(name_scope=self.full_name(),
size=d_hid,
num_flatten_dims=2)
self._dropout_rate = dropout_rate
def forward(self, x):
"""
forward
:param x:
:return:
"""
hidden = self._i2h(x)
if self._dropout_rate:
hidden = fluid.layers.dropout(
hidden,
dropout_prob=self._dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False)
out = self._h2o(hidden)
return out
class MultiHeadAttentionLayer(Layer):
"""
MultiHeadAttentionLayer
"""
def __init__(self,
name_scope,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
gather_idx=None,
static_kv=False):
super(MultiHeadAttentionLayer, self).__init__(name_scope)
self._n_head = n_head
self._d_key = d_key
self._d_value = d_value
self._d_model = d_model
self._dropout_rate = dropout_rate
self._q_fc = FC(name_scope=self.full_name(),
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
self._k_fc = FC(name_scope=self.full_name(),
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
self._v_fc = FC(name_scope=self.full_name(),
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
self._proj_fc = FC(name_scope=self.full_name(),
size=self._d_model,
bias_attr=False,
num_flatten_dims=2)
def forward(self, queries, keys, values, attn_bias):
"""
forward
:param queries:
:param keys:
:param values:
:param attn_bias:
:return:
"""
# compute q ,k ,v
keys = queries if keys is None else keys
values = keys if values is None else values
q = self._q_fc(queries)
k = self._k_fc(keys)
v = self._v_fc(values)
# split head
reshaped_q = fluid.layers.reshape(
x=q, shape=[0, 0, self._n_head, self._d_key], inplace=False)
transpose_q = fluid.layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3])
reshaped_k = fluid.layers.reshape(
x=k, shape=[0, 0, self._n_head, self._d_key], inplace=False)
transpose_k = fluid.layers.transpose(x=reshaped_k, perm=[0, 2, 1, 3])
reshaped_v = fluid.layers.reshape(
x=v, shape=[0, 0, self._n_head, self._d_value], inplace=False)
transpose_v = fluid.layers.transpose(x=reshaped_v, perm=[0, 2, 1, 3])
# scale dot product attention
product = fluid.layers.matmul(
x=transpose_q,
y=transpose_k,
transpose_y=True,
alpha=self._d_model**-0.5)
if attn_bias:
product += attn_bias
weights = fluid.layers.softmax(product)
if self._dropout_rate:
weights_droped = fluid.layers.dropout(
weights,
dropout_prob=self._dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False)
out = fluid.layers.matmul(weights_droped, transpose_v)
else:
out = fluid.layers.matmul(weights, transpose_v)
# combine heads
if len(out.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = fluid.layers.transpose(out, perm=[0, 2, 1, 3])
final_out = fluid.layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=False)
# fc to output
proj_out = self._proj_fc(final_out)
return proj_out
class EncoderSubLayer(Layer):
"""
EncoderSubLayer
"""
def __init__(self,
name_scope,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderSubLayer, self).__init__(name_scope)
self._preprocess_cmd = preprocess_cmd
self._postprocess_cmd = postprocess_cmd
self._prepostprocess_dropout = prepostprocess_dropout
self._preprocess_layer = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
self._multihead_attention_layer = MultiHeadAttentionLayer(
self.full_name(), d_key, d_value, d_model, n_head,
attention_dropout)
self._postprocess_layer = PrePostProcessLayer(
self.full_name(), self._postprocess_cmd, None)
self._preprocess_layer2 = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
self._positionwise_feed_forward = PositionwiseFeedForwardLayer(
self.full_name(), d_inner_hid, d_model, relu_dropout)
self._postprocess_layer2 = PrePostProcessLayer(
self.full_name(), self._postprocess_cmd, None)
def forward(self, enc_input, attn_bias):
"""
forward
:param enc_input:
:param attn_bias:
:return:
"""
pre_process_multihead = self._preprocess_layer(
None, enc_input, self._preprocess_cmd, self._prepostprocess_dropout)
attn_output = self._multihead_attention_layer(pre_process_multihead,
None, None, attn_bias)
attn_output = self._postprocess_layer(enc_input, attn_output,
self._postprocess_cmd,
self._prepostprocess_dropout)
pre_process2_output = self._preprocess_layer2(
None, attn_output, self._preprocess_cmd,
self._prepostprocess_dropout)
ffd_output = self._positionwise_feed_forward(pre_process2_output)
return self._postprocess_layer2(attn_output, ffd_output,
self._postprocess_cmd,
self._prepostprocess_dropout)
class EncoderLayer(Layer):
"""
encoder
"""
def __init__(self,
name_scope,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
super(EncoderLayer, self).__init__(name_scope)
self._preprocess_cmd = preprocess_cmd
self._encoder_sublayers = list()
self._prepostprocess_dropout = prepostprocess_dropout
self._n_layer = n_layer
self._preprocess_layer = PrePostProcessLayer(self.full_name(),
self._preprocess_cmd, 3)
for i in range(n_layer):
self._encoder_sublayers.append(
self.add_sublayer(
'esl_%d' % i,
EncoderSubLayer(
self.full_name(), n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd)))
def forward(self, enc_input, attn_bias):
"""
forward
:param enc_input:
:param attn_bias:
:return:
"""
for i in range(self._n_layer):
enc_output = self._encoder_sublayers[i](enc_input, attn_bias)
enc_input = enc_output
return self._preprocess_layer(None, enc_output, self._preprocess_cmd,
self._prepostprocess_dropout)
class PrepareEncoderDecoderLayer(Layer):
"""
PrepareEncoderDecoderLayer
"""
def __init__(self,
name_scope,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate,
word_emb_param_name=None,
pos_enc_param_name=None):
super(PrepareEncoderDecoderLayer, self).__init__(name_scope)
self._src_max_len = src_max_len
self._src_emb_dim = src_emb_dim
self._src_vocab_size = src_vocab_size
self._dropout_rate = dropout_rate
self._input_emb = Embedding(
name_scope=self.full_name(),
size=[src_vocab_size, src_emb_dim],
padding_idx=0,
param_attr=fluid.ParamAttr(
name=word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
if pos_enc_param_name is pos_enc_param_names[0]:
pos_inp = pos_inp1
else:
pos_inp = pos_inp2
self._pos_emb = Embedding(
name_scope=self.full_name(),
size=[self._src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
name=pos_enc_param_name,
initializer=fluid.initializer.NumpyArrayInitializer(pos_inp),
trainable=False))
# use in dygraph_mode to fit different length batch
# self._pos_emb._w = to_variable(
# position_encoding_init(self._src_max_len, self._src_emb_dim))
def forward(self, src_word, src_pos):
"""
forward
:param src_word:
:param src_pos:
:return:
"""
# print("here")
# print(self._input_emb._w._numpy().shape)
src_word_emb = self._input_emb(src_word)
src_word_emb = fluid.layers.scale(
x=src_word_emb, scale=self._src_emb_dim**0.5)
# # TODO change this to fit dynamic length input
src_pos_emb = self._pos_emb(src_pos)
src_pos_emb.stop_gradient = True
enc_input = src_word_emb + src_pos_emb
return fluid.layers.dropout(
enc_input,
dropout_prob=self._dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False) if self._dropout_rate else enc_input
class WrapEncoderLayer(Layer):
"""
encoderlayer
"""
def __init__(self, name_cope, src_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing):
"""
The wrapper assembles together all needed layers for the encoder.
"""
super(WrapEncoderLayer, self).__init__(name_cope)
self._prepare_encoder_layer = PrepareEncoderDecoderLayer(
self.full_name(),
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[0],
pos_enc_param_name=pos_enc_param_names[0])
self._encoder = EncoderLayer(
self.full_name(), n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd)
def forward(self, enc_inputs):
"""forward"""
src_word, src_pos, src_slf_attn_bias = enc_inputs
enc_input = self._prepare_encoder_layer(src_word, src_pos)
enc_output = self._encoder(enc_input, src_slf_attn_bias)
return enc_output
class DecoderSubLayer(Layer):
"""
decoder
"""
def __init__(self,
name_scope,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None,
gather_idx=None):
super(DecoderSubLayer, self).__init__(name_scope)
self._postprocess_cmd = postprocess_cmd
self._preprocess_cmd = preprocess_cmd
self._prepostprcess_dropout = prepostprocess_dropout
self._pre_process_layer = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._multihead_attention_layer = MultiHeadAttentionLayer(
self.full_name(),
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache=cache,
gather_idx=gather_idx)
self._post_process_layer = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
self._pre_process_layer2 = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._multihead_attention_layer2 = MultiHeadAttentionLayer(
self.full_name(),
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache=cache,
gather_idx=gather_idx,
static_kv=True)
self._post_process_layer2 = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
self._pre_process_layer3 = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._positionwise_feed_forward_layer = PositionwiseFeedForwardLayer(
self.full_name(), d_inner_hid, d_model, relu_dropout)
self._post_process_layer3 = PrePostProcessLayer(self.full_name(),
postprocess_cmd, None)
def forward(self, dec_input, enc_output, slf_attn_bias, dec_enc_attn_bias):
"""
forward
:param dec_input:
:param enc_output:
:param slf_attn_bias:
:param dec_enc_attn_bias:
:return:
"""
pre_process_rlt = self._pre_process_layer(
None, dec_input, self._preprocess_cmd, self._prepostprcess_dropout)
slf_attn_output = self._multihead_attention_layer(pre_process_rlt, None,
None, slf_attn_bias)
slf_attn_output_pp = self._post_process_layer(
dec_input, slf_attn_output, self._postprocess_cmd,
self._prepostprcess_dropout)
pre_process_rlt2 = self._pre_process_layer2(None, slf_attn_output_pp,
self._preprocess_cmd,
self._prepostprcess_dropout)
enc_attn_output_pp = self._multihead_attention_layer2(
pre_process_rlt2, enc_output, enc_output, dec_enc_attn_bias)
enc_attn_output = self._post_process_layer2(
slf_attn_output_pp, enc_attn_output_pp, self._postprocess_cmd,
self._prepostprcess_dropout)
pre_process_rlt3 = self._pre_process_layer3(None, enc_attn_output,
self._preprocess_cmd,
self._prepostprcess_dropout)
ffd_output = self._positionwise_feed_forward_layer(pre_process_rlt3)
dec_output = self._post_process_layer3(enc_attn_output, ffd_output,
self._postprocess_cmd,
self._prepostprcess_dropout)
return dec_output
class DecoderLayer(Layer):
"""
decoder
"""
def __init__(self,
name_scope,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=None,
gather_idx=None):
super(DecoderLayer, self).__init__(name_scope)
self._pre_process_layer = PrePostProcessLayer(self.full_name(),
preprocess_cmd, 3)
self._decoder_sub_layers = list()
self._n_layer = n_layer
self._preprocess_cmd = preprocess_cmd
self._prepostprocess_dropout = prepostprocess_dropout
for i in range(n_layer):
self._decoder_sub_layers.append(
self.add_sublayer(
'dsl_%d' % i,
DecoderSubLayer(
self.full_name(),
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None if caches is None else caches[i],
gather_idx=gather_idx)))
def forward(self, dec_input, enc_output, dec_slf_attn_bias,
dec_enc_attn_bias):
"""
forward
:param dec_input:
:param enc_output:
:param dec_slf_attn_bias:
:param dec_enc_attn_bias:
:return:
"""
for i in range(self._n_layer):
tmp_dec_output = self._decoder_sub_layers[i](
dec_input, enc_output, dec_slf_attn_bias, dec_enc_attn_bias)
dec_input = tmp_dec_output
dec_output = self._pre_process_layer(None, tmp_dec_output,
self._preprocess_cmd,
self._prepostprocess_dropout)
return dec_output
class WrapDecoderLayer(Layer):
"""
decoder
"""
def __init__(self,
name_scope,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
caches=None,
gather_idx=None):
"""
The wrapper assembles together all needed layers for the encoder.
"""
super(WrapDecoderLayer, self).__init__(name_scope)
self._prepare_decoder_layer = PrepareEncoderDecoderLayer(
self.full_name(),
trg_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[1],
pos_enc_param_name=pos_enc_param_names[1])
self._decoder_layer = DecoderLayer(
self.full_name(),
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=caches,
gather_idx=gather_idx)
self._weight_sharing = weight_sharing
if not weight_sharing:
self._fc = FC(self.full_name(),
size=trg_vocab_size,
bias_attr=False)
def forward(self, dec_inputs=None, enc_output=None):
"""
forward
:param dec_inputs:
:param enc_output:
:return:
"""
trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = self._prepare_decoder_layer(trg_word, trg_pos)
dec_output = self._decoder_layer(dec_input, enc_output,
trg_slf_attn_bias, trg_src_attn_bias)
dec_output_reshape = fluid.layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=False)
if self._weight_sharing:
predict = fluid.layers.matmul(
x=dec_output_reshape,
y=self._prepare_decoder_layer._input_emb._w,
transpose_y=True)
else:
predict = self._fc(dec_output_reshape)
if dec_inputs is None:
# Return probs for independent decoder program.
predict_out = fluid.layers.softmax(predict)
return predict_out
return predict
class TransFormer(Layer):
"""
model
"""
def __init__(self, name_scope, src_vocab_size, trg_vocab_size, max_length,
n_layer, n_head, d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd, weight_sharing,
label_smooth_eps):
super(TransFormer, self).__init__(name_scope)
self._label_smooth_eps = label_smooth_eps
self._trg_vocab_size = trg_vocab_size
if weight_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
self._wrap_encoder_layer = WrapEncoderLayer(
self.full_name(), src_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing)
self._wrap_decoder_layer = WrapDecoderLayer(
self.full_name(), trg_vocab_size, max_length, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing)
if weight_sharing:
self._wrap_decoder_layer._prepare_decoder_layer._input_emb._w = self._wrap_encoder_layer._prepare_encoder_layer._input_emb._w
def forward(self, enc_inputs, dec_inputs, label, weights):
"""
forward
:param enc_inputs:
:param dec_inputs:
:param label:
:param weights:
:return:
"""
enc_output = self._wrap_encoder_layer(enc_inputs)
predict = self._wrap_decoder_layer(dec_inputs, enc_output)
if self._label_smooth_eps:
label_out = fluid.layers.label_smooth(
label=fluid.layers.one_hot(
input=label, depth=self._trg_vocab_size),
epsilon=self._label_smooth_eps)
cost = fluid.layers.softmax_with_cross_entropy(
logits=predict,
label=label_out,
soft_label=True if self._label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = fluid.layers.reduce_sum(weighted_cost)
token_num = fluid.layers.reduce_sum(weights)
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
return sum_cost, avg_cost, predict, token_num
def train():
""" """
train models train models
:return: :return:
...@@ -1117,6 +108,7 @@ def train(): ...@@ -1117,6 +108,7 @@ def train():
if args.use_data_parallel: if args.use_data_parallel:
strategy = fluid.dygraph.parallel.prepare_context() strategy = fluid.dygraph.parallel.prepare_context()
# define model
transformer = TransFormer( transformer = TransFormer(
'transformer', ModelHyperParams.src_vocab_size, 'transformer', ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1, ModelHyperParams.trg_vocab_size, ModelHyperParams.max_length + 1,
...@@ -1127,28 +119,39 @@ def train(): ...@@ -1127,28 +119,39 @@ def train():
ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout, ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd, ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps) ModelHyperParams.weight_sharing, TrainTaskConfig.label_smooth_eps)
# define optimizer
optimizer = fluid.optimizer.SGD(learning_rate=0.003) optimizer = fluid.optimizer.Adam(learning_rate=NoamDecay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps,
TrainTaskConfig.learning_rate),
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
#
if args.use_data_parallel: if args.use_data_parallel:
transformer = fluid.dygraph.parallel.DataParallel(transformer, transformer = fluid.dygraph.parallel.DataParallel(
strategy) transformer, strategy)
reader = paddle.batch( # define data generator for training and validation
wmt16.train(ModelHyperParams.src_vocab_size, train_reader = paddle.batch(wmt16.train(
ModelHyperParams.trg_vocab_size), ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size) batch_size=TrainTaskConfig.batch_size)
if args.use_data_parallel: if args.use_data_parallel:
reader = fluid.contrib.reader.distributed_batch_reader(reader) train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
for i in range(200): val_reader = paddle.batch(wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=TrainTaskConfig.batch_size)
# loop for training iterations
for i in range(TrainTaskConfig.pass_num):
dy_step = 0 dy_step = 0
for batch in reader(): sum_cost = 0
np_values = prepare_batch_input( transformer.train()
batch, ModelHyperParams.src_pad_idx, for batch in train_reader():
ModelHyperParams.trg_pad_idx, ModelHyperParams.n_head) enc_inputs, dec_inputs, label, weights = prepare_train_input(
batch, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head)
enc_inputs, dec_inputs, label, weights = create_data(np_values)
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer( dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
enc_inputs, dec_inputs, label, weights) enc_inputs, dec_inputs, label, weights)
...@@ -1158,15 +161,35 @@ def train(): ...@@ -1158,15 +161,35 @@ def train():
transformer.apply_collective_grads() transformer.apply_collective_grads()
else: else:
dy_avg_cost.backward() dy_avg_cost.backward()
optimizer.minimize(dy_avg_cost) optimizer.minimize(dy_avg_cost)
transformer.clear_gradients() transformer.clear_gradients()
dy_step = dy_step + 1 dy_step = dy_step + 1
if dy_step % 10 == 0: if dy_step % 10 == 0:
print("pass num : {}, batch_id: {}, dy_graph avg loss: {}". print("pass num : {}, batch_id: {}, dy_graph avg loss: {}".
format(i, dy_step, dy_avg_cost.numpy())) format(i, dy_step,
print("pass : {} finished".format(i)) dy_avg_cost.numpy() * trainer_count))
# switch to evaluation mode
transformer.eval()
sum_cost = 0
token_num = 0
for batch in val_reader():
enc_inputs, dec_inputs, label, weights = prepare_train_input(
batch, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head)
dy_sum_cost, dy_avg_cost, dy_predict, dy_token_num = transformer(
enc_inputs, dec_inputs, label, weights)
sum_cost += dy_sum_cost.numpy()
token_num += dy_token_num.numpy()
print("pass : {} finished, validation avg loss: {}".format(
i, sum_cost / token_num))
if fluid.dygraph.parallel.Env().dev_id == 0:
fluid.save_dygraph(transformer.state_dict(), args.model_file)
if __name__ == '__main__': if __name__ == '__main__':
train() args = parse_args()
train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册