提交 00f3b76e 编写于 作者: G Guo Sheng 提交者: Yibing Liu

Speed up Transformer inference (#1476)

* Add py-reader and parallel-executor support in Transformer inference

* Add statick k, v cache for encoder output in Transformer inference

* Replace the cache from compute_qkv with cahce from split_heads in Transformer inference

* Fuse k, q, v projection in Transformer

* Revert the fused k, q, v projection in Transformer to be compatible with saved models

* Use gather_op to replace sequence_expand_op in Transformer inference

* Add fluid_transformer.md

* Refine README for released models and data in Transformer

* Refine README for released models and data in Transformer
上级 c9ccfe4d
......@@ -69,9 +69,9 @@ WMT 数据集是机器翻译领域公认的主流数据集,[WMT'16 EN-DE 数
└── subword-nmt # BPE 编码的代码
```
`gen_data/wmt16_ende_data_bpe` 中是我们最终使用的英德翻译数据,其中 `train.tok.clean.bpe.32000.en-de` 为训练数据,`newstest2016.tok.bpe.32000.en-de` 等为验证和测试数据,`vocab_all.bpe.32000` 为相应的词典文件(已加入 `<s>``<e>``<unk>` 这三个特殊符号,源语言和目标语言共享该词典文件)。
`gen_data/wmt16_ende_data_bpe` 中是我们最终使用的英德翻译数据,其中 `train.tok.clean.bpe.32000.en-de` 为训练数据,`newstest2016.tok.bpe.32000.en-de` 等为验证和测试数据,`vocab_all.bpe.32000` 为相应的词典文件(已加入 `<s>``<e>``<unk>` 这三个特殊符号,源语言和目标语言共享该词典文件)。另外我们也整理提供了一份处理好的 WMT'16 EN-DE 数据以供[下载](https://transformer-res.bj.bcebos.com/wmt16_ende_data_bpe_clean.tar.gz)使用(包含训练所需 BPE 数据和词典以及预测和评估所需的 BPE 数据和 tokenize 的数据)。
对于其他自定义数据,转换为类似 `train.tok.clean.bpe.32000.en-de` 的数据格式(`\t` 分隔的源语言和目标语言句子对,句子中的 token 之间使用空格分隔)即可;如需使用 BPE 编码,可参考,亦可以使用类似 WMT,使用 `gen_data.sh` 进行处理。
对于其他自定义数据,转换为类似 `train.tok.clean.bpe.32000.en-de` 的数据格式(`\t` 分隔的源语言和目标语言句子对,句子中的 token 之间使用空格分隔)即可;如需使用 BPE 编码,亦可以使用类似 WMT'16 EN-DE 原始数据的格式,参照 `gen_data.sh` 进行处理。
### 模型训练
......@@ -110,11 +110,9 @@ python -u train.py \
--batch_size 3200 \
--sort_type pool \
--pool_size 200000 \
n_layer 6 \
n_head 16 \
d_model 1024 \
d_inner_hid 4096 \
n_head 16 \
prepostprocess_dropout 0.3
```
有关这些参数更详细信息的请参考 `config.py` 中的注释说明。
......@@ -144,30 +142,53 @@ python -u infer.py \
--token_delimiter ' ' \
--batch_size 32 \
model_path trained_models/iter_100000.infer.model \
beam_size 4 \
beam_size 5 \
max_out_len 255
```
和模型训练时类似,预测时也需要设置数据和 reader 相关的参数,并可以执行 `python infer.py --help` 查看这些参数的说明(部分参数意义和训练时略有不同);同样可以在预测命令中设置模型超参数,但应与模型训练时的设置一致;此外相比于模型训练,预测时还有一些额外的参数,如需要设置 `model_path` 来给出模型所在目录,可以设置 `beam_size``max_out_len` 来指定 Beam Search 算法的搜索宽度和最大深度(翻译长度),这些参数也可以在 `config.py` 中的 `InferTaskConfig` 内查阅注释说明并进行更改设置。
和模型训练时类似,预测时也需要设置数据和 reader 相关的参数,并可以执行 `python infer.py --help` 查看这些参数的说明(部分参数意义和训练时略有不同);同样可以在预测命令中设置模型超参数,但应与模型训练时的设置一致,如训练时使用 big model 的参数设置,则预测时对应类似如下命令:
```sh
python -u infer.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \
--token_delimiter ' ' \
--batch_size 32 \
model_path trained_models/iter_100000.infer.model \
n_head 16 \
d_model 1024 \
d_inner_hid 4096 \
prepostprocess_dropout 0.3 \
beam_size 5 \
max_out_len 255
```
此外相比于模型训练,预测时还有一些额外的参数,如需要设置 `model_path` 来给出模型所在目录,可以设置 `beam_size``max_out_len` 来指定 Beam Search 算法的搜索宽度和最大深度(翻译长度),这些参数也可以在 `config.py` 中的 `InferTaskConfig` 内查阅注释说明并进行更改设置。
执行以上预测命令会打印翻译结果到标准输出,每行输出是对应行输入的得分最高的翻译。对于使用 BPE 的英德数据,预测出的翻译结果也将是 BPE 表示的数据,要还原成原始的数据(这里指 tokenize 后的数据)才能进行正确的评估,可以使用以下命令来恢复 `predict.txt` 内的翻译结果到 `predict.tok.txt` 中(无需再次 tokenize 处理):
```sh
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
```
接下来就可以使用参考翻译对翻译结果进行 BLEU 指标的评估了。以英德翻译 `newstest2016.tok.de` 数据为例,执行如下命令:
接下来就可以使用参考翻译对翻译结果进行 BLEU 指标的评估了,评估需要用到 mosesdecoder 中的脚本,可以通过以下命令获取:
```sh
git clone https://github.com/moses-smt/mosesdecoder.git
```
以英德翻译 `newstest2014.tok.de` 数据为例,获取 mosesdecoder 后使用 `multi-bleu.perl` 执行如下命令进行翻译结果评估:
```sh
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt16_ende_data/newstest2016.tok.de < predict.tok.txt
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt16_ende_data/newstest2014.tok.de < predict.tok.txt
```
可以看到类似如下的结果(为单机两卡训练 200K 个 iteration 后模型的预测结果)。
可以看到类似如下的结果
```
BLEU = 33.08, 64.2/39.2/26.4/18.5 (BP=0.994, ratio=0.994, hyp_len=61971, ref_len=62362)
BLEU = 26.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078)
```
目前在未使用 model average 的情况下,英德翻译 base model 八卡训练 100K 个 iteration 后测试 BLEU 值如下:
目前在未使用 model average 的情况下,英德翻译 base model 和 big model 八卡训练 100K 个 iteration 后测试 BLEU 值如下:
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
| BLEU | 26.25 | 29.15 | 33.64 |
| Base | 26.35 | 29.07 | 33.30 |
| Big | 27.07 | 30.09 | 34.38 |
我们这里也提供了以上 [base model](https://transformer-res.bj.bcebos.com/base_model.tar.gz)[big model](https://transformer-res.bj.bcebos.com/big_model.tar.gz) 模型的下载以供使用。
### 分布式训练
......
......@@ -164,7 +164,10 @@ input_descs = {
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32"],
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
......@@ -194,4 +197,5 @@ label_data_input_fields = (
fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"init_idx",
"trg_src_attn_bias", )
import argparse
import ast
import multiprocessing
import numpy as np
import os
from functools import partial
import paddle
import paddle.fluid as fluid
import model
import reader
from config import *
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from model import fast_decode as fast_decoder
from config import *
from train import pad_batch_data
import reader
from train import pad_batch_data, prepare_data_generator
def parse_args():
......@@ -54,6 +56,21 @@ def parse_args():
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use memory optimization.")
parser.add_argument(
"--use_py_reader",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--use_parallel_exe",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -123,106 +140,185 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
trg_word, dtype="float32").reshape(-1, 1),
place, [range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
init_idx = np.asarray(range(len(insts)), dtype="int32")
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, init_score,
trg_src_attn_bias
init_idx, trg_src_attn_bias
]))
return data_input_dict
def prepare_feed_dict_list(data_generator, count, place):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.bos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model, place)
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def py_reader_provider_wrapper(data_reader, place):
"""
Data provider needed by fluid.layers.py_reader.
"""
input_dict = dict(data_input_dict.items())
return input_dict
def py_reader_provider():
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
for batch_id, data in enumerate(data_reader()):
data_input_dict = prepare_batch_input(
data, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.bos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model, place)
yield [data_input_dict[item] for item in data_input_names]
return py_reader_provider
def fast_infer(test_data, trg_idx2word):
def fast_infer(args):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
out_ids, out_scores, pyreader = fast_decoder(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
InferTaskConfig.beam_size,
InferTaskConfig.max_out_len,
ModelHyperParams.eos_idx,
use_py_reader=args.use_py_reader)
# This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().clone(for_test=True)
out_ids, out_scores = fast_decoder(
ModelHyperParams.src_vocab_size, ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 1, ModelHyperParams.n_layer,
ModelHyperParams.n_head, ModelHyperParams.d_key,
ModelHyperParams.d_value, ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid, ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout, ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd, ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing, InferTaskConfig.beam_size,
InferTaskConfig.max_out_len, ModelHyperParams.eos_idx)
if args.use_mem_opt:
fluid.memory_optimize(infer_program)
if InferTaskConfig.use_gpu:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_vars(
exe,
InferTaskConfig.model_path,
vars=[
var for var in fluid.default_main_program().list_vars()
var for var in infer_program.list_vars()
if isinstance(var, fluid.framework.Parameter)
])
# This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().clone(for_test=True)
exec_strategy = fluid.ExecutionStrategy()
# For faster executor
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 1
build_strategy = fluid.BuildStrategy()
infer_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=infer_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
for batch_id, data in enumerate(test_data.batch_generator()):
data_input = prepare_batch_input(
data, encoder_data_input_fields + fast_decoder_data_input_fields,
ModelHyperParams.eos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
seq_ids, seq_scores = exe.run(infer_program,
feed=data_input,
fetch_list=[out_ids, out_scores],
return_numpy=False)
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(data))]
scores = [[] for i in range(len(data))]
for i in range(len(seq_ids.lod()[0]) - 1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1])
if len(hyps[i]) >= InferTaskConfig.n_best:
break
def infer(args, inferencer=fast_infer):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
token_delimiter=args.token_delimiter,
use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
# count start and end tokens out
max_length=ModelHyperParams.max_length - 2,
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
# data reader settings for inference
args.train_file_pattern = args.test_file_pattern
args.use_token_batch = False
args.sort_type = reader.SortType.NONE
args.shuffle = False
args.shuffle_batch = False
test_data = prepare_data_generator(
args,
is_test=False,
count=dev_count,
pyreader=pyreader,
py_reader_provider_wrapper=py_reader_provider_wrapper,
place=place)
if args.use_py_reader:
pyreader.start()
data_generator = None
else:
data_generator = test_data()
trg_idx2word = reader.DataReader.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word)
while True:
try:
feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
place)
if args.use_parallel_exe:
seq_ids, seq_scores = infer_exe.run(
fetch_list=[out_ids.name, out_scores.name],
feed=feed_dict_list,
return_numpy=False)
else:
seq_ids, seq_scores = exe.run(
program=infer_program,
fetch_list=[out_ids.name, out_scores.name],
feed=feed_dict_list[0]
if feed_dict_list is not None else None,
return_numpy=False,
use_program_cache=True)
seq_ids_list, seq_scores_list = [seq_ids], [
seq_scores
] if isinstance(
seq_ids, paddle.fluid.core.LoDTensor) else (seq_ids, seq_scores)
for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list):
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(seq_ids.lod()[0]) - 1)]
scores = [[] for i in range(len(seq_scores.lod()[0]) - 1)]
for i in range(len(seq_ids.lod()[0]) -
1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1])
if len(hyps[i]) >= InferTaskConfig.n_best:
break
except (StopIteration, fluid.core.EOFException):
# The data pass is over.
if args.use_py_reader:
pyreader.reset()
break
if __name__ == "__main__":
args = parse_args()
infer(args)
fast_infer(args)
......@@ -7,6 +7,43 @@ import paddle.fluid.layers as layers
from config import *
def wrap_layer_with_block(layer, block_idx):
"""
Make layer define support indicating block, by which we can add layers
to other blocks within current block. This will make it easy to define
cache among while loop.
"""
class BlockGuard(object):
"""
BlockGuard class.
BlockGuard class is used to switch to the given block in a program by
using the Python `with` keyword.
"""
def __init__(self, block_idx=None, main_program=None):
self.main_program = fluid.default_main_program(
) if main_program is None else main_program
self.old_block_idx = self.main_program.current_block().idx
self.new_block_idx = block_idx
def __enter__(self):
self.main_program.current_block_idx = self.new_block_idx
def __exit__(self, exc_type, exc_val, exc_tb):
self.main_program.current_block_idx = self.old_block_idx
if exc_type is not None:
return False # re-raise exception
return True
def layer_wrapper(*args, **kwargs):
with BlockGuard(block_idx):
return layer(*args, **kwargs)
return layer_wrapper
def position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
......@@ -35,7 +72,9 @@ def multi_head_attention(queries,
d_model,
n_head=1,
dropout_rate=0.,
cache=None):
cache=None,
gather_idx=None,
static_kv=False):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
......@@ -56,42 +95,86 @@ def multi_head_attention(queries,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
k = layers.fc(input=keys,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
v = layers.fc(input=values,
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
fc_layer = wrap_layer_with_block(
layers.fc, fluid.default_main_program().current_block()
.parent_idx) if cache is not None and static_kv else layers.fc
k = fc_layer(
input=keys,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
v = fc_layer(
input=values,
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
return q, k, v
def __split_heads(x, n_head):
def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
Reshape input tensors at the last dimension to split multi-heads
and then transpose. Specifically, transform the input tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] to the output tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
if n_head == 1:
return x
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
reshaped_q = layers.reshape(
x=queries, shape=[0, 0, n_head, d_key], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
q = layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3])
# For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search.
reshape_layer = wrap_layer_with_block(
layers.reshape,
fluid.default_main_program().current_block()
.parent_idx) if cache is not None and static_kv else layers.reshape
transpose_layer = wrap_layer_with_block(
layers.transpose,
fluid.default_main_program().current_block().
parent_idx) if cache is not None and static_kv else layers.transpose
reshaped_k = reshape_layer(
x=keys, shape=[0, 0, n_head, d_key], inplace=True)
k = transpose_layer(x=reshaped_k, perm=[0, 2, 1, 3])
reshaped_v = reshape_layer(
x=values, shape=[0, 0, n_head, d_value], inplace=True)
v = transpose_layer(x=reshaped_v, perm=[0, 2, 1, 3])
if cache is not None: # only for faster inference
if static_kv: # For encoder-decoder attention in inference
cache_k, cache_v = cache["static_k"], cache["static_v"]
# To init the static_k and static_v in cache.
# Maybe we can use condition_op(if_else) to do these at the first
# step in while loop to replace these, however it might be less
# efficient.
static_cache_init = wrap_layer_with_block(
layers.assign,
fluid.default_main_program().current_block().parent_idx)
static_cache_init(k, cache_k)
static_cache_init(v, cache_v)
else: # For decoder self-attention in inference
cache_k, cache_v = cache["k"], cache["v"]
# gather cell states corresponding to selected parent
select_k = layers.gather(cache_k, index=gather_idx)
select_v = layers.gather(cache_v, index=gather_idx)
if not static_kv:
# For self attention in inference, use cache and concat time steps.
select_k = layers.concat([select_k, k], axis=2)
select_v = layers.concat([select_v, v], axis=2)
# update cell states(caches) cached in global block
layers.assign(select_k, cache_k)
layers.assign(select_v, cache_v)
return q, select_k, select_v
return q, k, v
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
......@@ -107,8 +190,7 @@ def multi_head_attention(queries,
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_key**-0.5)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
......@@ -122,23 +204,7 @@ def multi_head_attention(queries,
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_key * n_head]), k],
axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_value * n_head]), v],
axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model,
dropout_rate)
......@@ -327,7 +393,8 @@ def decoder_layer(dec_input,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None):
cache=None,
gather_idx=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
......@@ -342,7 +409,8 @@ def decoder_layer(dec_input,
d_model,
n_head,
attention_dropout,
cache, )
cache=cache,
gather_idx=gather_idx)
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
......@@ -358,7 +426,10 @@ def decoder_layer(dec_input,
d_value,
d_model,
n_head,
attention_dropout, )
attention_dropout,
cache=cache,
gather_idx=gather_idx,
static_kv=True)
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
......@@ -393,7 +464,8 @@ def decoder(dec_input,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=None):
caches=None,
gather_idx=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
......@@ -413,7 +485,8 @@ def decoder(dec_input,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None if caches is None else caches[i])
cache=None if caches is None else caches[i],
gather_idx=gather_idx)
dec_input = dec_output
dec_output = pre_process_layer(dec_output, preprocess_cmd,
prepostprocess_dropout)
......@@ -610,7 +683,8 @@ def wrap_decoder(trg_vocab_size,
weight_sharing,
dec_inputs=None,
enc_output=None,
caches=None):
caches=None,
gather_idx=None):
"""
The wrapper assembles together all needed layers for the decoder.
"""
......@@ -646,7 +720,8 @@ def wrap_decoder(trg_vocab_size,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=caches)
caches=caches,
gather_idx=gather_idx)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
......@@ -666,9 +741,43 @@ def wrap_decoder(trg_vocab_size,
return predict
def fast_decode(
def fast_decode(src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beam_size,
max_out_len,
eos_idx,
use_py_reader=False):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
if use_py_reader:
all_inputs, reader = make_all_py_reader_inputs(data_input_names)
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(fast_decoder_data_input_fields)
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
enc_output = wrap_encoder(
src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
......@@ -682,64 +791,60 @@ def fast_decode(
preprocess_cmd,
postprocess_cmd,
weight_sharing,
beam_size,
max_out_len,
eos_idx, ):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
enc_output = wrap_encoder(
src_vocab_size, max_in_len, n_layer, n_head, d_key, d_value, d_model,
d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout,
preprocess_cmd, postprocess_cmd, weight_sharing)
start_tokens, init_scores, trg_src_attn_bias = make_all_inputs(
fast_decoder_data_input_fields)
enc_inputs, )
start_tokens, init_scores, parent_idx, trg_src_attn_bias = dec_inputs
def beam_search():
max_len = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=max_out_len)
shape=[1],
dtype=start_tokens.dtype,
value=max_out_len,
force_cpu=True)
step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0)
cond = layers.less_than(x=step_idx, y=max_len)
shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True)
cond = layers.less_than(x=step_idx, y=max_len) # default force_cpu=True
while_op = layers.While(cond)
# array states will be stored for each step.
ids = layers.array_write(
layers.reshape(start_tokens, (-1, 1)), step_idx)
scores = layers.array_write(init_scores, step_idx)
# cell states will be overwrited at each step.
# caches contains states of history steps to reduce redundant
# computation in decoder.
caches = [{
"k": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, 0, d_model],
dtype=enc_output.dtype,
value=0)
} for i in range(n_layer)]
# caches contains states of history steps in decoder self-attention
# and static encoder output projections in encoder-decoder attention
# to reduce redundant computation.
caches = [
{
"k": # for self attention
layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, n_head, 0, d_key],
dtype=enc_output.dtype,
value=0),
"v": # for self attention
layers.fill_constant_batch_size_like(
input=start_tokens,
shape=[-1, n_head, 0, d_value],
dtype=enc_output.dtype,
value=0),
"static_k": # for encoder-decoder attention
layers.create_tensor(dtype=enc_output.dtype),
"static_v": # for encoder-decoder attention
layers.create_tensor(dtype=enc_output.dtype)
} for i in range(n_layer)
]
with while_op.block():
pre_ids = layers.array_read(array=ids, i=step_idx)
pre_ids = layers.reshape(pre_ids, (-1, 1, 1))
# Since beam_search_op dosen't enforce pre_ids' shape, we can do
# inplace reshape here which actually change the shape of pre_ids.
pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True)
pre_scores = layers.array_read(array=scores, i=step_idx)
# sequence_expand can gather sequences according to lod thus can be
# used in beam search to sift states corresponding to selected ids.
pre_src_attn_bias = layers.sequence_expand(
x=trg_src_attn_bias, y=pre_scores)
pre_enc_output = layers.sequence_expand(x=enc_output, y=pre_scores)
pre_caches = [{
"k": layers.sequence_expand(
x=cache["k"], y=pre_scores),
"v": layers.sequence_expand(
x=cache["v"], y=pre_scores),
} for cache in caches]
# gather cell states corresponding to selected parent
pre_src_attn_bias = layers.gather(
trg_src_attn_bias, index=parent_idx)
pre_pos = layers.elementwise_mul(
x=layers.fill_constant_batch_size_like(
input=pre_enc_output, # cann't use pre_ids here since it has lod
input=pre_src_attn_bias, # cann't use lod tensor here
value=1,
shape=[-1, 1, 1],
dtype=pre_ids.dtype),
......@@ -761,35 +866,33 @@ def fast_decode(
postprocess_cmd,
weight_sharing,
dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias),
enc_output=pre_enc_output,
caches=pre_caches)
enc_output=enc_output,
caches=caches,
gather_idx=parent_idx)
# intra-beam topK
topk_scores, topk_indices = layers.topk(
input=layers.softmax(logits), k=beam_size)
accu_scores = layers.elementwise_add(
x=layers.log(topk_scores),
y=layers.reshape(
pre_scores, shape=[-1]),
axis=0)
# beam_search op uses lod to distinguish branches.
x=layers.log(topk_scores), y=pre_scores, axis=0)
# beam_search op uses lod to differentiate branches.
topk_indices = layers.lod_reset(topk_indices, pre_ids)
selected_ids, selected_scores = layers.beam_search(
# topK reduction across beams, also contain special handle of
# end beams and end sentences(batch reduction)
selected_ids, selected_scores, gather_idx = layers.beam_search(
pre_ids=pre_ids,
pre_scores=pre_scores,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=eos_idx)
end_id=eos_idx,
return_parent_idx=True)
layers.increment(x=step_idx, value=1.0, in_place=True)
# update states
# cell states(caches) have been updated in wrap_decoder,
# only need to update beam search states here.
layers.array_write(selected_ids, i=step_idx, array=ids)
layers.array_write(selected_scores, i=step_idx, array=scores)
layers.assign(gather_idx, parent_idx)
layers.assign(pre_src_attn_bias, trg_src_attn_bias)
layers.assign(pre_enc_output, enc_output)
for i in range(n_layer):
layers.assign(pre_caches[i]["k"], caches[i]["k"])
layers.assign(pre_caches[i]["v"], caches[i]["v"])
length_cond = layers.less_than(x=step_idx, y=max_len)
finish_cond = layers.logical_not(layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
......@@ -799,4 +902,4 @@ def fast_decode(
return finished_ids, finished_scores
finished_ids, finished_scores = beam_search()
return finished_ids, finished_scores
return finished_ids, finished_scores, reader if use_py_reader else None
......@@ -186,7 +186,7 @@ def main(args):
# Since the token number differs among devices, customize gradient scale to
# use token average cost among multi-devices. and the gradient scale is
# `1 / token_number` for average cost.
build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
# build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
loss_name=avg_cost.name,
......
......@@ -10,7 +10,6 @@ import time
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.transpiler.details import program_to_code
import reader
from config import *
......@@ -258,7 +257,12 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
return data_input_dict, np.asarray([num_token], dtype="float32")
def prepare_data_generator(args, is_test, count, pyreader):
def prepare_data_generator(args,
is_test,
count,
pyreader,
py_reader_provider_wrapper,
place=None):
"""
Data generator wrapper for DataReader. If use py_reader, set the data
provider for py_reader
......@@ -319,7 +323,7 @@ def prepare_data_generator(args, is_test, count, pyreader):
data_reader = split(data_reader, count)
if args.use_py_reader:
pyreader.decorate_tensor_provider(
py_reader_provider_wrapper(data_reader))
py_reader_provider_wrapper(data_reader, place))
data_reader = None
else: # Data generator for multi-devices
data_reader = stack(data_reader, count)
......@@ -357,7 +361,7 @@ def prepare_feed_dict_list(data_generator, init_flag, count):
return feed_dict_list if len(feed_dict_list) == count else None
def py_reader_provider_wrapper(data_reader):
def py_reader_provider_wrapper(data_reader, place):
"""
Data provider needed by fluid.layers.py_reader.
"""
......@@ -370,8 +374,7 @@ def py_reader_provider_wrapper(data_reader):
data, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.eos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
total_dict = dict(data_input_dict.items())
yield [total_dict[item] for item in data_input_names]
yield [data_input_dict[item] for item in data_input_names]
return py_reader_provider
......@@ -406,7 +409,11 @@ def test_context(exe, train_exe, dev_count):
is_test=True)
test_prog = test_prog.clone(for_test=True)
test_data = prepare_data_generator(
args, is_test=True, count=dev_count, pyreader=pyreader)
args,
is_test=True,
count=dev_count,
pyreader=pyreader,
py_reader_provider_wrapper=py_reader_provider_wrapper)
exe.run(startup_prog) # to init pyreader for testing
if TrainTaskConfig.ckpt_path:
......@@ -477,7 +484,11 @@ def train_loop(exe,
logging.info("begin reader")
train_data = prepare_data_generator(
args, is_test=False, count=dev_count, pyreader=pyreader)
args,
is_test=False,
count=dev_count,
pyreader=pyreader,
py_reader_provider_wrapper=py_reader_provider_wrapper)
# For faster executor
exec_strategy = fluid.ExecutionStrategy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册