未验证 提交 497faa61 编写于 作者: J jiangbojian 提交者: GitHub

Mal (#4172)

* EMNLP2019-MAL

* EMNLP2019-MAL, update README.md

* EMNLP2019-MAL, update README.md

* EMNLP2019-MAL, update license
上级 a7ef0aef
# Multi-agent Learning for Neural Machine Translation(MAL)
## 简介
MAL是百度翻译团队近期提出的首个多智能体端到端联合学习框架,该框架显著提升了单智能体学习能力,在多个机器翻译测试集上刷新了当前最好结果。 该框架投稿并被EMNLP2019录用 [Multi-agent Learning for Neural Machine Translation](https://www.aclweb.org/anthology/D19-1079.pdf)。 具体结构如下:
<p align="center">
<img src="images/arch.png" width = "340" height = "300" /> <br />
MAL整体框架
</p>
这个repo包含了PaddlePaddle版本的MAL实现,框架在论文的基础上做了一些修改,在WMT英德2014测试集上BLEU达到30.04,超过了论文中的结果,在不改变模型结构的基础上,刷新了SOTA。
### 实验结果
#### WMT 英德
| Models | En-De |
| :------------- | :---------: |
| [ConvS2S](https://pdfs.semanticscholar.org/bb3e/bc09b65728d6eced04929df72a006fb5210b.pdf) | 25.20 |
| [Transformer](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) | 28.40 |
| [Rel](https://www.aclweb.org/anthology/N18-2074.pdf) | 29.20 |
| [DynamicConv](https://openreview.net/pdf?id=SkVhlh09tX) | 29.70 |
| L2R | 28.88 |
| MAL-L2R | **30.04** |
## 运行
### 环境
运行环境需要满足如下要求:
+ python 2.7
+ paddlepaddle-gpu (1.6.1)
+ CUDA, CuDNN and NCCL (CUDA 9.0, CuDNN v7 and NCCL 2.3.5)
WMT英德的实验结果复现需要56张 32G V100, 运行30W步左右。
### 数据准备
运行get_data.sh脚本拉取原始数据并做预处理,形成训练需要的文件格式
```
sh get_data.sh
```
### 模型运行
在运行前,需要配置CUDA, CuDNN, NCCL的路径,具体路径修改在env/env.sh
调用train.sh运行MAL,产出的模型在output下,模型会边训练,边预测,针对训练过程中解码出来的文件,可以调用evaluate.sh来测BLEU
在train.sh中有个参数是distributed_args,这里需要使用者根据自身机器的情况来改变,需要修改的有nproc_per_node和selected_gpus,nproc_per_node代表每台机器需要使用几张卡,selected_gpus为gpu的卡号,例如一台8卡的v100,使用8张卡跑训练,那么nproc_per_node设置为8,selected_gpus为0, 1, 2, 3, 4, 5, 6, 7
```
sh train.sh ip1,ip2,ip3...(机器的ip地址,不要写127.0.0.1,填写hostname -i的结果) &
sh evaluate.sh file_path(预测出的文件,在output路径下)
```
### 复现论文中结果
我们提供了MAL在英德任务上训练出的模型,调用infer.sh可以观察到最终结果(因为测试集需要提前生成,所以在调用infer.sh前,请先调用get_data.sh,同时也需要设置好CUDA, CuDNN路径)
```
sh infer.sh
```
### 代码结构
我们主要的代码均在src文件夹中
train.py 训练的入口文件
infer.py 模型预测入口
config.py 定义了该项目模型的相关配置,包括具体模型类别、以及模型的超参数
reader.py 定义了读入数据的功能
bleu_hook.py BLEU计算脚本
####
################################## User Define Configuration ###########################
################################## Data Configuration ##################################
#type of storage cluster
#storage_type = "afs"
#attention: files for training should be put on hdfs
##the list contains all file locations should be specified here
#fs_name = "afs://xingtian.afs.baidu.com:9902"
##If force_reuse_output_path is True ,paddle will remove output_path without check output_path exist
#force_reuse_output_path = "True"
##ugi of hdfs
#fs_ugi = "NLP_KM_Data,NLP_km_2018"
#the initial model path on hdfs used to init parameters
#init_model_path=
#the initial model path for pservers
#pserver_model_dir=
#which pass
#pserver_model_pass=
#example of above 2 args:
#if set pserver_model_dir to /app/paddle/models
#and set pserver_model_pass to 123
#then rank 0 will download model from /app/paddle/models/rank-00000/pass-00123/
#and rank 1 will download model from /app/paddle/models/rank-00001/pass-00123/, etc.
##train data path on hdfs
#train_data_path = "/user/NLP_KM_Data/gongweibao/transformer/paddle_training_data/train_data"
##test data path on hdfs, can be null or not setted
#test_data_path = "/app/inf/mpi/bml-guest/paddle-platform/dataset/mnist/data/test/"
#the output directory on hdfs
#output_path = "/user/NLP_KM_Data/gongweibao/transformer/output"
#add datareader to thirdparty
#thirdparty_path = "/user/NLP_KM_Data/gongweibao/transformer/thirdparty"
FLAGS_rpc_deadline=3000000
#whl_name=paddlepaddle_ab57d3_post97_gpu-0.0.0-cp27-cp27mu-linux_x86_64.whl
#dataset_path=/user/NLP_KM_Data/gongweibao/transformer/small/paddle_training_data
PROFILE=0
FUSE=1
NCCL_COMM_NUM=2
NUM_THREADS=3
USE_HIERARCHICAL_ALLREDUCE=True
NUM_CARDS=8
NUM_EPOCHS=100
BATCH_SIZE=4096
#!/bin/bash
export BASE_PATH="$PWD"
#NCCL
export NCCL_DEBUG=INFO
export NCCL_IB_GID_INDEX=3
#export NCCL_IB_RETRY_CNT=0
#PADDLE
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=0.0
#Cudnn
#export FLAGS_cudnn_exhaustive_search=1
export LD_LIBRARY_PATH=/home/work/cuda-9.0/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/home/work/cudnn/cudnn_v7/cuda/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH="${BASE_PATH}/nccl_2.3.5/lib/:$LD_LIBRARY_PATH"
#proxy
unset https_proxy http_proxy
# GLOG
export GLOG_v=1
#export GLOG_vmodule=fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,threaded_ssa_graph_executor=10,backward_op_deps_pass=10,graph=10
export GLOG_logtostderr=1
#!/bin/bash
set -u
function check_iplist() {
if [ ${iplist:-} ]; then
#paddle envs
export PADDLE_PSERVER_PORT=9184
export PADDLE_TRAINER_IPS=${iplist}
#export PADDLE_CURRENT_IP=`/sbin/ip a | grep inet | grep global | awk '{print $2}' | sed 's/\/[0-9][0-9].*$//g'`
export PADDLE_CURRENT_IP=`hostname -i`
iparray=(${iplist//,/ })
for i in "${!iparray[@]}"; do
echo $i
if [ ${iparray[$i]} == ${PADDLE_CURRENT_IP} ]; then
export PADDLE_TRAINER_ID=$i
fi
done
export TRAINING_ROLE=TRAINER
#export PADDLE_PSERVERS=127.0.0.1
export PADDLE_INIT_TRAINER_COUNT=${#iparray[@]}
export PADDLE_PORT=${PADDLE_PSERVER_PORT}
export PADDLE_TRAINERS=${PADDLE_TRAINER_IPS}
export POD_IP=${PADDLE_CURRENT_IP}
export PADDLE_TRAINERS_NUM=${PADDLE_INIT_TRAINER_COUNT}
#is local
export PADDLE_IS_LOCAL=0
echo "****************************************************"
#paddle debug envs
export GLOG_v=0
export GLOG_logtostderr=1
#nccl debug envs
export NCCL_DEBUG=INFO
#export NCCL_IB_DISABLE=1
#export NCCL_IB_GDR_LEVEL=4
export NCCL_IB_GID_INDEX=3
#export NCCL_SOCKET_IFNAME=eth2
fi
}
#! /bin/sh
path=$1
python ./src/id2word.py data/vocab.source.32000 < ${path} > ${path}_word
head -n 3003 ${path}_word > ${path}_word_tmp
mv ${path}_word_tmp ${path}_word
cat ${path}_word | sed 's/@@ //g' > ${path}.trans.post
python ./src/bleu_hook.py --reference wmt16_en_de/newstest2014.tok.de --translation ${path}.trans.post
#! /bin/sh
tmp_dir=wmt16_en_de
data_dir=data
source_file=train.tok.clean.bpe.32000.en
target_file=train.tok.clean.bpe.32000.de
source_vocab_size=32000
target_vocab_size=32000
num_shards=100
if [ ! -d wmt16_en_de ]
then
mkdir wmt16_en_de
fi
wget https://baidu-nlp.bj.bcebos.com/EMNLP2019-MAL/wmt16_en_de.tar.gz -O wmt16_en_de/wmt16_en_de.tar.gz
tar -zxf wmt16_en_de/wmt16_en_de.tar.gz -C wmt16_en_de
if [ ! -d $data_dir ]
then
mkdir data
fi
if [ ! -d testset ]
then
mkdir testset
fi
cp wmt16_en_de/vocab.bpe.32000 data/vocab.source.32000
python ./src/gen_records.py --tmp_dir ${tmp_dir} --data_dir ${data_dir} --source_train_files ${source_file} --target_train_files ${target_file} --source_vocab_size ${source_vocab_size} --target_vocab_size ${target_vocab_size} --num_shards ${num_shards} --token True --onevocab True
python ./src/preprocess/gen_utils.py --vocab $data_dir/vocab.source.${source_vocab_size} --testset ${tmp_dir}/newstest2014.tok.bpe.32000.en --output ./testset/testfile
#! /bin/sh
export LD_LIBRARY_PATH=/home/work/cuda-9.0/lib64:/home/work/cudnn/cudnn_v7/cuda/lib64:/home/work/cuda-9.0/extras/CUPTI/lib64:$LD_LIBRARY_PATH
wget https://baidu-nlp.bj.bcebos.com/EMNLP2019-MAL/checkpoint.best.tgz
tar -zxf checkpoint.best.tgz
infer(){
CUDA_VISIBLE_DEVICES=$1 python -u src/infer.py \
--val_file_pattern $3 \
--vocab_size $4 \
--special_token '<s>' '<e>' '<unk>' \
--use_mem_opt True \
--use_delay_load True \
--infer_batch_size 16 \
--decode_alpha 0.3 \
d_model 1024 \
d_inner_hid 4096 \
n_head 16 \
prepostprocess_dropout 0.0 \
attention_dropout 0.0 \
relu_dropout 0.0 \
model_path $2 \
beam_size 4 \
max_out_len 306 \
max_length 256
}
infer 0 checkpoint.best testset/testfile 37007
sh evaluate.sh trans/forward_checkpoint.best
grep "BLEU_cased" trans/*
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import argparse
def str2bool(v):
"""
because argparse does not support to parse "true, False" as python
boolean directly
"""
return v.lower() in ("true", "t", "1")
class ArgumentGroup(object):
"""
ArgumentGroup
"""
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, positional_arg=False, **kwargs):
"""
add_arg
"""
prefix = "" if positional_arg else "--"
type = str2bool if type == bool else type
self._group.add_argument(
prefix + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def print_arguments(args):
"""
print_arguments
"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def inv_arguments(args):
"""
inv_arguments
"""
print('[Warning] Only keyword argument type is supported.')
args_list = []
for arg, value in sorted(six.iteritems(vars(args))):
args_list.extend(['--' + str(arg), str(value)])
return args_list
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layer_helper import LayerHelper as LayerHelper
def generate_relative_positions_matrix(length, max_relative_position, cache=False):
if not cache:
range_vec = layers.range(0, length, 1, 'int32')
range_vec.stop_gradient = True
shapes = layers.shape(range_vec)
range_vec = layers.reshape(range_vec, shape=[1, shapes[0]])
range_mat = layers.expand(range_vec, [shapes[0], 1])
distance_mat = range_mat - layers.transpose(range_mat, [1, 0])
else:
distance_mat = layers.range(-1 * length+1, 1, 1, 'int32')
distance_mat.stop_gradient = True
shapes = layers.shape(distance_mat)
distance_mat = layers.reshape(distance_mat, [1, shapes[0]])
distance_mat_clipped = layers.clip(layers.cast(distance_mat, dtype="float32"), float(-max_relative_position), float(max_relative_position))
final_mat = layers.cast(distance_mat_clipped, dtype = 'int32') + max_relative_position
return final_mat
def generate_relative_positions_embeddings(length, depth, max_relative_position, name, cache=False):
relative_positions_matrix = generate_relative_positions_matrix(
length, max_relative_position, cache=cache)
y = layers.reshape(relative_positions_matrix, [-1])
y.stop_gradient = True
vocab_size = max_relative_position * 2 + 1
#embeddings_table = layers.create_parameter(shape=[vocab_size, depth], dtype='float32', default_initializer=fluid.initializer.Constant(1.2345), name=name)
embeddings_table = layers.create_parameter(shape=[vocab_size, depth], dtype='float32', name=name)
#layers.Print(embeddings_table, message = "embeddings_table=====")
embeddings_1 = layers.gather(embeddings_table, y)
embeddings = layers.reshape(embeddings_1, [-1, length, depth])
return embeddings
def _relative_attention_inner(q, k, v, transpose):
batch_size = layers.shape(q)[0]
heads = layers.shape(q)[1]
length = layers.shape(q)[2]
xy_matmul = layers.matmul(q, k, transpose_y=transpose)
x_t = layers.transpose(q, [2, 0, 1, 3])
x_t_r = layers.reshape(x_t, [length, batch_size * heads, -1])
x_tz_matmul = layers.matmul(x_t_r, v, transpose_y = transpose)
x_tz_matmul_r = layers.reshape(x_tz_matmul, [length, batch_size, heads, -1])
x_tz_matmul_r_t = layers.transpose(x_tz_matmul_r, [1, 2, 0, 3])
return xy_matmul + x_tz_matmul_r_t
def _dot_product_relative(q, k, v, bias, dropout=0.1, cache=None, params_type="normal"):
depth_constant = int(k.shape[3])
heads = layers.shape(k)[1]
length = layers.shape(k)[2]
max_relative_position = 4
pre_name = "relative_positions_"
if params_type == "fixed":
pre_name = "fixed_relative_positions_"
elif params_type == "new":
pre_name = "new_relative_positions_"
relations_keys = generate_relative_positions_embeddings(
length, depth_constant, max_relative_position, name=pre_name + "keys",
cache=cache is not None)
relations_values = generate_relative_positions_embeddings(
length, depth_constant, max_relative_position,
name = pre_name + "values",
cache=cache is not None)
logits = _relative_attention_inner(q, k, relations_keys, True)
if bias is not None: logits += bias
weights = layers.softmax(logits, name = "attention_weights")
weights = layers.dropout(weights, dropout_prob=float(dropout))
output = _relative_attention_inner(weights, v, relations_values, False)
return output
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key ** -0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
out = layers.matmul(weights, v)
return out
if __name__ == "__main__":
batch_size = 2
heads = 8
length = 5
depth = 3
cpu = fluid.core.CPUPlace()
exe = fluid.Executor(cpu)
startup_prog = fluid.Program()
train_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard("forward"):
x = layers.reshape(layers.cast(layers.range(0, 18, 1, "int32"), dtype = "float32"), shape =[-1, 3, 3])
y = layers.reshape(layers.cast(layers.range(0, 2, 1, "int32"), dtype = "float32"), shape =[-1, 1])
z = x * y
exe.run(startup_prog)
outs = exe.run(train_prog, fetch_list=[x, y, z])
print outs[0]
print outs[1]
print outs[2]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers as layers
INF = 1. * 1e9
class BeamSearch(object):
"""
beam_search class
"""
def __init__(self, beam_size, batch_size, alpha, vocab_size, hidden_size):
self.beam_size = beam_size
self.batch_size = batch_size
self.alpha = alpha
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.gather_top2k_append_index = layers.range(0, 2 * self.batch_size * beam_size, 1, 'int64') // \
(2 * self.beam_size) * (self.beam_size)
self.gather_topk_append_index = layers.range(0, self.batch_size * beam_size, 1, 'int64') // \
self.beam_size * (2 * self.beam_size)
self.gather_finish_topk_append_index = layers.range(0, self.batch_size * beam_size, 1, 'int64') // \
self.beam_size * (3 * self.beam_size)
self.eos_id = layers.fill_constant([self.batch_size, 2 * self.beam_size], 'int64', value=1)
self.get_alive_index = layers.range(0, self.batch_size, 1, 'int64') * self.beam_size
def gather_cache(self, kv_caches, select_id):
"""
gather cache
"""
for index in xrange(len(kv_caches)):
kv_cache = kv_caches[index]
select_k = layers.gather(kv_cache['k'], [select_id])
select_v = layers.gather(kv_cache['v'], [select_id])
layers.assign(select_k, kv_caches[index]['k'])
layers.assign(select_v, kv_caches[index]['v'])
# topk_seq, topk_scores, topk_log_probs, topk_finished, cache
def compute_topk_scores_and_seq(self, sequences, scores, scores_to_gather, flags, pick_finish=False, cache=None):
"""
compute_topk_scores_and_seq
"""
topk_scores, topk_indexes = layers.topk(scores, k=self.beam_size) #[batch_size, beam_size]
if not pick_finish:
flat_topk_indexes = layers.reshape(topk_indexes, [-1]) + self.gather_topk_append_index
flat_sequences = layers.reshape(sequences, [2 * self.batch_size * self.beam_size, -1])
else:
flat_topk_indexes = layers.reshape(topk_indexes, [-1]) + self.gather_finish_topk_append_index
flat_sequences = layers.reshape(sequences, [3 * self.batch_size * self.beam_size, -1])
topk_seq = layers.gather(flat_sequences, [flat_topk_indexes])
topk_seq = layers.reshape(topk_seq, [self.batch_size, self.beam_size, -1])
flat_flags = layers.reshape(flags, [-1])
topk_flags = layers.gather(flat_flags, [flat_topk_indexes])
topk_flags = layers.reshape(topk_flags, [-1, self.beam_size])
flat_scores = layers.reshape(scores_to_gather, [-1])
topk_gathered_scores = layers.gather(flat_scores, [flat_topk_indexes])
topk_gathered_scores = layers.reshape(topk_gathered_scores, [-1, self.beam_size])
if cache:
self.gather_cache(cache, flat_topk_indexes)
return topk_seq, topk_gathered_scores, topk_flags, cache
def grow_topk(self, i, logits, alive_seq, alive_log_probs, cache, enc_output, enc_bias):
"""
grow_topk
"""
logits = layers.reshape(logits, [self.batch_size, self.beam_size, -1])
candidate_log_probs = layers.log(layers.softmax(logits, axis=2))
log_probs = candidate_log_probs + layers.unsqueeze(alive_log_probs, axes=[2])
base_1 = layers.cast(i, 'float32') + 6.0
base_1 /= 6.0
length_penalty = layers.pow(base_1, self.alpha)
#length_penalty = layers.pow(((5.0 + layers.cast(i+1, 'float32')) / 6.0), self.alpha)
curr_scores = log_probs / length_penalty
flat_curr_scores = layers.reshape(curr_scores, [self.batch_size, self.beam_size * self.vocab_size])
topk_scores, topk_ids = layers.topk(flat_curr_scores, k=self.beam_size * 2)
topk_log_probs = topk_scores * length_penalty
select_beam_index = topk_ids // self.vocab_size
select_id = topk_ids % self.vocab_size
#layers.Print(select_id, message="select_id", summarize=1024)
#layers.Print(topk_scores, message="topk_scores", summarize=10000000)
flat_select_beam_index = layers.reshape(select_beam_index, [-1]) + self.gather_top2k_append_index
topk_seq = layers.gather(alive_seq, [flat_select_beam_index])
topk_seq = layers.reshape(topk_seq, [self.batch_size, 2 * self.beam_size, -1])
#concat with current ids
topk_seq = layers.concat([topk_seq, layers.unsqueeze(select_id, axes=[2])], axis=2)
topk_finished = layers.cast(layers.equal(select_id, self.eos_id), 'float32')
#gather cache
self.gather_cache(cache, flat_select_beam_index)
#topk_seq: [batch_size, 2*beam_size, i+1]
#topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size]
return topk_seq, topk_log_probs, topk_scores, topk_finished, cache
def grow_alive(self, curr_seq, curr_scores, curr_log_probs, curr_finished, cache):
"""
grow_alive
"""
finish_float_flag = layers.cast(curr_finished, 'float32')
finish_float_flag = finish_float_flag * -INF
curr_scores += finish_float_flag
return self.compute_topk_scores_and_seq(curr_seq, curr_scores,
curr_log_probs, curr_finished, cache=cache)
def grow_finished(self, i, finished_seq, finished_scores, finished_flags, curr_seq,
curr_scores, curr_finished):
"""
grow_finished
"""
finished_seq = layers.concat([finished_seq,
layers.fill_constant([self.batch_size, self.beam_size, 1], dtype='int64', value=0)],
axis=2)
curr_scores = curr_scores + (1.0 - layers.cast(curr_finished, 'int64')) * -INF
curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1)
curr_finished_scores = layers.concat([finished_scores, curr_scores], axis=1)
curr_finished_flags = layers.concat([finished_flags, curr_finished], axis=1)
return self.compute_topk_scores_and_seq(curr_finished_seq, curr_finished_scores,
curr_finished_scores, curr_finished_flags,
pick_finish=True)
def inner_func(self, i, logits, alive_seq, alive_log_probs, finished_seq, finished_scores,
finished_flags, cache, enc_output, enc_bias):
"""
inner_func
"""
topk_seq, topk_log_probs, topk_scores, topk_finished, cache = self.grow_topk(
i, logits, alive_seq, alive_log_probs, cache, enc_output, enc_bias)
alive_seq, alive_log_probs, _, cache = self.grow_alive(
topk_seq, topk_scores, topk_log_probs, topk_finished, cache)
#layers.Print(alive_seq, message="alive_seq", summarize=1024)
finished_seq, finished_scores, finished_flags, _ = self.grow_finished(
i, finished_seq, finished_scores, finished_flags, topk_seq, topk_scores, topk_finished)
return alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, cache
def is_finished(self, step_idx, source_length, alive_log_probs, finished_scores, finished_in_finished):
"""
is_finished
"""
base_1 = layers.cast(source_length, 'float32') + 55.0
base_1 /= 6.0
max_length_penalty = layers.pow(base_1, self.alpha)
flat_alive_log_probs = layers.reshape(alive_log_probs, [-1])
lower_bound_alive_scores_1 = layers.gather(flat_alive_log_probs, [self.get_alive_index])
lower_bound_alive_scores = lower_bound_alive_scores_1 / max_length_penalty
lowest_score_of_finished_in_finish = layers.reduce_min(finished_scores * finished_in_finished, dim=1)
finished_in_finished = layers.cast(finished_in_finished, 'bool')
lowest_score_of_finished_in_finish += \
((1.0 - layers.cast(layers.reduce_any(finished_in_finished, 1), 'float32')) * -INF)
#print lowest_score_of_finished_in_finish
bound_is_met = layers.reduce_all(layers.greater_than(lowest_score_of_finished_in_finish,
lower_bound_alive_scores))
decode_length = source_length + 50
length_cond = layers.less_than(x=step_idx, y=decode_length)
return layers.logical_and(x=layers.logical_not(bound_is_met), y=length_cond)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
import collections
import math
import os
import re
import sys
import time
import unicodedata
# Dependency imports
import numpy as np
import six
from six.moves import range
from six.moves import zip
from preprocess import text_encoder
def _get_ngrams(segment, max_order):
"""Extracts all n-grams up to a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams up to max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus,
translation_corpus,
max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""
reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []
for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams(references, max_order)
translation_ngram_counts = _get_ngrams(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram)-1] += translation_ngram_counts[ngram]
precisions = [0] * max_order
smooth = 1.0
for i in range(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0
if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum / max_order)
if use_bp:
ratio = (translation_length + 1e-6) / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp
return np.float32(bleu)
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
"""
get unicode of specified chars
"""
return "".join(six.unichr(x) for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))
uregex = UnicodeRegex()
def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/"
"blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a number (e.g. a year) followed by a dot at the end of sentence
is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
return string.split()
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
"""Compute BLEU for two files (reference and hypothesis translation)."""
ref_lines = text_encoder.native_to_unicode(
open(ref_filename, "r").read()).splitlines()
hyp_lines = text_encoder.native_to_unicode(
open(hyp_filename, "r").read()).splitlines()
assert len(ref_lines) == len(hyp_lines)
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return compute_bleu(ref_tokens, hyp_tokens)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Calc BLEU.")
parser.add_argument(
"--reference",
type=str,
required=True,
help="path of reference.")
parser.add_argument(
"--translation",
type=str,
required=True,
help="path of translation.")
args = parser.parse_args()
bleu_uncased = 100 * bleu_wrapper(args.reference, args.translation, case_sensitive=False)
bleu_cased = 100 * bleu_wrapper(args.reference, args.translation, case_sensitive=True)
f = open("%s.bleu" % args.translation, 'w')
f.write("BLEU_uncased = %6.2f\n" % (bleu_uncased))
f.write("BLEU_cased = %6.2f\n" % (bleu_cased))
f.close()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainTaskConfig(object):
"""
TrainTaskConfig
"""
# support both CPU and GPU now.
use_gpu = True
# the epoch number to train.
pass_num = 30
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 4.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
# the directory for saving trained models.
model_dir = "trained_models"
# the directory for saving checkpoints.
ckpt_dir = "trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path = None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step = 0
# the frequency to save trained models.
save_freq = 5000
# the frequency to copy unfixed parameters to fixed parameters
fixed_freq = 50000
beta = 0.7
class InferTaskConfig(object):
"""
InferTaskConfig
"""
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 10
# the parameters for beam search.
beam_size = 5
max_out_len = 256
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = True
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
decode_alpha = 0.6
class ModelHyperParams(object):
"""
ModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences deciding the size of position encoding table.
max_length = 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 1024
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 4096
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 16
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE.
dropout_seed = None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
embedding_sharing = True
class DenseModelHyperParams(object):
"""
DenseModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 37007
# size of target word dictionay
trg_vocab_size = 37007
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences deciding the size of position encoding table.
max_length = 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
enc_n_layer = 25
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE.
dropout_seed = None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
embedding_sharing = True
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
"dense_src_slf_attn_bias": [(batch_size, DenseModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
"reverse_trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
"dense_trg_slf_attn_bias": [(batch_size, DenseModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
"dense_trg_src_attn_bias": [(batch_size, DenseModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
"reverse_lbl_word": [(batch_size * seq_len, 1), "int64"],
"eos_position": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32"],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
"decode_length": [(batch_size, ), "int64"],
}
# Names of word embedding table which might be reused for weight sharing.
dense_word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
dense_pos_enc_param_names = (
"dense_src_pos_enc_table",
"dense_trg_pos_enc_table", )
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
# separated inputs for different usages.
dense_encoder_data_input_fields = (
"src_word",
"src_pos",
"dense_src_slf_attn_bias", )
decoder_data_input_fields = (
"trg_word",
"reverse_trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
dense_decoder_data_input_fields = (
"trg_word",
"reverse_trg_word",
"trg_pos",
"dense_trg_slf_attn_bias",
"dense_trg_src_attn_bias",
"enc_output", )
label_data_input_fields = (
"lbl_word",
"lbl_weight",
"reverse_lbl_word",
"eos_position")
dense_bias_input_fields = (
"dense_src_slf_attn_bias",
"dense_trg_slf_attn_bias",
"dense_trg_src_attn_bias")
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"dense_src_slf_attn_bias", )
fast_decoder_data_input_fields = (
"decode_length", )
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layer_helper import LayerHelper as LayerHelper
from config import *
from beam_search import BeamSearch
INF = 1. * 1e5
def layer_norm(x, begin_norm_axis=1, epsilon=1e-6, param_attr=None, bias_attr=None):
"""
layer_norm
"""
helper = LayerHelper('layer_norm', **locals())
mean = layers.reduce_mean(x, dim=begin_norm_axis, keep_dim=True)
shift_x = layers.elementwise_sub(x=x, y=mean, axis=0)
variance = layers.reduce_mean(layers.square(shift_x), dim=begin_norm_axis, keep_dim=True)
r_stdev = layers.rsqrt(variance + epsilon)
norm_x = layers.elementwise_mul(x=shift_x, y=r_stdev, axis=0)
param_shape = [reduce(lambda x, y: x * y, norm_x.shape[begin_norm_axis:])]
param_dtype = norm_x.dtype
scale = helper.create_parameter(
attr=param_attr,
shape=param_shape,
dtype=param_dtype,
default_initializer=fluid.initializer.Constant(1.))
bias = helper.create_parameter(
attr=bias_attr,
shape=param_shape,
dtype=param_dtype,
is_bias=True,
default_initializer=fluid.initializer.Constant(0.))
out = layers.elementwise_mul(x=norm_x, y=scale, axis=-1)
out = layers.elementwise_add(x=out, y=bias, axis=-1)
return out
def dense_position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels = d_pos_vec
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(np.arange(
num_timescales) * -log_timescale_increment)
#num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
attention_type="dot_product",):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
k = layers.fc(input=keys,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
v = layers.fc(input=values,
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
if n_head == 1:
return x
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key ** -0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
seed=DenseModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = layers.concat([cache['k'], k], axis=1)
v = layers.concat([cache['v'], v], axis=1)
layers.assign(k, cache['k'])
layers.assign(v, cache['v'])
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, #d_model,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
bias_attr=False,
num_flatten_dims=2)
return proj_out
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
seed=DenseModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
epsilon=1e-6,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
seed=DenseModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def prepare_encoder_decoder(src_word,
src_pos,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0.,
word_emb_param_name=None,
training=True,
pos_enc_param_name=None,
is_src=True,
params_type="normal"):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
assert params_type == "fixed" or params_type == "normal" or params_type == "new"
pre_name = "densedense"
if params_type == "fixed":
pre_name = "fixed_densefixed_dense"
elif params_type == "new":
pre_name = "new_densenew_dense"
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=DenseModelHyperParams.bos_idx, # set embedding of bos to 0
param_attr=fluid.ParamAttr(
name = pre_name + word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim ** -0.5)))#, is_sparse=True)
if not is_src and training:
src_word_emb = layers.pad(src_word_emb, [0, 0, 1, 0, 0, 0])
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim ** 0.5)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
trainable=False, name = pre_name + pos_enc_param_name))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
return layers.dropout(
enc_input,
dropout_prob=dropout_rate,
seed=DenseModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train') if dropout_rate else enc_input
prepare_encoder = partial(
prepare_encoder_decoder, pos_enc_param_name="src_pos_enc_table", is_src=True)
prepare_decoder = partial(
prepare_encoder_decoder, pos_enc_param_name="trg_pos_enc_table", is_src=False)
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(enc_input, preprocess_cmd,
prepostprocess_dropout), None, None, attn_bias, d_key,
d_value, d_model, n_head, attention_dropout)
attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
prepostprocess_dropout)
ffd_output = positionwise_feed_forward(
pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
d_inner_hid, d_model, relu_dropout)
return post_process_layer(attn_output, ffd_output, postprocess_cmd,
prepostprocess_dropout)
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
stack_layer_norm = []
bottom_embedding_output = pre_process_layer(enc_input, preprocess_cmd, prepostprocess_dropout)
stack_layer_norm.append(bottom_embedding_output)
#zeros = layers.zeros_like(enc_input)
#ones_flag = layers.equal(zeros, zeros)
#ones = layers.cast(ones_flag, 'float32')
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd, )
enc_output_2 = pre_process_layer(enc_output, preprocess_cmd, prepostprocess_dropout)
stack_layer_norm.append(enc_output_2)
pre_output = bottom_embedding_output
for index in xrange(1, len(stack_layer_norm)):
pre_output = pre_output + stack_layer_norm[index]
# pre_mean
enc_input = pre_output / len(stack_layer_norm)
enc_output = pre_process_layer(enc_output, preprocess_cmd,
prepostprocess_dropout)
return enc_output
def decoder_layer(dec_input,
enc_output,
slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output = multi_head_attention(
pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
None,
None,
slf_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache, )
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
postprocess_cmd,
prepostprocess_dropout, )
enc_attn_output = multi_head_attention(
pre_process_layer(slf_attn_output, preprocess_cmd,
prepostprocess_dropout),
enc_output,
enc_output,
dec_enc_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
postprocess_cmd,
prepostprocess_dropout, )
ffd_output = positionwise_feed_forward(
pre_process_layer(enc_attn_output, preprocess_cmd,
prepostprocess_dropout),
d_inner_hid,
d_model,
relu_dropout, )
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout, )
return dec_output
def decoder(dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
dec_output = decoder_layer(
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None if caches is None else caches[i])
dec_input = dec_output
dec_output = pre_process_layer(dec_output, preprocess_cmd,
prepostprocess_dropout)
return dec_output
def make_all_inputs(input_fields):
"""
Define the input data layers for the transformer model.
"""
inputs = []
for input_field in input_fields:
input_var = layers.data(
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
def make_all_py_reader_inputs(input_fields, is_test=False):
"""
Define the input data layers for the transformer model.
"""
reader = layers.py_reader(
capacity=20,
name="test_reader" if is_test else "train_reader",
shapes=[dense_input_descs[input_field][0] for input_field in input_fields],
dtypes=[dense_input_descs[input_field][1] for input_field in input_fields],
lod_levels=[
dense_input_descs[input_field][2]
if len(dense_input_descs[input_field]) == 3 else 0
for input_field in input_fields
], use_double_buffer=True)
return layers.read_file(reader), reader
def dense_transformer(src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
enc_n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
label_smooth_eps,
use_py_reader=False,
is_test=False,
params_type="normal",
all_data_inputs=None):
"""
transformer
"""
if embedding_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields + dense_bias_input_fields
if use_py_reader:
all_inputs = all_data_inputs
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(decoder_data_input_fields[:-1])
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
real_label = all_inputs[enc_inputs_len + dec_inputs_len]
weights = all_inputs[enc_inputs_len + dec_inputs_len + 1]
reverse_label = all_inputs[enc_inputs_len + dec_inputs_len + 2]
enc_inputs[2] = all_inputs[-3] # dense_src_slf_attn_bias
dec_inputs[3] = all_inputs[-2] # dense_trg_slf_attn_bias
dec_inputs[4] = all_inputs[-1] # dense_trg_src_attn_bias
enc_output = wrap_encoder(
src_vocab_size,
max_length,
enc_n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
enc_inputs,
params_type=params_type)
predict = wrap_decoder(
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
dec_inputs,
enc_output, is_train = True if not is_test else False,
params_type=params_type)
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
if label_smooth_eps:
label = layers.one_hot(input=real_label, depth=trg_vocab_size)
label = label * (1 - label_smooth_eps) + (1 - label) * (
label_smooth_eps / (trg_vocab_size - 1))
label.stop_gradient = True
else:
label = real_label
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
sum_cost.persistable = True
token_num = layers.reduce_sum(weights)
token_num.persistable = True
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
sen_count = layers.shape(dec_inputs[0])[0]
batch_predict = layers.reshape(predict, shape = [sen_count, -1, DenseModelHyperParams.trg_vocab_size])
batch_label = layers.reshape(real_label, shape=[sen_count, -1])
batch_weights = layers.reshape(weights, shape=[sen_count, -1, 1])
return sum_cost, avg_cost, token_num, batch_predict, cost, sum_cost, batch_label, batch_weights
def wrap_encoder(src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
enc_inputs=None,
params_type="normal"):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_all_inputs(
encoder_data_input_fields)
else:
src_word, src_pos, src_slf_attn_bias = enc_inputs
enc_input = prepare_encoder(
src_word,
src_pos,
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=dense_word_emb_param_names[0],
params_type=params_type)
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd, )
return enc_output
def wrap_decoder(trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
dec_inputs=None,
enc_output=None,
caches=None, is_train=True, params_type="normal"):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, reverse_trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \
make_all_inputs(dense_decoder_data_input_fields)
else:
trg_word, reverse_trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=dense_word_emb_param_names[0]
if embedding_sharing else dense_word_emb_param_names[1],
training=is_train,
params_type=params_type)
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
assert params_type == "fixed" or params_type == "normal" or params_type == "new"
pre_name = "densedense"
if params_type == "fixed":
pre_name = "fixed_densefixed_dense"
elif params_type == "new":
pre_name = "new_densenew_dense"
if weight_sharing and embedding_sharing:
predict = layers.matmul(
x=dec_output,
y=fluid.default_main_program().global_block().var(
pre_name + dense_word_emb_param_names[0]),
transpose_y=True)
elif weight_sharing:
predict = layers.matmul(
x=dec_output,
y=fluid.default_main_program().global_block().var(
pre_name + dense_word_emb_param_names[1]),
transpose_y=True)
else:
predict = layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False)
#layers.Print(predict, message="logits", summarize=20)
if dec_inputs is None:
# Return probs for independent decoder program.
predict = layers.softmax(predict)
return predict
def get_enc_bias(source_inputs):
"""
get_enc_bias
"""
source_inputs = layers.cast(source_inputs, 'float32')
emb_sum = layers.reduce_sum(layers.abs(source_inputs), dim=-1)
zero = layers.fill_constant([1], 'float32', value=0)
bias = layers.cast(layers.equal(emb_sum, zero), 'float32') * -1e9
return layers.unsqueeze(layers.unsqueeze(bias, axes=[1]), axes=[1])
def dense_fast_decode(
src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
enc_n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
beam_size,
batch_size,
max_out_len,
decode_alpha,
eos_idx,
params_type="normal"):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
assert params_type == "normal" or params_type == "new" or params_type == "fixed"
data_input_names = dense_encoder_data_input_fields + fast_decoder_data_input_fields
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(fast_decoder_data_input_fields)
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
enc_output = wrap_encoder(src_vocab_size, max_in_len, enc_n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing, embedding_sharing, enc_inputs, params_type=params_type)
enc_bias = get_enc_bias(enc_inputs[0])
source_length, = dec_inputs
def beam_search(enc_output, enc_bias, source_length):
"""
beam_search
"""
max_len = layers.fill_constant(
shape=[1], dtype='int64', value=max_out_len)
step_idx = layers.fill_constant(
shape=[1], dtype='int64', value=0)
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
caches_batch_size = batch_size * beam_size
init_score = np.zeros([1, beam_size]).astype('float32')
init_score[:, 1:] = -INF
initial_log_probs = layers.assign(init_score)
alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1])
# alive seq [batch_size, beam_size, 1]
initial_ids = layers.zeros([batch_size, 1, 1], 'float32')
alive_seq = layers.expand(initial_ids, [1, beam_size, 1])
alive_seq = layers.cast(alive_seq, 'int64')
enc_output = layers.unsqueeze(enc_output, axes=[1])
enc_output = layers.expand(enc_output, [1, beam_size, 1, 1])
enc_output = layers.reshape(enc_output, [caches_batch_size, -1, d_model])
tgt_src_attn_bias = layers.unsqueeze(enc_bias, axes=[1])
tgt_src_attn_bias = layers.expand(tgt_src_attn_bias, [1, beam_size, n_head, 1, 1])
enc_bias_shape = layers.shape(tgt_src_attn_bias)
tgt_src_attn_bias = layers.reshape(tgt_src_attn_bias, [-1, enc_bias_shape[2],
enc_bias_shape[3], enc_bias_shape[4]])
beam_search = BeamSearch(beam_size, batch_size, decode_alpha, trg_vocab_size, d_model)
caches = [{
"k": layers.fill_constant(
shape=[caches_batch_size, 0, d_model],
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant(
shape=[caches_batch_size, 0, d_model],
dtype=enc_output.dtype,
value=0)
} for i in range(n_layer)]
finished_seq = layers.zeros_like(alive_seq)
finished_scores = layers.fill_constant([batch_size, beam_size],
dtype='float32', value=-INF)
finished_flags = layers.fill_constant([batch_size, beam_size],
dtype='float32', value=0)
with while_op.block():
pos = layers.fill_constant([caches_batch_size, 1, 1], dtype='int64', value=1)
pos = layers.elementwise_mul(pos, step_idx, axis=0)
alive_seq_1 = layers.reshape(alive_seq, [caches_batch_size, -1])
alive_seq_2 = alive_seq_1[:, -1:]
alive_seq_2 = layers.unsqueeze(alive_seq_2, axes=[1])
logits = wrap_decoder(
trg_vocab_size, max_in_len, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing, embedding_sharing,
dec_inputs=(alive_seq_2, alive_seq_2, pos, None, tgt_src_attn_bias),
enc_output=enc_output, caches=caches, is_train=False, params_type=params_type)
alive_seq_2, alive_log_probs_2, finished_seq_2, finished_scores_2, finished_flags_2, caches_2 = \
beam_search.inner_func(step_idx, logits, alive_seq_1, alive_log_probs, finished_seq,
finished_scores, finished_flags, caches, enc_output,
tgt_src_attn_bias)
layers.increment(x=step_idx, value=1.0, in_place=True)
finish_cond = beam_search.is_finished(step_idx, source_length, alive_log_probs_2,
finished_scores_2, finished_flags_2)
layers.assign(alive_seq_2, alive_seq)
layers.assign(alive_log_probs_2, alive_log_probs)
layers.assign(finished_seq_2, finished_seq)
layers.assign(finished_scores_2, finished_scores)
layers.assign(finished_flags_2, finished_flags)
for i in xrange(len(caches_2)):
layers.assign(caches_2[i]["k"], caches[i]["k"])
layers.assign(caches_2[i]["v"], caches[i]["v"])
layers.logical_and(x=cond, y=finish_cond, out=cond)
finished_flags = layers.reduce_sum(finished_flags, dim=1, keep_dim=True) / beam_size
finished_flags = layers.cast(finished_flags, 'bool')
mask = layers.cast(layers.reduce_any(input=finished_flags, dim=1, keep_dim=True), 'float32')
mask = layers.expand(mask, [1, beam_size])
mask2 = 1.0 - mask
finished_seq = layers.cast(finished_seq, 'float32')
alive_seq = layers.cast(alive_seq, 'float32')
#print mask
finished_seq = layers.elementwise_mul(finished_seq, mask, axis=0) + \
layers.elementwise_mul(alive_seq, mask2, axis = 0)
finished_seq = layers.cast(finished_seq, 'int32')
finished_scores = layers.elementwise_mul(finished_scores, mask, axis=0) + \
layers.elementwise_mul(alive_log_probs, mask2)
finished_seq.persistable = True
finished_scores.persistable = True
return finished_seq, finished_scores
finished_ids, finished_scores = beam_search(enc_output, enc_bias, source_length)
return finished_ids, finished_scores
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layer_helper import LayerHelper as LayerHelper
from config import *
from beam_search import BeamSearch
from attention import _dot_product_relative
INF = 1. * 1e5
def layer_norm(x, begin_norm_axis=1, epsilon=1e-6, param_attr=None, bias_attr=None):
"""
layer_norm
"""
helper = LayerHelper('layer_norm', **locals())
mean = layers.reduce_mean(x, dim=begin_norm_axis, keep_dim=True)
shift_x = layers.elementwise_sub(x=x, y=mean, axis=0)
variance = layers.reduce_mean(layers.square(shift_x), dim=begin_norm_axis, keep_dim=True)
r_stdev = layers.rsqrt(variance + epsilon)
norm_x = layers.elementwise_mul(x=shift_x, y=r_stdev, axis=0)
param_shape = [reduce(lambda x, y: x * y, norm_x.shape[begin_norm_axis:])]
param_dtype = norm_x.dtype
scale = helper.create_parameter(
attr=param_attr,
shape=param_shape,
dtype=param_dtype,
default_initializer=fluid.initializer.Constant(1.))
bias = helper.create_parameter(
attr=bias_attr,
shape=param_shape,
dtype=param_dtype,
is_bias=True,
default_initializer=fluid.initializer.Constant(0.))
out = layers.elementwise_mul(x=norm_x, y=scale, axis=-1)
out = layers.elementwise_add(x=out, y=bias, axis=-1)
return out
def forward_position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels = d_pos_vec
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(np.arange(
num_timescales) * -log_timescale_increment)
#num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
attention_type="dot_product",):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
k = layers.fc(input=keys,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
v = layers.fc(input=values,
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
if n_head == 1:
return x
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key ** -0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = layers.concat([cache['k'], k], axis=1)
v = layers.concat([cache['v'], v], axis=1)
layers.assign(k, cache['k'])
layers.assign(v, cache['v'])
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
assert attention_type == "dot_product" or attention_type == "dot_product_relative_encoder" or attention_type == "dot_product_relative_decoder"
if attention_type == "dot_product":
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, #d_model,
dropout_rate)
elif attention_type == "dot_product_relative_encoder":
q = layers.scale(x=q, scale=d_key ** -0.5)
ctx_multiheads = _dot_product_relative(q, k, v, attn_bias, dropout=dropout_rate)
else:
q = layers.scale(x=q, scale=d_key ** -0.5)
ctx_multiheads = _dot_product_relative(q, k, v, attn_bias, dropout=dropout_rate, cache = cache)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
bias_attr=False,
num_flatten_dims=2)
return proj_out
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
epsilon=1e-6,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def prepare_encoder_decoder(src_word,
src_pos,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0.,
word_emb_param_name=None,
training=True,
pos_enc_param_name=None,
is_src=True,
params_type="normal"):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
assert params_type == "fixed" or params_type == "normal" or params_type == "new"
pre_name = "forwardforward"
if params_type == "fixed":
pre_name = "fixed_forwardfixed_forward"
elif params_type == "new":
pre_name = "new_forwardnew_forward"
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=ModelHyperParams.bos_idx, # set embedding of bos to 0
param_attr=fluid.ParamAttr(
name = pre_name + word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim ** -0.5)))#, is_sparse=True)
if not is_src and training:
src_word_emb = layers.pad(src_word_emb, [0, 0, 1, 0, 0, 0])
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim ** 0.5)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
trainable=False, name = pre_name + pos_enc_param_name))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
return layers.dropout(
enc_input,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train') if dropout_rate else enc_input
prepare_encoder = partial(
prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0], is_src=True)
prepare_decoder = partial(
prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1], is_src=False)
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(enc_input, preprocess_cmd,
prepostprocess_dropout), None, None, attn_bias, d_key,
d_value, d_model, n_head, attention_dropout)
attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
prepostprocess_dropout)
ffd_output = positionwise_feed_forward(
pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
d_inner_hid, d_model, relu_dropout)
return post_process_layer(attn_output, ffd_output, postprocess_cmd,
prepostprocess_dropout)
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da"):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd, )
enc_input = enc_output
enc_output = pre_process_layer(enc_output, preprocess_cmd,
prepostprocess_dropout)
return enc_output
def decoder_layer(dec_input,
enc_output,
slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output = multi_head_attention(
pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
None,
None,
slf_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache)
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
postprocess_cmd,
prepostprocess_dropout, )
enc_attn_output = multi_head_attention(
pre_process_layer(slf_attn_output, preprocess_cmd,
prepostprocess_dropout),
enc_output,
enc_output,
dec_enc_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout, )
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
postprocess_cmd,
prepostprocess_dropout, )
ffd_output = positionwise_feed_forward(
pre_process_layer(enc_attn_output, preprocess_cmd,
prepostprocess_dropout),
d_inner_hid,
d_model,
relu_dropout, )
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout, )
return dec_output
def decoder(dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=None):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
dec_output = decoder_layer(
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None if caches is None else caches[i])
dec_input = dec_output
dec_output = pre_process_layer(dec_output, preprocess_cmd,
prepostprocess_dropout)
return dec_output
def make_all_inputs(input_fields):
"""
Define the input data layers for the transformer model.
"""
inputs = []
for input_field in input_fields:
input_var = layers.data(
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
def make_all_py_reader_inputs(input_fields, is_test=False):
"""
Define the input data layers for the transformer model.
"""
reader = layers.py_reader(
capacity=20,
name="test_reader" if is_test else "train_reader",
shapes=[input_descs[input_field][0] for input_field in input_fields],
dtypes=[input_descs[input_field][1] for input_field in input_fields],
lod_levels=[
input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0
for input_field in input_fields
], use_double_buffer=True)
return layers.read_file(reader), reader
def forward_transformer(src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
label_smooth_eps,
use_py_reader=False,
is_test=False,
params_type="normal",
all_data_inputs=None):
"""
transformer
"""
if embedding_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields + dense_bias_input_fields
if use_py_reader:
all_inputs = all_data_inputs
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(decoder_data_input_fields[:-1])
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
real_label = all_inputs[enc_inputs_len + dec_inputs_len]
weights = all_inputs[enc_inputs_len + dec_inputs_len + 1]
reverse_label = all_inputs[enc_inputs_len + dec_inputs_len + 2]
enc_output = wrap_encoder(
src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
enc_inputs,
params_type=params_type)
predict = wrap_decoder(
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
dec_inputs,
enc_output, is_train = True if not is_test else False,
params_type=params_type)
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
if label_smooth_eps:
label = layers.one_hot(input=real_label, depth=trg_vocab_size)
label = label * (1 - label_smooth_eps) + (1 - label) * (
label_smooth_eps / (trg_vocab_size - 1))
label.stop_gradient = True
else:
label = real_label
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
sum_cost.persistable = True
token_num = layers.reduce_sum(weights)
token_num.persistable = True
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
sen_count = layers.shape(dec_inputs[0])[0]
batch_predict = layers.reshape(predict, shape = [sen_count, -1, ModelHyperParams.trg_vocab_size])
#batch_label = layers.reshape(real_label, shape=[sen_count, -1])
batch_weights = layers.reshape(weights, shape=[sen_count, -1, 1])
return sum_cost, avg_cost, token_num, batch_predict, cost, sum_cost, real_label, batch_weights
def wrap_encoder(src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
enc_inputs=None,
params_type="normal"):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_all_inputs(
encoder_data_input_fields)
else:
src_word, src_pos, src_slf_attn_bias = enc_inputs
enc_input = prepare_encoder(
src_word,
src_pos,
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[0],
params_type=params_type)
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd, )
return enc_output
def wrap_decoder(trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
dec_inputs=None,
enc_output=None,
caches=None, is_train=True, params_type="normal"):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, reverse_trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \
make_all_inputs(decoder_data_input_fields)
else:
trg_word, reverse_trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[0]
if embedding_sharing else word_emb_param_names[1],
training=is_train,
params_type=params_type)
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=caches)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
assert params_type == "fixed" or params_type == "normal" or params_type == "new"
pre_name = "forwardforward"
if params_type == "fixed":
pre_name = "fixed_forwardfixed_forward"
elif params_type == "new":
pre_name = "new_forwardnew_forward"
if weight_sharing and embedding_sharing:
predict = layers.matmul(
x=dec_output,
y=fluid.default_main_program().global_block().var(
pre_name + word_emb_param_names[0]),
transpose_y=True)
elif weight_sharing:
predict = layers.matmul(
x=dec_output,
y=fluid.default_main_program().global_block().var(
pre_name + word_emb_param_names[1]),
transpose_y=True)
else:
predict = layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False)
if dec_inputs is None:
# Return probs for independent decoder program.
predict = layers.softmax(predict)
return predict
def get_enc_bias(source_inputs):
"""
get_enc_bias
"""
source_inputs = layers.cast(source_inputs, 'float32')
emb_sum = layers.reduce_sum(layers.abs(source_inputs), dim=-1)
zero = layers.fill_constant([1], 'float32', value=0)
bias = layers.cast(layers.equal(emb_sum, zero), 'float32') * -1e9
return layers.unsqueeze(layers.unsqueeze(bias, axes=[1]), axes=[1])
def forward_fast_decode(
src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
beam_size,
batch_size,
max_out_len,
decode_alpha,
eos_idx,
params_type="normal"):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
assert params_type == "normal" or params_type == "new" or params_type == "fixed"
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(fast_decoder_data_input_fields)
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing, embedding_sharing, enc_inputs, params_type=params_type)
enc_bias = get_enc_bias(enc_inputs[0])
source_length, = dec_inputs
def beam_search(enc_output, enc_bias, source_length):
"""
beam_search
"""
max_len = layers.fill_constant(
shape=[1], dtype='int64', value=max_out_len)
step_idx = layers.fill_constant(
shape=[1], dtype='int64', value=0)
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
caches_batch_size = batch_size * beam_size
init_score = np.zeros([1, beam_size]).astype('float32')
init_score[:, 1:] = -INF
initial_log_probs = layers.assign(init_score)
alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1])
# alive seq [batch_size, beam_size, 1]
initial_ids = layers.zeros([batch_size, 1, 1], 'float32')
alive_seq = layers.expand(initial_ids, [1, beam_size, 1])
alive_seq = layers.cast(alive_seq, 'int64')
enc_output = layers.unsqueeze(enc_output, axes=[1])
enc_output = layers.expand(enc_output, [1, beam_size, 1, 1])
enc_output = layers.reshape(enc_output, [caches_batch_size, -1, d_model])
tgt_src_attn_bias = layers.unsqueeze(enc_bias, axes=[1])
tgt_src_attn_bias = layers.expand(tgt_src_attn_bias, [1, beam_size, n_head, 1, 1])
enc_bias_shape = layers.shape(tgt_src_attn_bias)
tgt_src_attn_bias = layers.reshape(tgt_src_attn_bias, [-1, enc_bias_shape[2],
enc_bias_shape[3], enc_bias_shape[4]])
beam_search = BeamSearch(beam_size, batch_size, decode_alpha, trg_vocab_size, d_model)
caches = [{
"k": layers.fill_constant(
shape=[caches_batch_size, 0, d_model],
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant(
shape=[caches_batch_size, 0, d_model],
dtype=enc_output.dtype,
value=0)
} for i in range(n_layer)]
finished_seq = layers.zeros_like(alive_seq)
finished_scores = layers.fill_constant([batch_size, beam_size],
dtype='float32', value=-INF)
finished_flags = layers.fill_constant([batch_size, beam_size],
dtype='float32', value=0)
with while_op.block():
pos = layers.fill_constant([caches_batch_size, 1, 1], dtype='int64', value=1)
pos = layers.elementwise_mul(pos, step_idx, axis=0)
alive_seq_1 = layers.reshape(alive_seq, [caches_batch_size, -1])
alive_seq_2 = alive_seq_1[:, -1:]
alive_seq_2 = layers.unsqueeze(alive_seq_2, axes=[1])
logits = wrap_decoder(
trg_vocab_size, max_in_len, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing, embedding_sharing,
dec_inputs=(alive_seq_2, alive_seq_2, pos, None, tgt_src_attn_bias),
enc_output=enc_output, caches=caches, is_train=False, params_type=params_type)
alive_seq_2, alive_log_probs_2, finished_seq_2, finished_scores_2, finished_flags_2, caches_2 = \
beam_search.inner_func(step_idx, logits, alive_seq_1, alive_log_probs, finished_seq,
finished_scores, finished_flags, caches, enc_output,
tgt_src_attn_bias)
layers.increment(x=step_idx, value=1.0, in_place=True)
finish_cond = beam_search.is_finished(step_idx, source_length, alive_log_probs_2,
finished_scores_2, finished_flags_2)
layers.assign(alive_seq_2, alive_seq)
layers.assign(alive_log_probs_2, alive_log_probs)
layers.assign(finished_seq_2, finished_seq)
layers.assign(finished_scores_2, finished_scores)
layers.assign(finished_flags_2, finished_flags)
for i in xrange(len(caches_2)):
layers.assign(caches_2[i]["k"], caches[i]["k"])
layers.assign(caches_2[i]["v"], caches[i]["v"])
layers.logical_and(x=cond, y=finish_cond, out=cond)
finished_flags = layers.reduce_sum(finished_flags, dim=1, keep_dim=True) / beam_size
finished_flags = layers.cast(finished_flags, 'bool')
mask = layers.cast(layers.reduce_any(input=finished_flags, dim=1, keep_dim=True), 'float32')
mask = layers.expand(mask, [1, beam_size])
mask2 = 1.0 - mask
finished_seq = layers.cast(finished_seq, 'float32')
alive_seq = layers.cast(alive_seq, 'float32')
#print mask
finished_seq = layers.elementwise_mul(finished_seq, mask, axis=0) + \
layers.elementwise_mul(alive_seq, mask2, axis = 0)
finished_seq = layers.cast(finished_seq, 'int32')
finished_scores = layers.elementwise_mul(finished_scores, mask, axis=0) + \
layers.elementwise_mul(alive_log_probs, mask2)
finished_seq.persistable = True
finished_scores.persistable = True
return finished_seq, finished_scores
finished_ids, finished_scores = beam_search(enc_output, enc_bias, source_length)
return finished_ids, finished_scores
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os, sys
import random
import six
import ast
class TRDataGen(object):
"""record data generator
"""
def __init__(self, num_shards, data_dir):
self.num_shards = num_shards
self.data_dir = data_dir
def gen_data_fnames(self, is_train=True):
"""generate filenames for train and valid
return:
train_filenames, valid_filenames
"""
if not os.path.isdir(self.data_dir):
try:
os.mkdir(self.data_dir)
except Exception as e:
raise ValueError("%s is exists as one file", self.data_dir)
if is_train:
train_prefix = os.path.join(self.data_dir, "translate-train-%05d-of_unshuffle")
return [train_prefix % i for i in xrange(self.num_shards)]
return [os.path.join(self.data_dir, "translate-dev-00000-of_unshuffle")]
def generate(self, data_list, is_train=True, is_shuffle=True):
"""generating record file
:param data_list:
:param is_train:
:return:
"""
output_filename = self.gen_data_fnames(is_train)
#writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filename]
writers = [open(fname, 'w') for fname in output_filename]
ct = 0
shard = 0
for case in data_list:
ct += 1
if ct % 10000 == 0:
logging.info("Generating case %s ." % ct)
example = self.to_example(case)
writers[shard].write(example.strip() + "\n")
if is_train:
shard = (shard + 1) % self.num_shards
logging.info("Generating case %s ." % ct)
for writer in writers:
writer.close()
if is_shuffle:
self.shuffle_dataset(output_filename)
def to_example(self, dictionary):
"""
:param source:
:param target:
:return:
"""
if "inputs" not in dictionary or "targets" not in dictionary:
raise ValueError("Empty generated field: inputs or target")
inputs = " ".join(str(x) for x in dictionary["inputs"])
targets = " ".join(str(x) for x in dictionary["targets"])
return inputs + "\t" + targets
def shuffle_dataset(self, filenames):
"""
:return:
"""
logging.info("Shuffling data...")
for fname in filenames:
records = self.read_records(fname)
random.shuffle(records)
out_fname = fname.replace("_unshuffle", "-shuffle")
self.write_records(records, out_fname)
os.remove(fname)
def read_records(self, filename):
"""
:param filename:
:return:
"""
records = []
with open(filename, 'r') as reader:
for record in reader:
records.append(record)
if len(records) % 100000 == 0:
logging.info("read: %d", len(records))
return records
def write_records(self, records, out_filename):
"""
:param records:
:param out_filename:
:return:
"""
with open(out_filename, 'w') as f:
for count, record in enumerate(records):
f.write(record)
if count > 0 and count % 100000 == 0:
logging.info("write: %d", count)
if __name__ == "__main__":
from preprocess.problem import SubwordVocabProblem
from preprocess.problem import TokenVocabProblem
import argparse
parser = argparse.ArgumentParser("Tips for generating subword.")
parser.add_argument(
"--tmp_dir",
type=str,
required=True,
help="dir that includes original corpus.")
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="dir that generates training files")
parser.add_argument(
"--source_train_files",
type=str,
required=True,
help="train file for source")
parser.add_argument(
"--target_train_files",
type=str,
required=True,
help="train file for target")
parser.add_argument(
"--source_vocab_size",
type=int,
required=True,
help="source_vocab_size")
parser.add_argument(
"--target_vocab_size",
type=int,
required=True,
help="target_vocab_size")
parser.add_argument(
"--num_shards",
type=int,
default=100,
help="number of shards")
parser.add_argument(
"--subword",
type=ast.literal_eval,
default=False,
help="subword")
parser.add_argument(
"--token",
type=ast.literal_eval,
default=False,
help="token")
parser.add_argument(
"--onevocab",
type=ast.literal_eval,
default=False,
help="share vocab")
args = parser.parse_args()
print args
gen = TRDataGen(args.num_shards, args.data_dir)
source_train_files = args.source_train_files.split(",")
target_train_files = args.target_train_files.split(",")
if args.token == args.subword:
print "one of subword or token is True"
import sys
sys.exit(1)
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=LOG_FORMAT)
if args.subword:
problem = SubwordVocabProblem(args.source_vocab_size,
args.target_vocab_size,
source_train_files,
target_train_files,
None,
None,
args.onevocab)
else:
problem = TokenVocabProblem(args.source_vocab_size,
args.target_vocab_size,
source_train_files,
target_train_files,
None,
None,
args.onevocab)
gen.generate(problem.generate_data(args.data_dir, args.tmp_dir, True), True, True)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
id2word = {}
ln = sys.stdin
def load_vocab(file_path):
start_index = 0
f = open(file_path, 'r')
for line in f:
line = line.strip()
id2word[start_index] = line
start_index += 1
f.close()
if __name__=="__main__":
load_vocab(sys.argv[1])
while True:
line = ln.readline().strip()
if not line:
break
split_res = line.split(" ")
output_str = ""
for item in split_res:
output_str += id2word[int(item.strip())]
output_str += " "
output_str = output_str.strip()
print output_str
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import multiprocessing
import numpy as np
import os
from functools import partial
import contextlib
import time
import paddle.fluid.profiler as profiler
import paddle
import paddle.fluid as fluid
import forward_model
import reader
import sys
from config import *
from forward_model import wrap_encoder as encoder
from forward_model import wrap_decoder as decoder
from forward_model import forward_fast_decode
from dense_model import dense_fast_decode
from relative_model import relative_fast_decode
from forward_model import forward_position_encoding_init
from reader import *
def parse_args():
"""
parse_args
"""
parser = argparse.ArgumentParser("Training for Transformer.")
parser.add_argument(
"--val_file_pattern",
type=str,
required=True,
help="The pattern to match test data files.")
parser.add_argument(
"--batch_size",
type=int,
default=50,
help="The number of examples in one run for sequence generation.")
parser.add_argument(
"--pool_size",
type=int,
default=10000,
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use memory optimization.")
parser.add_argument(
"--use_py_reader",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--use_parallel_exe",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
"--use_candidate",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use candidates.")
parser.add_argument(
"--common_ids",
type=str,
default="",
help="The file path of common ids.")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
parser.add_argument(
"--use_delay_load",
type=ast.literal_eval,
default=True,
help=
"The flag indicating whether to load all data into memories at once.")
parser.add_argument(
"--vocab_size",
type=str,
required=True,
help="Size of Vocab.")
parser.add_argument(
"--infer_batch_size",
type=int,
help="Infer batch_size")
parser.add_argument(
"--decode_alpha",
type=float,
help="decode_alpha")
args = parser.parse_args()
# Append args related to dict
#src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
#trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
#dict_args = [
# "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
# str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
# "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
# str(src_dict[args.special_token[2]])
#]
voc_size = args.vocab_size
dict_args = [
"src_vocab_size", voc_size,
"trg_vocab_size", voc_size,
"bos_idx", str(0),
"eos_idx", str(1),
"unk_idx", str(int(voc_size) - 1)
]
merge_cfg_from_list(args.opts + dict_args,
[InferTaskConfig, ModelHyperParams])
return args
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
d_model):
"""
Put all padded data needed by beam search decoder into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
source_length = np.asarray([src_max_len], dtype="int64")
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, source_length
]))
return data_input_dict
def prepare_feed_dict_list(data_generator, count):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.bos_idx,
ModelHyperParams.bos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def prepare_dense_feed_dict_list(data_generator, count):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = dense_encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict = prepare_batch_input(
data_buffer, data_input_names, DenseModelHyperParams.bos_idx,
DenseModelHyperParams.bos_idx, DenseModelHyperParams.n_head,
DenseModelHyperParams.d_model)
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def prepare_infer_feed_dict_list(data_generator, count):
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
dense_data_input_names = dense_encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
dense_data_input_dict = prepare_batch_input(
data_buffer, dense_data_input_names, DenseModelHyperParams.bos_idx,
DenseModelHyperParams.bos_idx, DenseModelHyperParams.n_head,
DenseModelHyperParams.d_model)
data_input_dict = prepare_batch_input(data_buffer, data_input_names,
ModelHyperParams.bos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
for key in dense_data_input_dict:
if key not in data_input_dict:
data_input_dict[key] = dense_data_input_dict[key]
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def get_trans_res(batch_size, out_list, final_list):
"""
Get trans
"""
for index in xrange(batch_size):
seq = out_list[index][0] #top1 seq
if 1 not in seq:
res = seq[1:-1]
else:
res = seq[1:seq.index(1)]
res = map(str, res)
final_list.append(" ".join(res))
def fast_infer(args):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
test_prog = fluid.Program()
startup_prog = fluid.Program()
#with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard("new_forward"):
out_ids1, out_scores1 = forward_fast_decode(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
InferTaskConfig.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
with fluid.unique_name.guard("new_relative_position"):
out_ids2, out_scores2 = relative_fast_decode(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
InferTaskConfig.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
DenseModelHyperParams.src_vocab_size = ModelHyperParams.src_vocab_size
DenseModelHyperParams.trg_vocab_size = ModelHyperParams.trg_vocab_size
DenseModelHyperParams.weight_sharing = ModelHyperParams.weight_sharing
DenseModelHyperParams.embedding_sharing = ModelHyperParams.embedding_sharing
with fluid.unique_name.guard("new_dense"):
out_ids3, out_scores3 = dense_fast_decode(
DenseModelHyperParams.src_vocab_size,
DenseModelHyperParams.trg_vocab_size,
DenseModelHyperParams.max_length + 50,
DenseModelHyperParams.n_layer,
DenseModelHyperParams.enc_n_layer,
DenseModelHyperParams.n_head,
DenseModelHyperParams.d_key,
DenseModelHyperParams.d_value,
DenseModelHyperParams.d_model,
DenseModelHyperParams.d_inner_hid,
DenseModelHyperParams.prepostprocess_dropout,
DenseModelHyperParams.attention_dropout,
DenseModelHyperParams.relu_dropout,
DenseModelHyperParams.preprocess_cmd,
DenseModelHyperParams.postprocess_cmd,
DenseModelHyperParams.weight_sharing,
DenseModelHyperParams.embedding_sharing,
InferTaskConfig.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
test_prog = fluid.default_main_program().clone(for_test=True)
# This is used here to set dropout to the test mode.
if InferTaskConfig.use_gpu:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_params(
exe,
InferTaskConfig.model_path,
main_program=test_prog)
if args.use_mem_opt:
fluid.memory_optimize(test_prog)
exec_strategy = fluid.ExecutionStrategy()
# For faster executor
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 1
build_strategy = fluid.BuildStrategy()
# data reader settings for inference
args.use_token_batch = False
#args.sort_type = reader.SortType.NONE
args.shuffle = False
args.shuffle_batch = False
dev_count = 1
lines_cnt = len(open(args.val_file_pattern, 'r').readlines())
data_reader = line_reader(args.val_file_pattern, args.infer_batch_size, dev_count,
token_delimiter=args.token_delimiter,
max_len=ModelHyperParams.max_length,
parse_line=parse_src_line)
test_data = prepare_data_generator(
args,
is_test=True,
count=dev_count,
pyreader=None,
batch_size=args.infer_batch_size, data_reader=data_reader)
data_generator = test_data()
iter_num = 0
if not os.path.exists("trans"):
os.mkdir("trans")
model_name = InferTaskConfig.model_path.split("/")[-1]
forward_res = open(os.path.join("trans", "forward_%s" % model_name), 'w')
relative_res = open(os.path.join("trans", "relative_%s" % model_name), 'w')
dense_res = open(os.path.join("trans", "dense_%s" % model_name), 'w')
forward_list = []
relative_list = []
dense_list = []
with profile_context(False):
while True:
try:
feed_dict_list = prepare_infer_feed_dict_list(data_generator, dev_count)
forward_seq_ids, relative_seq_ids, dense_seq_ids = exe.run(
program=test_prog,
fetch_list=[out_ids1.name, out_ids2.name, out_ids3.name],
feed=feed_dict_list[0]
if feed_dict_list is not None else None,
return_numpy=False,
use_program_cache=False)
fseq_ids = np.asarray(forward_seq_ids).tolist()
rseq_ids = np.asarray(relative_seq_ids).tolist()
dseq_ids = np.asarray(dense_seq_ids).tolist()
get_trans_res(args.infer_batch_size, fseq_ids, forward_list)
get_trans_res(args.infer_batch_size, rseq_ids, relative_list)
get_trans_res(args.infer_batch_size, dseq_ids, dense_list)
except (StopIteration, fluid.core.EOFException):
break
forward_list = forward_list[:lines_cnt]
relative_list = relative_list[:lines_cnt]
dense_list = dense_list[:lines_cnt]
forward_res.writelines("\n".join(forward_list))
forward_res.flush()
forward_res.close()
relative_res.writelines("\n".join(relative_list))
relative_res.flush()
relative_res.close()
dense_res.writelines("\n".join(dense_list))
dense_res.flush()
dense_res.close()
@contextlib.contextmanager
def profile_context(profile=True):
"""
profile_context
"""
if profile:
with profiler.profiler('All', 'total', './profile_dir/profile_file_tmp'):
yield
else:
yield
if __name__ == "__main__":
args = parse_args()
fast_infer(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import subprocess
import commands
import os
import six
import copy
import argparse
import time
from args import ArgumentGroup, print_arguments, inv_arguments
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
multip_g = ArgumentGroup(parser, "multiprocessing",
"start paddle training using multi-processing mode.")
multip_g.add_arg("node_ips", str, None,
"paddle trainer ips")
multip_g.add_arg("node_id", int, None,
"the trainer id of the node for multi-node distributed training.")
multip_g.add_arg("print_config", bool, True,
"print the config of multi-processing mode.")
multip_g.add_arg("current_node_ip", str, None,
"the ip of current node.")
multip_g.add_arg("split_log_path", str, "log",
"log path for each trainer.")
multip_g.add_arg("log_prefix", str, "",
"the prefix name of job log.")
multip_g.add_arg("nproc_per_node", int, 8,
"the number of process to use on each node.")
multip_g.add_arg("selected_gpus", str, "0,1,2,3,4,5,6,7",
"the gpus selected to use.")
multip_g.add_arg("training_script", str, None, "the program/script to be lauched "
"in parallel followed by all the arguments", positional_arg=True)
multip_g.add_arg("training_script_args", str, None,
"training script args", positional_arg=True, nargs=argparse.REMAINDER)
# yapf: enable
def start_procs(args):
"""
start_procs
"""
procs = []
log_fns = []
default_env = os.environ.copy()
node_id = args.node_id
node_ips = [x.strip() for x in args.node_ips.split(',')]
current_ip = args.current_node_ip
num_nodes = len(node_ips)
selected_gpus = [x.strip() for x in args.selected_gpus.split(',')]
selected_gpu_num = len(selected_gpus)
all_trainer_endpoints = ""
for ip in node_ips:
for i in range(args.nproc_per_node):
if all_trainer_endpoints != "":
all_trainer_endpoints += ","
all_trainer_endpoints += "%s:617%d" % (ip, i)
nranks = num_nodes * args.nproc_per_node
gpus_per_proc = args.nproc_per_node % selected_gpu_num
if gpus_per_proc == 0:
gpus_per_proc = selected_gpu_num / args.nproc_per_node
else:
gpus_per_proc = selected_gpu_num / args.nproc_per_node + 1
selected_gpus_per_proc = [selected_gpus[i:i + gpus_per_proc]
for i in range(0, len(selected_gpus), gpus_per_proc)]
if args.print_config:
print("all_trainer_endpoints: ", all_trainer_endpoints,
", node_id: ", node_id,
", current_ip: ", current_ip,
", num_nodes: ", num_nodes,
", node_ips: ", node_ips,
", gpus_per_proc: ", gpus_per_proc,
", selected_gpus_per_proc: ", selected_gpus_per_proc,
", nranks: ", nranks)
current_env = copy.copy(default_env)
procs = []
cmds = []
log_fns = []
for i in range(0, args.nproc_per_node):
trainer_id = node_id * args.nproc_per_node + i
current_env.update({
"FLAGS_selected_gpus": "%s" % ",".join([str(s) for s in selected_gpus_per_proc[i]]),
"PADDLE_TRAINER_ID": "%d" % trainer_id,
"PADDLE_CURRENT_ENDPOINT": "%s:617%d" % (current_ip, i),
"PADDLE_TRAINERS_NUM": "%d" % nranks,
"PADDLE_TRAINER_ENDPOINTS": all_trainer_endpoints,
"PADDLE_NODES_NUM": "%d" % num_nodes
})
cmd = [sys.executable, "-u",
args.training_script] + args.training_script_args
cmds.append(cmd)
if args.split_log_path:
fn = open("%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
log_fns.append(fn)
process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
process = subprocess.Popen(cmd, env=current_env)
procs.append(process)
for i in range(len(procs)):
proc = procs[i]
proc.wait()
if len(log_fns) > 0:
log_fns[i].close()
if proc.returncode != 0:
raise subprocess.CalledProcessError(returncode=procs[i].returncode,
cmd=cmds[i])
else:
print("proc %d finsh" % i)
def main(args):
"""
main_func
"""
if args.print_config:
print_arguments(args)
start_procs(args)
if __name__ == "__main__":
lanch_args = parser.parse_args()
main(lanch_args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
from text_encoder import SubwordTextEncoder, TokenTextEncoder
from text_encoder import EOS_ID
def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
sources, file_byte_budget=1e6):
"""Generate a vocabulary from the datasets in sources."""
def generate():
"""Generate lines for vocabulary generation."""
logging.info("Generating vocab from: %s", str(sources))
for source in sources:
for lang_file in source[1]:
logging.info("Reading file: %s" % lang_file)
filepath = os.path.join(tmp_dir, lang_file)
with open(filepath, mode="r") as source_file:
file_byte_budget_ = file_byte_budget
counter = 0
countermax = int(os.path.getsize(filepath) / file_byte_budget_ / 2)
logging.info("countermax: %d" % countermax)
for line in source_file:
if counter < countermax:
counter += 1
else:
if file_byte_budget_ <= 0:
break
line = line.strip()
file_byte_budget_ -= len(line)
counter = 0
yield line
return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generate())
def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generator, max_subtoken_length=None,
reserved_tokens=None):
"""Inner implementation for vocab generators.
Args:
data_dir: The base directory where data and vocab files are stored. If None,
then do not save the vocab even if it doesn't exist.
vocab_filename: relative filename where vocab file is stored
vocab_size: target size of the vocabulary constructed by SubwordTextEncoder
generator: a generator that produces tokens from the vocabulary
max_subtoken_length: an optional integer. Set this to a finite value to
avoid quadratic costs during vocab building.
reserved_tokens: List of reserved tokens. `text_encoder.RESERVED_TOKENS`
should be a prefix of `reserved_tokens`. If `None`, defaults to
`RESERVED_TOKENS`.
Returns:
A SubwordTextEncoder vocabulary object.
"""
if data_dir and vocab_filename:
vocab_filepath = os.path.join(data_dir, vocab_filename)
if os.path.exists(vocab_filepath):
logging.info("Found vocab file: %s", vocab_filepath)
return SubwordTextEncoder(vocab_filepath)
else:
vocab_filepath = None
logging.info("Generating vocab file: %s", vocab_filepath)
vocab = SubwordTextEncoder.build_from_generator(
generator, vocab_size, max_subtoken_length=max_subtoken_length,
reserved_tokens=reserved_tokens)
if vocab_filepath:
if not os.path.exists(data_dir):
os.makedirs(data_dir)
vocab.store_to_file(vocab_filepath)
return vocab
def txt_line_iterator(fname):
"""
generator for line
:param fname:
:return:
"""
with open(fname, 'r') as f:
for line in f:
yield line.strip()
def txt2txt_generator(source_fname, target_fname):
"""
:param source_fname:
:param target_fname:
:return:
"""
for source, target in zip(
txt_line_iterator(source_fname),
txt_line_iterator(target_fname)
):
yield {"inputs": source, "targets": target}
def txt2txt_encoder(sample_generator, vocab, target_vocab=None):
"""
:param sample_generator:
:param vocab:
:param target_vocab:
:return:
"""
target_vocab = target_vocab or vocab
for sample in sample_generator:
sample["inputs"] = vocab.encode(sample["inputs"])
sample["inputs"].append(EOS_ID)
sample["targets"] = target_vocab.encode(sample["targets"])
sample["targets"].append(EOS_ID)
yield sample
def txt_encoder(filename, batch_size=1, vocab=None):
"""
:param sample_generator:
:param vocab:
:return:
"""
def pad_mini_batch(batch):
"""
:param batch:
:return:
"""
lens = map(lambda x: len(x), batch)
max_len = max(lens)
for i in range(len(batch)):
batch[i] = batch[i] + [0] * (max_len - lens[i])
return batch
fp = open(filename, 'r')
samples = []
batches = []
ct = 0
for sample in fp:
sample = sample.strip()
if vocab:
sample = vocab.encode(sample)
else:
sample = [int(s) for s in sample]
#sample.append(EOS_ID)
batches.append(sample)
ct += 1
if ct % batch_size == 0:
batches = pad_mini_batch(batches)
samples.extend(batches)
batches = []
if ct % batch_size != 0:
batches += [batches[-1]] * (batch_size - ct % batch_size)
batches = pad_mini_batch(batches)
samples.extend(batches)
return samples
if __name__ == "__main__":
parser = argparse.ArgumentParser("Tips for generating testset")
parser.add_argument(
"--vocab",
type=str,
required=True,
help="The path of source vocab.")
parser.add_argument(
"--testset",
type=str,
required=True,
help="The path of testset.")
parser.add_argument(
"--output",
type=str,
required=True,
help="The path of result.")
args = parser.parse_args()
token = TokenTextEncoder(args.vocab)
samples = txt_encoder(args.testset, 1, token)
with open(args.output, 'w') as f:
for sample in samples:
res = [str(item) for item in sample]
f.write("%s\n" % " ".join(res))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from gen_utils import get_or_generate_vocab
from gen_utils import txt_line_iterator
import os, sys
from gen_utils import txt2txt_encoder
from gen_utils import txt2txt_generator
from text_encoder import TokenTextEncoder
import logging
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=LOG_FORMAT)
class GenSubword(object):
"""
gen subword
"""
def __init__(self,
vocab_size=8000,
training_dataset_filenames="train.txt"):
"""
:param vocab_size:
:param vocab_name:
:param training_dataset_filenames: list
"""
self.vocab_size = vocab_size
self.vocab_name = "vocab.%s" % self.vocab_size
if not isinstance(training_dataset_filenames, list):
training_dataset_filenames = [training_dataset_filenames]
self.training_dataset_filenames = training_dataset_filenames
def generate_data(self, data_dir, tmp_dir):
"""
:param data_dir: target dir(includes vocab file)
:param tmp_dir: original dir(includes training dataset filenames)
:return:
"""
data_set = [["", self.training_dataset_filenames]]
source_vocab = get_or_generate_vocab(
data_dir,
tmp_dir,
self.vocab_name,
self.vocab_size,
data_set,
file_byte_budget=1e8)
source_vocab.store_to_file(os.path.join(data_dir, self.vocab_name))
class SubwordVocabProblem(object):
"""subword input"""
def __init__(self,
source_vocab_size=8000,
target_vocab_size=8000,
source_train_filenames="train.src",
target_train_filenames="train.tgt",
source_dev_filenames="dev.src",
target_dev_filenames="dev.tgt",
one_vocab=False):
"""
:param source_vocab_size:
:param target_vocab_size:
:param source_train_filenames:
:param target_train_filenames:
:param source_dev_filenames:
:param target_dev_filenames:
"""
self.source_vocab_size = source_vocab_size
self.target_vocab_size = target_vocab_size
self.source_vocab_name = "vocab.source.%s" % self.source_vocab_size
self.target_vocab_name = "vocab.target.%s" % self.target_vocab_size
if not isinstance(source_train_filenames, list):
source_train_filenames = [source_train_filenames]
if not isinstance(target_train_filenames, list):
target_train_filenames = [target_train_filenames]
if not isinstance(source_dev_filenames, list):
source_dev_filenames = [source_dev_filenames]
if not isinstance(target_dev_filenames, list):
target_dev_filenames = [target_dev_filenames]
self.source_train_filenames = source_train_filenames
self.target_train_filenames = target_train_filenames
self.source_dev_filenames = source_dev_filenames
self.target_dev_filenames = target_dev_filenames
self.one_vocab = one_vocab
def generate_data(self, data_dir, tmp_dir, is_train=True):
"""
:param data_dir:
:param tmp_dir:
:return:
"""
self.source_train_ds = [["", self.source_train_filenames]]
self.target_train_ds = [["", self.target_train_filenames]]
logging.info("building source vocab ...")
logging.info(self.one_vocab)
if not self.one_vocab:
source_vocab = get_or_generate_vocab(data_dir, tmp_dir,
self.source_vocab_name,
self.source_vocab_size,
self.source_train_ds,
file_byte_budget=1e8)
logging.info("building target vocab ...")
target_vocab = get_or_generate_vocab(data_dir, tmp_dir,
self.target_vocab_name,
self.target_vocab_size,
self.target_train_ds,
file_byte_budget=1e8)
else:
train_ds = [["", self.source_train_filenames + self.target_train_filenames]]
source_vocab = get_or_generate_vocab(data_dir, tmp_dir,
self.source_vocab_name,
self.source_vocab_size,
train_ds,
file_byte_budget=1e8)
target_vocab = source_vocab
target_vocab.store_to_file(os.path.join(data_dir, self.target_vocab_name))
pair_filenames = [self.source_train_filenames, self.target_train_filenames]
if not is_train:
pair_filenames = [self.source_dev_filenames, self.target_dev_filenames]
self.compile_data(tmp_dir, pair_filenames, is_train)
source_fname = "train.lang1" if is_train else "dev.lang1"
target_fname = "train.lang2" if is_train else "dev.lang2"
source_fname = os.path.join(tmp_dir, source_fname)
target_fname = os.path.join(tmp_dir, target_fname)
return txt2txt_encoder(txt2txt_generator(source_fname, target_fname),
source_vocab,
target_vocab)
def compile_data(self, tmp_dir, pair_filenames, is_train=True):
"""
combine the input files
:param tmp_dir:
:param pair_filenames:
:param is_train:
:return:
"""
filename = "train.lang1" if is_train else "dev.lang1"
out_file_1 = open(os.path.join(tmp_dir, filename), "w")
filename = "train.lang2" if is_train else "dev.lang2"
out_file_2 = open(os.path.join(tmp_dir, filename), "w")
for file1, file2 in zip(pair_filenames[0], pair_filenames[1]):
for line in txt_line_iterator(os.path.join(tmp_dir, file1)):
out_file_1.write(line + "\n")
for line in txt_line_iterator(os.path.join(tmp_dir, file2)):
out_file_2.write(line + "\n")
out_file_2.close()
out_file_1.close()
class TokenVocabProblem(object):
"""token input"""
def __init__(self,
source_vocab_size=8000,
target_vocab_size=8000,
source_train_filenames="train.src",
target_train_filenames="train.tgt",
source_dev_filenames="dev.src",
target_dev_filenames="dev.tgt",
one_vocab=False):
"""
:param source_vocab_size:
:param target_vocab_size:
:param source_train_filenames:
:param target_train_filenames:
:param source_dev_filenames:
:param target_dev_filenames:
"""
self.source_vocab_size = source_vocab_size
self.target_vocab_size = target_vocab_size
self.source_vocab_name = "vocab.source.%s" % self.source_vocab_size
self.target_vocab_name = "vocab.target.%s" % self.target_vocab_size
if not isinstance(source_train_filenames, list):
source_train_filenames = [source_train_filenames]
if not isinstance(target_train_filenames, list):
target_train_filenames = [target_train_filenames]
if not isinstance(source_dev_filenames, list):
source_dev_filenames = [source_dev_filenames]
if not isinstance(target_dev_filenames, list):
target_dev_filenames = [target_dev_filenames]
self.source_train_filenames = source_train_filenames
self.target_train_filenames = target_train_filenames
self.source_dev_filenames = source_dev_filenames
self.target_dev_filenames = target_dev_filenames
self.one_vocab = one_vocab
def add_exsits_vocab(self, filename):
"""
:param filename
"""
token_list = []
with open(filename) as f:
for line in f:
line = line.strip()
token_list.append(line)
token_list.append("UNK")
return token_list
def generate_data(self, data_dir, tmp_dir, is_train=True):
"""
:param data_dir:
:param tmp_dir:
:return:
"""
self.source_train_ds = [["", self.source_train_filenames]]
self.target_train_ds = [["", self.target_train_filenames]]
pair_filenames = [self.source_train_filenames, self.target_train_filenames]
if not is_train:
pair_filenames = [self.source_dev_filenames, self.target_dev_filenames]
self.compile_data(tmp_dir, pair_filenames, is_train)
source_fname = "train.lang1" if is_train else "dev.lang1"
target_fname = "train.lang2" if is_train else "dev.lang2"
source_fname = os.path.join(tmp_dir, source_fname)
target_fname = os.path.join(tmp_dir, target_fname)
if is_train:
source_vocab_path = os.path.join(data_dir, self.source_vocab_name)
target_vocab_path = os.path.join(data_dir, self.target_vocab_name)
if not self.one_vocab:
if os.path.exists(source_vocab_path) and os.path.exists(target_vocab_path):
logging.info("found source vocab ...")
source_vocab = TokenTextEncoder(None, vocab_list=self.add_exsits_vocab(source_vocab_path))
logging.info("found target vocab ...")
target_vocab = TokenTextEncoder(None, vocab_list=self.add_exsits_vocab(target_vocab_path))
else:
logging.info("building source vocab ...")
source_vocab = TokenTextEncoder.build_from_corpus(source_fname,
self.source_vocab_size)
os.makedirs(data_dir)
logging.info("building target vocab ...")
target_vocab = TokenTextEncoder.build_from_corpus(target_fname,
self.target_vocab_size)
else:
if os.path.exists(source_vocab_path):
logging.info("found source vocab ...")
source_vocab = TokenTextEncoder(None, vocab_list=self.add_exsits_vocab(source_vocab_path))
else:
source_vocab = TokenTextEncoder.build_from_corpus([source_fname, target_fname],
self.source_vocab_size)
logging.info("building target vocab ...")
target_vocab = source_vocab
source_vocab.store_to_file(source_vocab_path)
target_vocab.store_to_file(target_vocab_path)
else:
source_vocab = TokenTextEncoder(os.path.join(data_dir, self.source_vocab_name))
target_vocab = TokenTextEncoder(os.path.join(data_dir, self.target_vocab_name))
return txt2txt_encoder(txt2txt_generator(source_fname, target_fname),
source_vocab,
target_vocab)
def compile_data(self, tmp_dir, pair_filenames, is_train=True):
"""
combine the input files
:param tmp_dir:
:param pair_filenames:
:param is_train:
:return:
"""
filename = "train.lang1" if is_train else "dev.lang1"
out_file_1 = open(os.path.join(tmp_dir, filename), "w")
filename = "train.lang2" if is_train else "dev.lang2"
out_file_2 = open(os.path.join(tmp_dir, filename), "w")
for file1, file2 in zip(pair_filenames[0], pair_filenames[1]):
for line in txt_line_iterator(os.path.join(tmp_dir, file1)):
out_file_1.write(line + "\n")
for line in txt_line_iterator(os.path.join(tmp_dir, file2)):
out_file_2.write(line + "\n")
out_file_2.close()
out_file_1.close()
if __name__ == "__main__":
gen_sub = GenSubword().generate_data("train_data", "../asr/")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
from text_encoder import SubwordTextEncoder
from text_encoder import EOS_ID
def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
sources, file_byte_budget=1e6):
"""Generate a vocabulary from the datasets in sources."""
def generate():
"""Generate lines for vocabulary generation."""
logging.info("Generating vocab from: %s", str(sources))
for source in sources:
for lang_file in source[1]:
logging.info("Reading file: %s" % lang_file)
filepath = os.path.join(tmp_dir, lang_file)
with open(filepath, mode="r") as source_file:
file_byte_budget_ = file_byte_budget
counter = 0
countermax = int(os.path.getsize(filepath) / file_byte_budget_ / 2)
logging.info("countermax: %d" % countermax)
for line in source_file:
if counter < countermax:
counter += 1
else:
if file_byte_budget_ <= 0:
break
line = line.strip()
file_byte_budget_ -= len(line)
counter = 0
yield line
return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generate())
def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generator, max_subtoken_length=None,
reserved_tokens=None):
"""Inner implementation for vocab generators.
Args:
data_dir: The base directory where data and vocab files are stored. If None,
then do not save the vocab even if it doesn't exist.
vocab_filename: relative filename where vocab file is stored
vocab_size: target size of the vocabulary constructed by SubwordTextEncoder
generator: a generator that produces tokens from the vocabulary
max_subtoken_length: an optional integer. Set this to a finite value to
avoid quadratic costs during vocab building.
reserved_tokens: List of reserved tokens. `text_encoder.RESERVED_TOKENS`
should be a prefix of `reserved_tokens`. If `None`, defaults to
`RESERVED_TOKENS`.
Returns:
A SubwordTextEncoder vocabulary object.
"""
if data_dir and vocab_filename:
vocab_filepath = os.path.join(data_dir, vocab_filename)
if os.path.exists(vocab_filepath):
logging.info("Found vocab file: %s", vocab_filepath)
return SubwordTextEncoder(vocab_filepath)
else:
vocab_filepath = None
logging.info("Generating vocab file: %s", vocab_filepath)
vocab = SubwordTextEncoder.build_from_generator(
generator, vocab_size, max_subtoken_length=max_subtoken_length,
reserved_tokens=reserved_tokens)
if vocab_filepath:
if not os.path.exists(data_dir):
os.makedirs(data_dir)
vocab.store_to_file(vocab_filepath)
return vocab
def txt_line_iterator(fname):
"""
generator for line
:param fname:
:return:
"""
with open(fname, 'r') as f:
for line in f:
yield line.strip()
def txt2txt_generator(source_fname, target_fname):
"""
:param source_fname:
:param target_fname:
:return:
"""
for source, target in zip(
txt_line_iterator(source_fname),
txt_line_iterator(target_fname)
):
yield {"inputs": source, "targets": target}
def txt2txt_encoder(sample_generator, vocab, target_vocab=None):
"""
:param sample_generator:
:param vocab:
:param target_vocab:
:return:
"""
target_vocab = target_vocab or vocab
for sample in sample_generator:
sample["inputs"] = vocab.encode(sample["inputs"])
sample["inputs"].append(EOS_ID)
sample["targets"] = target_vocab.encode(sample["targets"])
sample["targets"].append(EOS_ID)
yield sample
def txt_encoder(filename, batch_size=1, vocab=None):
"""
:param sample_generator:
:param vocab:
:return:
"""
def pad_mini_batch(batch):
"""
:param batch:
:return:
"""
lens = map(lambda x: len(x), batch)
max_len = max(lens)
for i in range(len(batch)):
batch[i] = batch[i] + [0] * (max_len - lens[i])
return batch
fp = open(filename, 'r')
samples = []
batches = []
ct = 0
for sample in fp:
sample = sample.strip()
if vocab:
sample = vocab.encode(sample)
else:
sample = [int(s) for s in sample]
#sample.append(EOS_ID)
batches.append(sample)
ct += 1
if ct % batch_size == 0:
batches = pad_mini_batch(batches)
samples.extend(batches)
batches = []
if ct % batch_size != 0:
batches += [batches[-1]] * (batch_size - ct % batch_size)
batches = pad_mini_batch(batches)
samples.extend(batches)
return samples
if __name__ == "__main__":
parser = argparse.ArgumentParser("Tips for generating testset")
parser.add_argument(
"--vocab",
type=str,
required=True,
help="The path of source vocab.")
parser.add_argument(
"--input",
type=str,
required=True,
help="The path of testset.")
parser.add_argument(
"--output",
type=str,
required=True,
help="The path of result.")
args = parser.parse_args()
subword = SubwordTextEncoder(args.vocab)
samples = []
with open(args.input, 'r') as f:
for line in f:
line = line.strip()
ids_list = [int(num) for num in line.split(" ")]
samples.append(ids_list)
with open(args.output, 'w') as f:
for sample in samples:
ret = subword.decode(sample)
f.write("%s\n" % ret)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import os
import re
import logging
from tokenizer import encode as tokenizer_encode
from tokenizer import decode as tokenizer_decode
from itertools import chain
import collections
PAD = "<pad>"
EOS = "<EOS>"
RESERVED_TOKENS = [PAD, EOS]
NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
if six.PY2:
RESERVED_TOKENS_BYTES = RESERVED_TOKENS
else:
RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
# Regular expression for unescaping token strings.
# '\u' is converted to '_'
# '\\' is converted to '\'
# '\213;' is converted to unichr(213)
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
_ESCAPE_CHARS = set(u"\\_u;0123456789")
def strip_ids(ids, ids_to_strip):
"""Strip ids_to_strip from the end ids."""
ids = list(ids)
while ids[-1] in ids_to_strip:
ids.pop()
return ids
def native_to_unicode(s):
"""
:param s:
:return:
"""
return s if is_unicode(s) else to_unicode(s)
def is_unicode(s):
"""
:param s:
:return:
"""
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def unicode_to_native(s):
"""
:param s:
:return:
"""
if six.PY2:
return s.encode("utf-8") if is_unicode(s) else s
else:
return s
def to_unicode(s, ignore_errors=False):
"""
:param s:
:param ignore_errors:
:return:
"""
if is_unicode(s):
"""
"""
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def _escape_token(token, alphabet):
"""Escape away underscores and OOV characters and append '_'.
This allows the token to be expressed as the concatenation of a list
of subtokens from the vocabulary. The underscore acts as a sentinel
which allows us to invertibly concatenate multiple such lists.
Args:
token: A unicode string to be escaped.
alphabet: A set of all characters in the vocabulary's alphabet.
Returns:
escaped_token: An escaped unicode string.
Raises:
ValueError: If the provided token is not unicode.
"""
if not isinstance(token, six.text_type):
raise ValueError("Expected string type for token, got %s" % type(token))
token = token.replace(u"\\", u"\\\\").replace(u"_", u"\\u")
ret = [c if c in alphabet and c != u"\n" else r"\%d;" % ord(c) for c in token]
return u"".join(ret) + "_"
def _unescape_token(escaped_token):
"""Inverse of _escape_token().
Args:
escaped_token: a unicode string
Returns:
token: a unicode string
"""
def match(m):
"""
:param m:
:return:
"""
if m.group(1) is None:
return u"_" if m.group(0) == u"\\u" else u"\\"
try:
return six.unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return u"\u3013" # Unicode for undefined character.
trimmed = escaped_token[:-1] if escaped_token.endswith("_") else escaped_token
return _UNESCAPE_REGEX.sub(match, trimmed)
class TextEncoder(object):
"""Base class for converting from ints to/from human readable strings."""
def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
self._num_reserved_ids = num_reserved_ids
@property
def num_reserved_ids(self):
"""
:return:
"""
return self._num_reserved_ids
def encode(self, s):
"""Transform a human-readable string into a sequence of int ids.
The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
num_reserved_ids) are reserved.
EOS is not appended.
Args:
s: human-readable string to be converted.
Returns:
ids: list of integers
"""
return [int(w) + self._num_reserved_ids for w in s.split()]
def decode(self, ids, strip_extraneous=False):
"""Transform a sequence of int ids into a human-readable string.
EOS is not expected in ids.
Args:
ids: list of integers to be converted.
strip_extraneous: bool, whether to strip off extraneous tokens
(EOS and PAD).
Returns:
s: human-readable string.
"""
if strip_extraneous:
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
return " ".join(self.decode_list(ids))
def decode_list(self, ids):
"""Transform a sequence of int ids into a their string versions.
This method supports transforming individual input/output ids to their
string versions so that sequence to/from text conversions can be visualized
in a human readable format.
Args:
ids: list of integers to be converted.
Returns:
strs: list of human-readable string.
"""
decoded_ids = []
for id_ in ids:
if 0 <= id_ < self._num_reserved_ids:
decoded_ids.append(RESERVED_TOKENS[int(id_)])
else:
decoded_ids.append(id_ - self._num_reserved_ids)
return [str(d) for d in decoded_ids]
@property
def vocab_size(self):
"""
:return:
"""
raise NotImplementedError()
class SubwordTextEncoder(TextEncoder):
"""Class for invertibly encoding text using a limited vocabulary.
Invertibly encodes a native string as a sequence of subtokens from a limited
vocabulary.
A SubwordTextEncoder is built from a corpus (so it is tailored to the text in
the corpus), and stored to a file. See text_encoder_build_subword.py.
It can then be loaded and used to encode/decode any text.
Encoding has four phases:
1. Tokenize into a list of tokens. Each token is a unicode string of either
all alphanumeric characters or all non-alphanumeric characters. We drop
tokens consisting of a single space that are between two alphanumeric
tokens.
2. Escape each token. This escapes away special and out-of-vocabulary
characters, and makes sure that each token ends with an underscore, and
has no other underscores.
3. Represent each escaped token as a the concatenation of a list of subtokens
from the limited vocabulary. Subtoken selection is done greedily from
beginning to end. That is, we construct the list in order, always picking
the longest subtoken in our vocabulary that matches a prefix of the
remaining portion of the encoded token.
4. Concatenate these lists. This concatenation is invertible due to the
fact that the trailing underscores indicate when one list is finished.
"""
def __init__(self, filename=None):
"""Initialize and read from a file, if provided.
Args:
filename: filename from which to read vocab. If None, do not load a
vocab
"""
self._alphabet = set()
self.filename = filename
if filename is not None:
self._load_from_file(filename)
super(SubwordTextEncoder, self).__init__(num_reserved_ids=None)
def encode(self, s):
"""Converts a native string to a list of subtoken ids.
Args:
s: a native string.
Returns:
a list of integers in the range [0, vocab_size)
"""
return self._tokens_to_subtoken_ids(
tokenizer_encode(native_to_unicode(s)))
def encode_without_tokenizing(self, token_text):
"""Converts string to list of subtoken ids without calling tokenizer.
This treats `token_text` as a single token and directly converts it
to subtoken ids. This may be useful when the default tokenizer doesn't
do what we want (e.g., when encoding text with tokens composed of lots of
nonalphanumeric characters). It is then up to the caller to make sure that
raw text is consistently converted into tokens. Only use this if you are
sure that `encode` doesn't suit your needs.
Args:
token_text: A native string representation of a single token.
Returns:
A list of subword token ids; i.e., integers in the range [0, vocab_size).
"""
return self._tokens_to_subtoken_ids([native_to_unicode(token_text)])
def decode(self, ids, strip_extraneous=False):
"""Converts a sequence of subtoken ids to a native string.
Args:
ids: a list of integers in the range [0, vocab_size)
strip_extraneous: bool, whether to strip off extraneous tokens
(EOS and PAD).
Returns:
a native string
"""
if strip_extraneous:
ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
return unicode_to_native(
tokenizer_decode(self._subtoken_ids_to_tokens(ids)))
def decode_list(self, ids):
"""
:param ids:
:return:
"""
return [self._subtoken_id_to_subtoken_string(s) for s in ids]
@property
def vocab_size(self):
"""The subtoken vocabulary size."""
return len(self._all_subtoken_strings)
def _tokens_to_subtoken_ids(self, tokens):
"""Converts a list of tokens to a list of subtoken ids.
Args:
tokens: a list of strings.
Returns:
a list of integers in the range [0, vocab_size)
"""
ret = []
for token in tokens:
ret.extend(self._token_to_subtoken_ids(token))
return ret
def _token_to_subtoken_ids(self, token):
"""Converts token to a list of subtoken ids.
Args:
token: a string.
Returns:
a list of integers in the range [0, vocab_size)
"""
cache_location = hash(token) % self._cache_size
cache_key, cache_value = self._cache[cache_location]
if cache_key == token:
return cache_value
ret = self._escaped_token_to_subtoken_ids(
_escape_token(token, self._alphabet))
self._cache[cache_location] = (token, ret)
return ret
def _subtoken_ids_to_tokens(self, subtokens):
"""Converts a list of subtoken ids to a list of tokens.
Args:
subtokens: a list of integers in the range [0, vocab_size)
Returns:
a list of strings.
"""
concatenated = "".join(
[self._subtoken_id_to_subtoken_string(s) for s in subtokens])
split = concatenated.split("_")
ret = []
for t in split:
if t:
unescaped = _unescape_token(t + "_")
if unescaped:
ret.append(unescaped)
return ret
def _subtoken_id_to_subtoken_string(self, subtoken):
"""Converts a subtoken integer ID to a subtoken string."""
if 0 <= subtoken < self.vocab_size:
return self._all_subtoken_strings[subtoken]
return u""
def _escaped_token_to_subtoken_strings(self, escaped_token):
"""Converts an escaped token string to a list of subtoken strings.
Args:
escaped_token: An escaped token as a unicode string.
Returns:
A list of subtokens as unicode strings.
"""
# NOTE: This algorithm is greedy; it won't necessarily produce the "best"
# list of subtokens.
ret = []
start = 0
token_len = len(escaped_token)
while start < token_len:
for end in range(
min(token_len, start + self._max_subtoken_len), start, -1):
subtoken = escaped_token[start:end]
if subtoken in self._subtoken_string_to_id:
ret.append(subtoken)
start = end
break
else: # Did not break
# If there is no possible encoding of the escaped token then one of the
# characters in the token is not in the alphabet. This should be
# impossible and would be indicative of a bug.
assert False, "Token substring not found in subtoken vocabulary."
return ret
def _escaped_token_to_subtoken_ids(self, escaped_token):
"""Converts an escaped token string to a list of subtoken IDs.
Args:
escaped_token: An escaped token as a unicode string.
Returns:
A list of subtoken IDs as integers.
"""
return [
self._subtoken_string_to_id[subtoken]
for subtoken in self._escaped_token_to_subtoken_strings(escaped_token)
]
@classmethod
def build_from_generator(cls,
generator,
target_vocab_size,
max_subtoken_length=None,
reserved_tokens=None):
"""Builds a SubwordTextEncoder from the generated text.
Args:
generator: yields text.
target_vocab_size: int, approximate vocabulary size to create.
max_subtoken_length: Maximum length of a subtoken. If this is not set,
then the runtime and memory use of creating the vocab is quadratic in
the length of the longest token. If this is set, then it is instead
O(max_subtoken_length * length of longest token).
reserved_tokens: List of reserved tokens. The global variable
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this
argument is `None`, it will use `RESERVED_TOKENS`.
Returns:
SubwordTextEncoder with `vocab_size` approximately `target_vocab_size`.
"""
token_counts = collections.defaultdict(int)
for item in generator:
for tok in tokenizer_encode(native_to_unicode(item)):
token_counts[tok] += 1
encoder = cls.build_to_target_size(
target_vocab_size, token_counts, 1, 1e3,
max_subtoken_length=max_subtoken_length,
reserved_tokens=reserved_tokens)
return encoder
@classmethod
def build_to_target_size(cls,
target_size,
token_counts,
min_val,
max_val,
max_subtoken_length=None,
reserved_tokens=None,
num_iterations=4):
"""Builds a SubwordTextEncoder that has `vocab_size` near `target_size`.
Uses simple recursive binary search to find a minimum token count that most
closely matches the `target_size`.
Args:
target_size: Desired vocab_size to approximate.
token_counts: A dictionary of token counts, mapping string to int.
min_val: An integer; lower bound for the minimum token count.
max_val: An integer; upper bound for the minimum token count.
max_subtoken_length: Maximum length of a subtoken. If this is not set,
then the runtime and memory use of creating the vocab is quadratic in
the length of the longest token. If this is set, then it is instead
O(max_subtoken_length * length of longest token).
reserved_tokens: List of reserved tokens. The global variable
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this
argument is `None`, it will use `RESERVED_TOKENS`.
num_iterations: An integer; how many iterations of refinement.
Returns:
A SubwordTextEncoder instance.
Raises:
ValueError: If `min_val` is greater than `max_val`.
"""
if min_val > max_val:
raise ValueError("Lower bound for the minimum token count "
"is greater than the upper bound.")
if target_size < 1:
raise ValueError("Target size must be positive.")
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
def bisect(min_val, max_val):
"""Bisection to find the right size."""
present_count = (max_val + min_val) // 2
logging.info("Trying min_count %d" % present_count)
subtokenizer = cls()
subtokenizer.build_from_token_counts(
token_counts, present_count, num_iterations,
max_subtoken_length=max_subtoken_length,
reserved_tokens=reserved_tokens)
# Being within 1% of the target size is ok.
is_ok = abs(subtokenizer.vocab_size - target_size) * 100 < target_size
# If min_val == max_val, we can't do any better than this.
if is_ok or min_val >= max_val or present_count < 2:
return subtokenizer
if subtokenizer.vocab_size > target_size:
other_subtokenizer = bisect(present_count + 1, max_val)
else:
other_subtokenizer = bisect(min_val, present_count - 1)
if other_subtokenizer is None:
return subtokenizer
if (abs(other_subtokenizer.vocab_size - target_size) <
abs(subtokenizer.vocab_size - target_size)):
return other_subtokenizer
return subtokenizer
return bisect(min_val, max_val)
def build_from_token_counts(self,
token_counts,
min_count,
num_iterations=4,
reserved_tokens=None,
max_subtoken_length=None):
"""Train a SubwordTextEncoder based on a dictionary of word counts.
Args:
token_counts: a dictionary of Unicode strings to int.
min_count: an integer - discard subtokens with lower counts.
num_iterations: an integer. how many iterations of refinement.
reserved_tokens: List of reserved tokens. The global variable
`RESERVED_TOKENS` must be a prefix of `reserved_tokens`. If this
argument is `None`, it will use `RESERVED_TOKENS`.
max_subtoken_length: Maximum length of a subtoken. If this is not set,
then the runtime and memory use of creating the vocab is quadratic in
the length of the longest token. If this is set, then it is instead
O(max_subtoken_length * length of longest token).
Raises:
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it
is not clear what the space is being reserved for, or when it will be
filled in.
"""
if reserved_tokens is None:
reserved_tokens = RESERVED_TOKENS
else:
# There is not complete freedom in replacing RESERVED_TOKENS.
for default, proposed in zip(RESERVED_TOKENS, reserved_tokens):
if default != proposed:
raise ValueError("RESERVED_TOKENS must be a prefix of "
"reserved_tokens.")
# Initialize the alphabet. Note, this must include reserved tokens or it can
# result in encoding failures.
alphabet_tokens = chain(six.iterkeys(token_counts),
[native_to_unicode(t) for t in reserved_tokens])
self._init_alphabet_from_tokens(alphabet_tokens)
# Bootstrap the initial list of subtokens with the characters from the
# alphabet plus the escaping characters.
self._init_subtokens_from_list(list(self._alphabet),
reserved_tokens=reserved_tokens)
# We build iteratively. On each iteration, we segment all the words,
# then count the resulting potential subtokens, keeping the ones
# with high enough counts for our new vocabulary.
if min_count < 1:
min_count = 1
for i in range(num_iterations):
logging.info("Iteration {0}".format(i))
# Collect all substrings of the encoded token that break along current
# subtoken boundaries.
subtoken_counts = collections.defaultdict(int)
for token, count in six.iteritems(token_counts):
escaped_token = _escape_token(token, self._alphabet)
subtokens = self._escaped_token_to_subtoken_strings(escaped_token)
start = 0
for subtoken in subtokens:
last_position = len(escaped_token) + 1
if max_subtoken_length is not None:
last_position = min(last_position, start + max_subtoken_length)
for end in range(start + 1, last_position):
new_subtoken = escaped_token[start:end]
subtoken_counts[new_subtoken] += count
start += len(subtoken)
# Array of sets of candidate subtoken strings, by length.
len_to_subtoken_strings = []
for subtoken_string, count in six.iteritems(subtoken_counts):
lsub = len(subtoken_string)
if count >= min_count:
while len(len_to_subtoken_strings) <= lsub:
len_to_subtoken_strings.append(set())
len_to_subtoken_strings[lsub].add(subtoken_string)
# Consider the candidates longest to shortest, so that if we accept
# a longer subtoken string, we can decrement the counts of its prefixes.
new_subtoken_strings = []
for lsub in range(len(len_to_subtoken_strings) - 1, 0, -1):
subtoken_strings = len_to_subtoken_strings[lsub]
for subtoken_string in subtoken_strings:
count = subtoken_counts[subtoken_string]
if count >= min_count:
# Exclude alphabet tokens here, as they must be included later,
# explicitly, regardless of count.
if subtoken_string not in self._alphabet:
new_subtoken_strings.append((count, subtoken_string))
for l in range(1, lsub):
subtoken_counts[subtoken_string[:l]] -= count
# Include the alphabet explicitly to guarantee all strings are encodable.
new_subtoken_strings.extend((subtoken_counts.get(a, 0), a)
for a in self._alphabet)
new_subtoken_strings.sort(reverse=True)
# Reinitialize to the candidate vocabulary.
new_subtoken_strings = [subtoken for _, subtoken in new_subtoken_strings]
if reserved_tokens:
new_subtoken_strings = reserved_tokens + new_subtoken_strings
self._init_subtokens_from_list(new_subtoken_strings)
logging.info("vocab_size = %d" % self.vocab_size)
@property
def all_subtoken_strings(self):
"""
:return:
"""
return tuple(self._all_subtoken_strings)
def dump(self):
"""Debugging dump of the current subtoken vocabulary."""
subtoken_strings = [(i, s)
for s, i in six.iteritems(self._subtoken_string_to_id)]
print(u", ".join(u"{0} : '{1}'".format(i, s)
for i, s in sorted(subtoken_strings)))
def _init_subtokens_from_list(self, subtoken_strings, reserved_tokens=None):
"""Initialize token information from a list of subtoken strings.
Args:
subtoken_strings: a list of subtokens
reserved_tokens: List of reserved tokens. We must have `reserved_tokens`
as None or the empty list, or else the global variable `RESERVED_TOKENS`
must be a prefix of `reserved_tokens`.
Raises:
ValueError: if reserved is not 0 or len(RESERVED_TOKENS). In this case, it
is not clear what the space is being reserved for, or when it will be
filled in.
"""
if reserved_tokens is None:
reserved_tokens = []
if reserved_tokens:
self._all_subtoken_strings = reserved_tokens + subtoken_strings
else:
self._all_subtoken_strings = subtoken_strings
# we remember the maximum length of any subtoken to avoid having to
# check arbitrarily long strings.
self._max_subtoken_len = max([len(s) for s in subtoken_strings])
self._subtoken_string_to_id = {
s: i + len(reserved_tokens)
for i, s in enumerate(subtoken_strings) if s
}
# Initialize the cache to empty.
self._cache_size = 2 ** 20
self._cache = [(None, None)] * self._cache_size
def _init_alphabet_from_tokens(self, tokens):
"""Initialize alphabet from an iterable of token or subtoken strings."""
# Include all characters from all tokens in the alphabet to guarantee that
# any token can be encoded. Additionally, include all escaping characters.
self._alphabet = {c for token in tokens for c in token}
self._alphabet |= _ESCAPE_CHARS
def _load_from_file_object(self, f):
"""Load from a file object.
Args:
f: File object to load vocabulary from
"""
subtoken_strings = []
for line in f:
s = line.strip()
# Some vocab files wrap words in single quotes, but others don't
if ((s.startswith("'") and s.endswith("'")) or
(s.startswith("\"") and s.endswith("\""))):
s = s[1:-1]
subtoken_strings.append(native_to_unicode(s))
self._init_subtokens_from_list(subtoken_strings)
self._init_alphabet_from_tokens(subtoken_strings)
def _load_from_file(self, filename):
"""Load from a vocab file."""
if not os.path.exists(filename):
raise ValueError("File %s not found" % filename)
with open(filename, 'r') as f:
self._load_from_file_object(f)
def store_to_file(self, filename, add_single_quotes=True):
"""
:param filename:
:param add_single_quotes:
:return:
"""
with open(filename, "w") as f:
for subtoken_string in self._all_subtoken_strings:
if add_single_quotes:
f.write("'" + unicode_to_native(subtoken_string) + "'\n")
else:
f.write(unicode_to_native(subtoken_string) + "\n")
class TokenTextEncoder(TextEncoder):
"""Encoder based on a user-supplied vocabulary (file or list)."""
def __init__(self,
vocab_filename,
reverse=False,
vocab_list=None,
replace_oov="UNK",
num_reserved_ids=NUM_RESERVED_TOKENS):
"""Initialize from a file or list, one token per line.
Handling of reserved tokens works as follows:
- When initializing from a list, we add reserved tokens to the vocab.
- When initializing from a file, we do not add reserved tokens to the vocab.
- When saving vocab files, we save reserved tokens to the file.
Args:
vocab_filename: If not None, the full filename to read vocab from. If this
is not None, then vocab_list should be None.
reverse: Boolean indicating if tokens should be reversed during encoding
and decoding.
vocab_list: If not None, a list of elements of the vocabulary. If this is
not None, then vocab_filename should be None.
replace_oov: If not None, every out-of-vocabulary token seen when
encoding will be replaced by this string (which must be in vocab).
num_reserved_ids: Number of IDs to save for reserved tokens like <EOS>.
"""
super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
self._reverse = reverse
self._replace_oov = replace_oov
if vocab_filename:
self._init_vocab_from_file(vocab_filename)
else:
assert vocab_list is not None
self._init_vocab_from_list(vocab_list)
@classmethod
def build_from_corpus(cls, filenames, vocab_size):
"""
:param filenames:
:param vocab_size:
:return:
"""
def create_dictionary(names, lim=0):
"""
:param name:
:param lim:
:return:
"""
global_counter = collections.Counter()
for name in names:
fd = open(name)
for line in fd:
words = line.strip().split()
words = filter(lambda x: x != "-1", words)
global_counter.update(words)
if lim <= 2:
lim = len(global_counter) + 3
vocab_count = global_counter.most_common(lim - 3)
total_counts = sum(global_counter.values())
coverage = 100.0 * sum([count for word, count in vocab_count]) / total_counts
logging.info("coverage: %s" % coverage)
vocab_table = ["<pad>", "<EOS>"]
for i, (word, count) in enumerate(vocab_count):
vocab_table.append(word)
vocab_table.append("UNK")
return vocab_table
if not isinstance(filenames, list): filenames = [filenames]
vocab = cls(None,
vocab_list=create_dictionary(filenames, vocab_size),
replace_oov="UNK")
return vocab
def encode(self, s):
"""Converts a space-separated string of tokens to a list of ids."""
sentence = s
tokens = sentence.strip().split()
if self._replace_oov is not None:
tokens = [t if t in self._token_to_id else self._replace_oov
for t in tokens]
ret = [self._token_to_id[tok] for tok in tokens]
return ret[::-1] if self._reverse else ret
def decode(self, ids, strip_extraneous=False):
"""
:param ids:
:param strip_extraneous:
:return:
"""
return " ".join(self.decode_list(ids))
def decode_list(self, ids):
"""
:param ids:
:return:
"""
seq = reversed(ids) if self._reverse else ids
return [self._safe_id_to_token(i) for i in seq]
@property
def vocab_size(self):
"""
:return:
"""
return len(self._id_to_token)
def _safe_id_to_token(self, idx):
"""
:param idx:
:return:
"""
return self._id_to_token.get(idx, "ID_%d" % idx)
def _init_vocab_from_file(self, filename):
"""Load vocab from a file.
Args:
filename: The file to load vocabulary from.
"""
with open(filename, 'r') as f:
tokens = [token.strip() for token in f.readlines()]
def token_gen():
"""token gen"""
for token in tokens:
yield token
self._init_vocab(token_gen(), add_reserved_tokens=False)
def _init_vocab_from_list(self, vocab_list):
"""Initialize tokens from a list of tokens.
It is ok if reserved tokens appear in the vocab list. They will be
removed. The set of tokens in vocab_list should be unique.
Args:
vocab_list: A list of tokens.
"""
def token_gen():
"""token gen"""
for token in vocab_list:
if token not in RESERVED_TOKENS:
yield token
self._init_vocab(token_gen())
def _init_vocab(self, token_generator, add_reserved_tokens=True):
"""Initialize vocabulary with tokens from token_generator."""
self._id_to_token = {}
non_reserved_start_index = 0
if add_reserved_tokens:
self._id_to_token.update(enumerate(RESERVED_TOKENS))
non_reserved_start_index = len(RESERVED_TOKENS)
self._id_to_token.update(
enumerate(token_generator, start=non_reserved_start_index))
# _token_to_id is the reverse of _id_to_token
self._token_to_id = dict((v, k)
for k, v in six.iteritems(self._id_to_token))
def store_to_file(self, filename):
"""Write vocab file to disk.
Vocab files have one token per line. The file ends in a newline. Reserved
tokens are written to the vocab file as well.
Args:
filename: Full path of the file to store the vocab to.
"""
with open(filename, "w") as f:
for i in range(len(self._id_to_token)):
f.write(self._id_to_token[i] + "\n")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import sys
import glob
import unicodedata
import six
import logging
from six.moves import range # pylint: disable=redefined-builtin
# Conversion between Unicode and UTF-8, if required (on Python2)
_native_to_unicode = (lambda s: s.decode("utf-8")) if six.PY2 else (lambda s: s)
# This set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set(
six.unichr(i) for i in range(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))
def encode(text):
"""Encode a unicode string as a list of tokens.
Args:
text: a unicode string
Returns:
a list of tokens as Unicode strings
"""
if not text:
return []
ret = []
token_start = 0
# Classify each character in the input string
is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text]
for pos in range(1, len(text)):
if is_alnum[pos] != is_alnum[pos - 1]:
token = text[token_start:pos]
if token != u" " or token_start == 0:
ret.append(token)
token_start = pos
final_token = text[token_start:]
ret.append(final_token)
return ret
def decode(tokens):
"""Decode a list of tokens to a unicode string.
Args:
tokens: a list of Unicode strings
Returns:
a unicode string
"""
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens]
ret = []
for i, token in enumerate(tokens):
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]:
ret.append(u" ")
ret.append(token)
return "".join(ret)
def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True):
"""Reads files matching a wildcard pattern, yielding the contents.
Args:
filepattern: A wildcard pattern matching one or more files.
max_lines: If set, stop reading after reading this many lines.
split_on_newlines: A boolean. If true, then split files by lines and strip
leading and trailing whitespace from each line. Otherwise, treat each
file as a single string.
Yields:
The contents of the files as lines, if split_on_newlines is True, or
the entire contents of each file if False.
"""
filenames = sorted(glob.glob(filepattern))
lines_read = 0
for filename in filenames:
with open(filename, 'r') as f:
if split_on_newlines:
for line in f:
yield line.strip()
lines_read += 1
if max_lines and lines_read >= max_lines:
return
else:
if max_lines:
doc = []
for line in f:
doc.append(line)
lines_read += 1
if max_lines and lines_read >= max_lines:
yield "".join(doc)
return
yield "".join(doc)
else:
yield f.read()
def corpus_token_counts(
text_filepattern, corpus_max_lines, split_on_newlines=True):
"""Read the corpus and compute a dictionary of token counts.
Args:
text_filepattern: A pattern matching one or more files.
corpus_max_lines: An integer; maximum total lines to read.
split_on_newlines: A boolean. If true, then split files by lines and strip
leading and trailing whitespace from each line. Otherwise, treat each
file as a single string.
Returns:
a dictionary mapping token to count.
"""
counts = collections.Counter()
for doc in _read_filepattern(
text_filepattern,
max_lines=corpus_max_lines,
split_on_newlines=split_on_newlines):
counts.update(encode(_native_to_unicode(doc)))
return counts
def vocab_token_counts(text_filepattern, max_lines):
"""Read a vocab file and return a dictionary of token counts.
Reads a two-column CSV file of tokens and their frequency in a dataset. The
tokens are presumed to be generated by encode() or the equivalent.
Args:
text_filepattern: A pattern matching one or more files.
max_lines: An integer; maximum total lines to read.
Returns:
a dictionary mapping token to count.
"""
ret = {}
for i, line in enumerate(
_read_filepattern(text_filepattern, max_lines=max_lines)):
if "," not in line:
logging.warning("Malformed vocab line #%d '%s'", i, line)
continue
token, count = line.rsplit(",", 1)
ret[_native_to_unicode(token)] = int(count)
return ret
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import six
import os
import tarfile
import random
import numpy as np
from collections import defaultdict
def batching_scheme(batch_size,
max_length,
min_length_bucket=8,
length_bucket_step=1.1,
drop_long_sequences=False,
shard_multiplier=1,
length_multiplier=1,
min_length=0):
"""A batching scheme based on model hyperparameters.
Every batch containins a number of sequences divisible by `shard_multiplier`.
Args:
batch_size: int, total number of tokens in a batch.
max_length: int, sequences longer than this will be skipped. Defaults to
batch_size.
min_length_bucket: int
length_bucket_step: float greater than 1.0
drop_long_sequences: bool, if True, then sequences longer than
`max_length` are dropped. This prevents generating batches with
more than the usual number of tokens, which can cause out-of-memory
errors.
shard_multiplier: an integer increasing the batch_size to suit splitting
across datashards.
length_multiplier: an integer multiplier that is used to increase the
batch sizes and sequence length tolerance.
min_length: int, sequences shorter than this will be skipped.
Returns:
A dictionary with parameters that can be passed to input_pipeline:
* boundaries: list of bucket boundaries
* batch_sizes: list of batch sizes for each length bucket
* max_length: int, maximum length of an example
Raises:
ValueError: If min_length > max_length
"""
def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1):
assert length_bucket_step > 1.0
x = min_length
boundaries = []
while x < max_length:
boundaries.append(x)
x = max(x + 1, int(x * length_bucket_step))
return boundaries
max_length = max_length or batch_size
if max_length < min_length:
raise ValueError("max_length must be greater or equal to min_length")
boundaries = _bucket_boundaries(max_length, min_length_bucket,
length_bucket_step)
boundaries = [boundary * length_multiplier for boundary in boundaries]
max_length *= length_multiplier
batch_sizes = [
max(1, batch_size // length) for length in boundaries + [max_length]
]
max_batch_size = max(batch_sizes)
# Since the Datasets API only allows a single constant for window_size,
# and it needs divide all bucket_batch_sizes, we pick a highly-compoisite
# window size and then round down all batch sizes to divisors of that window
# size, so that a window can always be divided evenly into batches.
# TODO(noam): remove this when Dataset API improves.
highly_composite_numbers = [
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680,
2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 50400, 55440,
83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280,
720720, 1081080, 1441440, 2162160, 2882880, 3603600, 4324320, 6486480,
7207200, 8648640, 10810800, 14414400, 17297280, 21621600, 32432400,
36756720, 43243200, 61261200, 73513440, 110270160
]
window_size = max(
[i for i in highly_composite_numbers if i <= 3 * max_batch_size])
divisors = [i for i in xrange(1, window_size + 1) if window_size % i == 0]
batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes]
window_size *= shard_multiplier
batch_sizes = [bs * shard_multiplier for bs in batch_sizes]
# The Datasets API splits one window into multiple batches, which
# produces runs of many consecutive batches of the same size. This
# is bad for training. To solve this, we will shuffle the batches
# using a queue which must be several times as large as the maximum
# number of batches per window.
max_batches_per_window = window_size // min(batch_sizes)
shuffle_queue_size = max_batches_per_window * 3
ret = {
"boundaries": boundaries,
"batch_sizes": batch_sizes,
"min_length": min_length,
"max_length": (max_length if drop_long_sequences else 10 ** 9),
"shuffle_queue_size": shuffle_queue_size,
}
return ret
def bucket_by_sequence_length(data_reader,
example_length_fn,
bucket_boundaries,
bucket_batch_sizes,
trainer_nums,
trainer_id):
"""Bucket entries in dataset by length.
Args:
dataset: Dataset of dict<feature name, Tensor>.
example_length_fn: function from example to int, determines the length of
the example, which will determine the bucket it goes into.
bucket_boundaries: list<int>, boundaries of the buckets.
bucket_batch_sizes: list<int>, batch size per bucket.
Returns:
Dataset of padded and batched examples.
"""
def example_to_bucket_id(example):
"""
get bucket_id
"""
seq_length = example_length_fn(example)
boundaries = list(bucket_boundaries)
buckets_min = [np.iinfo(np.int32).min] + boundaries
buckets_max = boundaries + [np.iinfo(np.int32).max]
for i in range(len(buckets_min)):
if buckets_min[i] <= seq_length and seq_length < buckets_max[i]:
bucket_id = i
return bucket_id
def window_size_fn(bucket_id):
"""
get window size
"""
window_size = bucket_batch_sizes[bucket_id]
return window_size
def group_by_window(reader, key_func, window_size_func, drop_last=False):
"""
group the line by length
"""
groups = defaultdict(list)
def impl():
"""
impl
"""
for e in reader():
key = key_func(e)
window_size = window_size_func(key)
groups[key].append(e)
if len(groups[key]) == window_size:
each_size = window_size / trainer_nums
res = groups[key][trainer_id * each_size: (trainer_id + 1) * each_size]
yield res
groups[key] = []
if drop_last:
groups.clear()
return impl
reader = group_by_window(data_reader, example_to_bucket_id, window_size_fn)
return reader
def shuffle(reader, buf_size):
"""
Creates a data reader whose data output is shuffled.
Output from the iterator that created by original reader will be
buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
is determined by argument buf_size.
:param reader: the original reader whose output will be shuffled.
:type reader: callable
:param buf_size: shuffle buffer size.
:type buf_size: int
:return: the new reader whose output is shuffled.
:rtype: callable
"""
def data_reader():
"""
data_reader
"""
buf = []
for e in reader():
buf.append(e)
if len(buf) >= buf_size:
random.shuffle(buf)
for b in buf:
yield b
buf = []
if len(buf) > 0:
random.shuffle(buf)
for b in buf:
yield b
return data_reader
def sort(reader, buf_size, cmp=None, key=None, reverse=False):
"""
Creates a data reader whose data output is sorted.
Output from the iterator that created by original reader will be
buffered into sort buffer, and then sorted. The size of sort buffer
is determined by argument buf_size.
:param reader: the original reader whose output will be sorted.
:type reader: callable
:param buf_size: shuffle buffer size.
:type buf_size: int
:return: the new reader whose output is sorted.
:rtype: callable
"""
def data_reader():
"""
data_reader
"""
buf = []
for e in reader():
buf.append(e)
if len(buf) >= buf_size:
buf = sorted(buf, cmp, key, reverse)
for b in buf:
yield b
buf = []
if len(buf) > 0:
sorted(buf, cmp, key, reverse)
for b in buf:
yield b
return data_reader
def batch_by_token(reader, batch_size, len_fun, drop_last=False):
"""
Create a batched reader.
:param reader: the data reader to read from.
:type reader: callable
:param batch_size: size of each mini-batch
:type batch_size: int
:param drop_last: drop the last batch, if the size of last batch is not equal to batch_size.
:type drop_last: bool
:return: the batched reader.
:rtype: callable
"""
def batch_reader():
"""
batch_reader
"""
r = reader()
b = []
max_len = 0
for instance in r:
cur_len = len_fun(instance)
max_len = max(max_len, cur_len)
if max_len * (len(b) + 1) > batch_size:
yield b
b = [instance]
max_len = cur_len
else:
b.append(instance)
if drop_last == False and len(b) != 0:
yield b
# Batch size check
batch_size = int(batch_size)
if batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
return batch_reader
def parse_line(line, max_len, min_len=0, field_delimiter="\t", token_delimiter=" "):
"""
parse training data
"""
src, trg = line.strip("\n").split(field_delimiter)
src_ids = [int(token) for token in src.split(token_delimiter)]
trg_ids = [int(token) for token in trg.split(token_delimiter)]
reverse_trg_ids = trg_ids[::-1]
reverse_trg_ids = reverse_trg_ids[1:]
reverse_trg_ids.append(1)
inst_max_len = max(len(src_ids), len(trg_ids))
inst_min_len = min(len(src_ids), len(trg_ids))
if inst_max_len <= max_len and inst_min_len > min_len:
return src_ids, [0] + trg_ids[:-1], trg_ids, [0] + reverse_trg_ids[:-1], reverse_trg_ids
else:
return None
def repeat(reader, count=-1):
"""
repeat
"""
def data_reader():
"""
repeat data
"""
time = count
while time != 0:
for e in reader():
yield e
time -= 1
return data_reader
def parse_src_line(line, max_len, min_len=0, token_delimiter=" "):
"""
parse infer data
"""
src = line.strip("\n")
src_ids = [int(token) for token in src.split(token_delimiter)]
inst_max_len = inst_min_len = len(src_ids)
if inst_max_len < max_len and inst_min_len > min_len:
src_ids.append(1)
return [src_ids]
else:
src_ids = src_ids[:max_len - 1]
src_ids.append(1)
return [src_ids]
def interleave_reader(fpattern, cycle_length, block_length=1, **kwargs):
"""
cycle reader
"""
# refer to:
# https://www.tensorflow.org/api_docs/python/tf/contrib/data/parallel_interleave?hl=zh_cn
# https://www.tensorflow.org/api_docs/python/tf/data/Dataset?hl=zh_cn#interleave
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths)
if 'parse_line' in kwargs:
parse_line = kwargs.pop('parse_line')
class Worker(object): # mimic a worker thread
"""
each worker wrap a file
"""
def __init__(self):
self.input = None
self.iter = None
def set_input(self, input_arg):
"""
set file reader
"""
if self.iter is not None:
self.iter.close()
self.input = input_arg
self.iter = open(input_arg, 'rb')
def get_next(self):
"""
get next data
"""
return next(self.iter)
def data_reader():
"""
generate data
"""
num_workers = cycle_length # + prefetched
workers = []
# Indices in `workers` of iterators to interleave.
interleave_indices = []
# Indices in `workers` of prefetched iterators.
staging_indices = []
# EnsureWorkerThreadsStarted
for i in range(num_workers):
if i >= len(fpaths):
break
workers.append(Worker())
workers[i].set_input(fpaths[i])
if i < cycle_length:
interleave_indices.append(i)
else:
staging_indices.append(i)
input_index = len(workers) # index for files
next_index = 0 # index for worker
block_count = 0 # counter for the number of instances from one block
#
while True: # break while when all inputs end
can_produce_elements = False
# The for loop only fully runs when all workers ending.
# Otherwise, run one step then break the for loop, or
# find the first possible unended iterator by setting next_index
# or go to the step of loop.
for i in range(len(interleave_indices)):
index = (next_index + i) % len(interleave_indices)
current_worker_index = interleave_indices[index]
current_worker = workers[current_worker_index]
try:
line = current_worker.get_next()
if six.PY3:
line = line.decode()
inst = parse_line(line, **kwargs)
if inst is not None:
yield inst
next_index = index
block_count += 1
if block_count == block_length:
# advance to the next iterator
next_index = (index + 1) % len(interleave_indices)
block_count = 0
can_produce_elements = True
break
except (StopIteration,): # This iterator has reached the end.
if input_index < len(fpaths): # get a new iterator and skip
current_worker.set_input(fpaths[input_index])
staging_indices.append(current_worker_index)
if len(staging_indices) > 0: # pop_front
interleave_indices[index] = staging_indices[0]
staging_indices = staging_indices[1:]
input_index += 1
# advance to the next iterator
next_index = (index + 1) % len(interleave_indices)
block_count = 0
can_produce_elements = True
break
# else: advance to the next iterator by loop step
if not can_produce_elements:
# all inputs end, triggered when all iterators have reached the end
break
return data_reader
def line_reader(fpattern, batch_size, dev_count, **kwargs):
"""
cycle reader
"""
fpaths = glob.glob(fpattern)
#np.random.shuffle(fpaths)
#random.shuffle(fpaths)
if "parse_line" in kwargs:
parse_line = kwargs.pop('parse_line')
def data_reader():
"""
data_reader
"""
res = []
total_size = batch_size * dev_count
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f:
if six.PY3:
line = line.decode()
inst = parse_line(line, **kwargs)
res.append(inst)
if len(res) == total_size:
yield res
res = []
if len(res) > 0:
pad_count = total_size - len(res)
for index in xrange(pad_count):
res.append(res[-1])
yield res
return data_reader
def prepare_data_generator(args, is_test, count, pyreader, batch_size=None,
data_reader=None, py_reader_provider_wrapper=None):
"""
Data generator wrapper for DataReader. If use py_reader, set the data
provider for py_reader
"""
def stack(data_reader, count, clip_last=True):
"""
Data generator for multi-devices
"""
def __impl__():
res = []
for item in data_reader():
res.append(item)
if len(res) == count:
yield res
res = []
if len(res) == count:
yield res
elif not clip_last:
data = []
for item in res:
data += item
if len(data) > count:
inst_num_per_part = len(data) // count
yield [
data[inst_num_per_part * i:inst_num_per_part * (i + 1)]
for i in range(count)
]
return __impl__
def split(data_reader, count):
"""
split for multi-gpu
"""
def __impl__():
for item in data_reader():
inst_num_per_part = len(item) // count
for i in range(count):
yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
)]
return __impl__
if not args.use_token_batch:
# to make data on each device have similar token number
data_reader = split(data_reader, count)
#if args.use_py_reader:
if pyreader:
pyreader.decorate_tensor_provider(
py_reader_provider_wrapper(data_reader))
data_reader = None
else: # Data generator for multi-devices
data_reader = stack(data_reader, count)
return data_reader
def pad_batch_data(insts,
pad_idx,
n_head,
is_target=False,
is_label=False,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
max_len = max(len(inst) for inst in insts)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array(
[inst + [pad_idx] * (max_len - len(inst)) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, 1])]
if is_label: # label weight
inst_weight = np.array(
[[1.] * len(inst) + [0.] * (max_len - len(inst)) for inst in insts])
return_list += [inst_weight.astype("float32").reshape([-1, 1])]
else: # position data
inst_pos = np.array([
list(range(0, len(inst))) + [0] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, 1])]
if return_attn_bias:
if is_target:
# This is used to avoid attention on paddings and subsequent
# words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9]
else:
# This is used to avoid attention on paddings.
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst))
for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layer_helper import LayerHelper as LayerHelper
from config import *
from beam_search import BeamSearch
from attention import _dot_product_relative
INF = 1. * 1e5
def layer_norm(x, begin_norm_axis=1, epsilon=1e-6, param_attr=None, bias_attr=None):
"""
layer_norm
"""
helper = LayerHelper('layer_norm', **locals())
mean = layers.reduce_mean(x, dim=begin_norm_axis, keep_dim=True)
shift_x = layers.elementwise_sub(x=x, y=mean, axis=0)
variance = layers.reduce_mean(layers.square(shift_x), dim=begin_norm_axis, keep_dim=True)
r_stdev = layers.rsqrt(variance + epsilon)
norm_x = layers.elementwise_mul(x=shift_x, y=r_stdev, axis=0)
param_shape = [reduce(lambda x, y: x * y, norm_x.shape[begin_norm_axis:])]
param_dtype = norm_x.dtype
scale = helper.create_parameter(
attr=param_attr,
shape=param_shape,
dtype=param_dtype,
default_initializer=fluid.initializer.Constant(1.))
bias = helper.create_parameter(
attr=bias_attr,
shape=param_shape,
dtype=param_dtype,
is_bias=True,
default_initializer=fluid.initializer.Constant(0.))
out = layers.elementwise_mul(x=norm_x, y=scale, axis=-1)
out = layers.elementwise_add(x=out, y=bias, axis=-1)
return out#norm_x * scale + bias
def relative_position_encoding_init(n_position, d_pos_vec):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels = d_pos_vec
position = np.arange(n_position)
num_timescales = channels // 2
log_timescale_increment = (np.log(float(1e4) / float(1)) /
(num_timescales - 1))
inv_timescales = np.exp(np.arange(
num_timescales) * -log_timescale_increment)
#num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal
return position_enc.astype("float32")
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
attention_type="dot_product",
params_type = "normal"):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
k = layers.fc(input=keys,
size=d_key * n_head,
bias_attr=False,
num_flatten_dims=2)
v = layers.fc(input=values,
size=d_value * n_head,
bias_attr=False,
num_flatten_dims=2)
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
if n_head == 1:
return x
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key ** -0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = layers.concat([cache['k'], k], axis=1)
v = layers.concat([cache['v'], v], axis=1)
layers.assign(k, cache['k'])
layers.assign(v, cache['v'])
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
assert attention_type == "dot_product" or attention_type == "dot_product_relative_encoder" or attention_type == "dot_product_relative_decoder"
if attention_type == "dot_product":
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key, #d_model,
dropout_rate)
elif attention_type == "dot_product_relative_encoder":
q = layers.scale(x=q, scale=d_key ** -0.5)
ctx_multiheads = _dot_product_relative(q, k, v, attn_bias, dropout=dropout_rate, params_type = params_type)
else:
q = layers.scale(x=q, scale=d_key ** -0.5)
ctx_multiheads = _dot_product_relative(q, k, v, attn_bias, dropout=dropout_rate, cache = cache, params_type = params_type)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
bias_attr=False,
num_flatten_dims=2)
return proj_out
def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act="relu")
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
epsilon=1e-6,
param_attr=fluid.initializer.Constant(1.),
bias_attr=fluid.initializer.Constant(0.))
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def prepare_encoder_decoder(src_word,
src_pos,
src_vocab_size,
src_emb_dim,
src_max_len,
dropout_rate=0.,
word_emb_param_name=None,
training=True,
pos_enc_param_name=None,
is_src=True,
params_type="normal"):
"""Add word embeddings and position encodings.
The output tensor has a shape of:
[batch_size, max_src_length_in_batch, d_model].
This module is used at the bottom of the encoder stacks.
"""
assert params_type == "fixed" or params_type == "normal" or params_type == "new"
pre_name = "relative_positionrelative_position"
if params_type == "fixed":
pre_name = "fixed_relative_positionfixed_relative_position"
elif params_type == "new":
pre_name = "new_relative_positionnew_relative_position"
src_word_emb = layers.embedding(
src_word,
size=[src_vocab_size, src_emb_dim],
padding_idx=ModelHyperParams.bos_idx, # set embedding of bos to 0
param_attr=fluid.ParamAttr(
name = pre_name + word_emb_param_name,
initializer=fluid.initializer.Normal(0., src_emb_dim ** -0.5)))#, is_sparse=True)
if not is_src and training:
src_word_emb = layers.pad(src_word_emb, [0, 0, 1, 0, 0, 0])
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim ** 0.5)
src_pos_enc = layers.embedding(
src_pos,
size=[src_max_len, src_emb_dim],
param_attr=fluid.ParamAttr(
trainable=False, name = pre_name + pos_enc_param_name))
src_pos_enc.stop_gradient = True
enc_input = src_word_emb + src_pos_enc
return layers.dropout(
enc_input,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train') if dropout_rate else enc_input
prepare_encoder = partial(
prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0], is_src=True)
prepare_decoder = partial(
prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1], is_src=False)
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da",
params_type="normal"):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(enc_input, preprocess_cmd,
prepostprocess_dropout), None, None, attn_bias, d_key,
d_value, d_model, n_head, attention_dropout, attention_type = "dot_product_relative_encoder", params_type = params_type)
attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
prepostprocess_dropout)
ffd_output = positionwise_feed_forward(
pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout),
d_inner_hid, d_model, relu_dropout)
return post_process_layer(attn_output, ffd_output, postprocess_cmd,
prepostprocess_dropout)
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd="n",
postprocess_cmd="da",
params_type="normal"):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
params_type=params_type)
enc_input = enc_output
enc_output = pre_process_layer(enc_output, preprocess_cmd,
prepostprocess_dropout)
return enc_output
def decoder_layer(dec_input,
enc_output,
slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None,
params_type="normal"):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output = multi_head_attention(
pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout),
None,
None,
slf_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
cache,
attention_type="dot_product_relative_decoder",
params_type=params_type)
slf_attn_output = post_process_layer(
dec_input,
slf_attn_output,
postprocess_cmd,
prepostprocess_dropout, )
enc_attn_output = multi_head_attention(
pre_process_layer(slf_attn_output, preprocess_cmd, prepostprocess_dropout),
enc_output,
enc_output,
dec_enc_attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
params_type=params_type)
enc_attn_output = post_process_layer(
slf_attn_output,
enc_attn_output,
postprocess_cmd,
prepostprocess_dropout, )
ffd_output = positionwise_feed_forward(
pre_process_layer(enc_attn_output, preprocess_cmd,
prepostprocess_dropout),
d_inner_hid,
d_model,
relu_dropout, )
dec_output = post_process_layer(
enc_attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout, )
return dec_output
def decoder(dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=None,
params_type="normal"):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for i in range(n_layer):
dec_output = decoder_layer(
dec_input,
enc_output,
dec_slf_attn_bias,
dec_enc_attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
cache=None if caches is None else caches[i],
params_type=params_type)
dec_input = dec_output
dec_output = pre_process_layer(dec_output, preprocess_cmd,
prepostprocess_dropout)
return dec_output
def make_all_inputs(input_fields):
"""
Define the input data layers for the transformer model.
"""
inputs = []
for input_field in input_fields:
input_var = layers.data(
name=input_field,
shape=input_descs[input_field][0],
dtype=input_descs[input_field][1],
lod_level=input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0,
append_batch_size=False)
inputs.append(input_var)
return inputs
def make_all_py_reader_inputs(input_fields, is_test=False):
"""
Define the input data layers for the transformer model.
"""
reader = layers.py_reader(
capacity=20,
name="test_reader" if is_test else "train_reader",
shapes=[input_descs[input_field][0] for input_field in input_fields],
dtypes=[input_descs[input_field][1] for input_field in input_fields],
lod_levels=[
input_descs[input_field][2]
if len(input_descs[input_field]) == 3 else 0
for input_field in input_fields
], use_double_buffer=True)
return layers.read_file(reader), reader
def relative_transformer(src_vocab_size,
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
label_smooth_eps,
use_py_reader=False,
is_test=False,
params_type="normal",
all_data_inputs = None):
"""
transformer
"""
if embedding_sharing:
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields + dense_bias_input_fields
if use_py_reader:
all_inputs = all_data_inputs
else:
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(decoder_data_input_fields[:-1])
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
real_label = all_inputs[enc_inputs_len + dec_inputs_len]
weights = all_inputs[enc_inputs_len + dec_inputs_len + 1]
reverse_label = all_inputs[enc_inputs_len + dec_inputs_len + 2]
enc_output = wrap_encoder(
src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
enc_inputs,
params_type=params_type)
predict = wrap_decoder(
trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
dec_inputs,
enc_output, is_train = True if not is_test else False,
params_type=params_type)
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
if label_smooth_eps:
label = layers.one_hot(input=real_label, depth=trg_vocab_size)
label = label * (1 - label_smooth_eps) + (1 - label) * (
label_smooth_eps / (trg_vocab_size - 1))
label.stop_gradient = True
else:
label = real_label
cost = layers.softmax_with_cross_entropy(
logits=predict,
label=label,
soft_label=True if label_smooth_eps else False)
weighted_cost = cost * weights
sum_cost = layers.reduce_sum(weighted_cost)
sum_cost.persistable = True
token_num = layers.reduce_sum(weights)
token_num.persistable = True
token_num.stop_gradient = True
avg_cost = sum_cost / token_num
sen_count = layers.shape(dec_inputs[0])[0]
batch_predict = layers.reshape(predict, shape = [sen_count, -1, ModelHyperParams.trg_vocab_size])
batch_label = layers.reshape(real_label, shape=[sen_count, -1])
batch_weights = layers.reshape(weights, shape=[sen_count, -1, 1])
return sum_cost, avg_cost, token_num, batch_predict, cost, sum_cost, batch_label, batch_weights
def wrap_encoder(src_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
enc_inputs=None,
params_type="normal"):
"""
The wrapper assembles together all needed layers for the encoder.
"""
if enc_inputs is None:
# This is used to implement independent encoder program in inference.
src_word, src_pos, src_slf_attn_bias = make_all_inputs(
encoder_data_input_fields)
else:
src_word, src_pos, src_slf_attn_bias = enc_inputs
enc_input = prepare_encoder(
src_word,
src_pos,
src_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[0],
params_type=params_type)
enc_output = encoder(
enc_input,
src_slf_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
params_type=params_type)
return enc_output
def wrap_decoder(trg_vocab_size,
max_length,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
dec_inputs=None,
enc_output=None,
caches=None, is_train=True, params_type="normal"):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if dec_inputs is None:
# This is used to implement independent decoder program in inference.
trg_word, reverse_trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = \
make_all_inputs(decoder_data_input_fields)
else:
trg_word, reverse_trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias = dec_inputs
dec_input = prepare_decoder(
trg_word,
trg_pos,
trg_vocab_size,
d_model,
max_length,
prepostprocess_dropout,
word_emb_param_name=word_emb_param_names[0]
if embedding_sharing else word_emb_param_names[1],
training=is_train,
params_type=params_type)
dec_output = decoder(
dec_input,
enc_output,
trg_slf_attn_bias,
trg_src_attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
caches=caches,
params_type=params_type)
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output = layers.reshape(
dec_output, shape=[-1, dec_output.shape[-1]], inplace=True)
assert params_type == "fixed" or params_type == "normal" or params_type == "new"
pre_name = "relative_positionrelative_position"
if params_type == "fixed":
pre_name = "fixed_relative_positionfixed_relative_position"
elif params_type == "new":
pre_name = "new_relative_positionnew_relative_position"
if weight_sharing and embedding_sharing:
predict = layers.matmul(
x=dec_output,
y=fluid.default_main_program().global_block().var(
pre_name + word_emb_param_names[0]),
transpose_y=True)
elif weight_sharing:
predict = layers.matmul(
x=dec_output,
y=fluid.default_main_program().global_block().var(
pre_name + word_emb_param_names[1]),
transpose_y=True)
else:
predict = layers.fc(input=dec_output,
size=trg_vocab_size,
bias_attr=False)
if dec_inputs is None:
# Return probs for independent decoder program.
predict = layers.softmax(predict)
return predict
def get_enc_bias(source_inputs):
"""
get_enc_bias
"""
source_inputs = layers.cast(source_inputs, 'float32')
emb_sum = layers.reduce_sum(layers.abs(source_inputs), dim=-1)
zero = layers.fill_constant([1], 'float32', value=0)
bias = layers.cast(layers.equal(emb_sum, zero), 'float32') * -1e9
return layers.unsqueeze(layers.unsqueeze(bias, axes=[1]), axes=[1])
def relative_fast_decode(
src_vocab_size,
trg_vocab_size,
max_in_len,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
preprocess_cmd,
postprocess_cmd,
weight_sharing,
embedding_sharing,
beam_size,
batch_size,
max_out_len,
decode_alpha,
eos_idx,
params_type="normal"):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
assert params_type == "normal" or params_type == "new" or params_type == "fixed"
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
all_inputs = make_all_inputs(data_input_names)
enc_inputs_len = len(encoder_data_input_fields)
dec_inputs_len = len(fast_decoder_data_input_fields)
enc_inputs = all_inputs[0:enc_inputs_len]
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]
enc_output = wrap_encoder(src_vocab_size, max_in_len, n_layer, n_head,
d_key, d_value, d_model, d_inner_hid,
prepostprocess_dropout, attention_dropout,
relu_dropout, preprocess_cmd, postprocess_cmd,
weight_sharing, embedding_sharing, enc_inputs, params_type=params_type)
enc_bias = get_enc_bias(enc_inputs[0])
source_length, = dec_inputs
def beam_search(enc_output, enc_bias, source_length):
"""
beam_search
"""
max_len = layers.fill_constant(
shape=[1], dtype='int64', value=max_out_len)
step_idx = layers.fill_constant(
shape=[1], dtype='int64', value=0)
cond = layers.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
caches_batch_size = batch_size * beam_size
init_score = np.zeros([1, beam_size]).astype('float32')
init_score[:, 1:] = -INF
initial_log_probs = layers.assign(init_score)
alive_log_probs = layers.expand(initial_log_probs, [batch_size, 1])
# alive seq [batch_size, beam_size, 1]
initial_ids = layers.zeros([batch_size, 1, 1], 'float32')
alive_seq = layers.expand(initial_ids, [1, beam_size, 1])
alive_seq = layers.cast(alive_seq, 'int64')
enc_output = layers.unsqueeze(enc_output, axes=[1])
enc_output = layers.expand(enc_output, [1, beam_size, 1, 1])
enc_output = layers.reshape(enc_output, [caches_batch_size, -1, d_model])
tgt_src_attn_bias = layers.unsqueeze(enc_bias, axes=[1])
tgt_src_attn_bias = layers.expand(tgt_src_attn_bias, [1, beam_size, n_head, 1, 1])
enc_bias_shape = layers.shape(tgt_src_attn_bias)
tgt_src_attn_bias = layers.reshape(tgt_src_attn_bias, [-1, enc_bias_shape[2],
enc_bias_shape[3], enc_bias_shape[4]])
beam_search = BeamSearch(beam_size, batch_size, decode_alpha, trg_vocab_size, d_model)
caches = [{
"k": layers.fill_constant(
shape=[caches_batch_size, 0, d_model],
dtype=enc_output.dtype,
value=0),
"v": layers.fill_constant(
shape=[caches_batch_size, 0, d_model],
dtype=enc_output.dtype,
value=0)
} for i in range(n_layer)]
finished_seq = layers.zeros_like(alive_seq)
finished_scores = layers.fill_constant([batch_size, beam_size],
dtype='float32', value=-INF)
finished_flags = layers.fill_constant([batch_size, beam_size],
dtype='float32', value=0)
with while_op.block():
pos = layers.fill_constant([caches_batch_size, 1, 1], dtype='int64', value=1)
pos = layers.elementwise_mul(pos, step_idx, axis=0)
alive_seq_1 = layers.reshape(alive_seq, [caches_batch_size, -1])
alive_seq_2 = alive_seq_1[:, -1:]
alive_seq_2 = layers.unsqueeze(alive_seq_2, axes=[1])
logits = wrap_decoder(
trg_vocab_size, max_in_len, n_layer, n_head, d_key,
d_value, d_model, d_inner_hid, prepostprocess_dropout,
attention_dropout, relu_dropout, preprocess_cmd,
postprocess_cmd, weight_sharing, embedding_sharing,
dec_inputs=(alive_seq_2, alive_seq_2, pos, None, tgt_src_attn_bias),
enc_output=enc_output, caches=caches, is_train=False, params_type=params_type)
alive_seq_2, alive_log_probs_2, finished_seq_2, finished_scores_2, finished_flags_2, caches_2 = \
beam_search.inner_func(step_idx, logits, alive_seq_1, alive_log_probs, finished_seq,
finished_scores, finished_flags, caches, enc_output,
tgt_src_attn_bias)
layers.increment(x=step_idx, value=1.0, in_place=True)
finish_cond = beam_search.is_finished(step_idx, source_length, alive_log_probs_2,
finished_scores_2, finished_flags_2)
layers.assign(alive_seq_2, alive_seq)
layers.assign(alive_log_probs_2, alive_log_probs)
layers.assign(finished_seq_2, finished_seq)
layers.assign(finished_scores_2, finished_scores)
layers.assign(finished_flags_2, finished_flags)
for i in xrange(len(caches_2)):
layers.assign(caches_2[i]["k"], caches[i]["k"])
layers.assign(caches_2[i]["v"], caches[i]["v"])
layers.logical_and(x=cond, y=finish_cond, out=cond)
finished_flags = layers.reduce_sum(finished_flags, dim=1, keep_dim=True) / beam_size
finished_flags = layers.cast(finished_flags, 'bool')
mask = layers.cast(layers.reduce_any(input=finished_flags, dim=1, keep_dim=True), 'float32')
mask = layers.expand(mask, [1, beam_size])
mask2 = 1.0 - mask
finished_seq = layers.cast(finished_seq, 'float32')
alive_seq = layers.cast(alive_seq, 'float32')
#print mask
finished_seq = layers.elementwise_mul(finished_seq, mask, axis=0) + \
layers.elementwise_mul(alive_seq, mask2, axis = 0)
finished_seq = layers.cast(finished_seq, 'int32')
finished_scores = layers.elementwise_mul(finished_scores, mask, axis=0) + \
layers.elementwise_mul(alive_log_probs, mask2)
finished_seq.persistable = True
finished_scores.persistable = True
return finished_seq, finished_scores
finished_ids, finished_scores = beam_search(enc_output, enc_bias, source_length)
return finished_ids, finished_scores
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import copy
import logging
import multiprocessing
import os
import six
import sys
import time
import random
import math
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.tensor as tensor
from paddle.fluid.framework import default_main_program
import reader
from reader import *
from config import *
from forward_model import forward_transformer, forward_position_encoding_init, forward_fast_decode, make_all_py_reader_inputs
from dense_model import dense_transformer, dense_fast_decode
from relative_model import relative_transformer, relative_fast_decode
def parse_args():
"""
parse_args
"""
parser = argparse.ArgumentParser("Training for Transformer.")
parser.add_argument(
"--train_file_pattern",
type=str,
required=True,
help="The pattern to match training data files.")
parser.add_argument(
"--val_file_pattern",
type=str,
help="The pattern to match validation data files.")
parser.add_argument(
"--ckpt_path",
type=str,
help="The pattern to match training data files.")
parser.add_argument(
"--infer_batch_size",
type=int,
help="Infer batch_size")
parser.add_argument(
"--decode_alpha",
type=float,
help="decode_alpha")
parser.add_argument(
"--beam_size",
type=int,
help="Infer beam_size")
parser.add_argument(
"--use_token_batch",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to "
"produce batch data according to token number.")
parser.add_argument(
"--batch_size",
type=int,
default=4096,
help="The number of sequences contained in a mini-batch, or the maximum "
"number of tokens (include paddings) contained in a mini-batch. Note "
"that this represents the number on single device and the actual batch "
"size for multi-devices will multiply the device number.")
parser.add_argument(
"--pool_size",
type=int,
default=200000,
help="The buffer size to pool data.")
parser.add_argument(
"--num_threads",
type=int,
default=2,
help="The number of threads which executor use.")
parser.add_argument(
"--use_fp16",
type=ast.literal_eval,
default=True,
help="Use fp16 or not"
)
parser.add_argument(
"--nccl_comm_num",
type=int,
default=1,
help="The number of threads which executor use.")
parser.add_argument(
"--sort_type",
default="pool",
choices=("global", "pool", "none"),
help="The grain to sort by length: global for all instances; pool for "
"instances in pool; none for no sort.")
parser.add_argument(
"--use_hierarchical_allreduce",
default=False,
type=ast.literal_eval,
help="Use hierarchical allreduce or not.")
parser.add_argument(
"--hierarchical_allreduce_inter_nranks",
default=8,
type=int,
help="interranks.")
parser.add_argument(
"--shuffle",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to shuffle instances in each pass.")
parser.add_argument(
"--shuffle_batch",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to shuffle the data batches.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
parser.add_argument(
'--local',
type=ast.literal_eval,
default=False,
help='Whether to run as local mode.')
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=['CPU', 'GPU'],
help="The device type.")
parser.add_argument(
'--sync', type=ast.literal_eval, default=True, help="sync mode.")
parser.add_argument(
"--enable_ce",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to run the task "
"for continuous evaluation.")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use memory optimization.")
parser.add_argument(
"--use_py_reader",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--fetch_steps",
type=int,
default=100,
help="The frequency to fetch and print output.")
parser.add_argument(
"--use_delay_load",
type=ast.literal_eval,
default=True,
help=
"The flag indicating whether to load all data into memories at once.")
parser.add_argument(
"--src_vocab_size",
type=str,
required=True,
help="Size of src Vocab.")
parser.add_argument(
"--tgt_vocab_size",
type=str,
required=True,
help="Size of tgt Vocab.")
parser.add_argument(
"--restore_step",
type=int,
default=0,
help="The step number of checkpoint to restore training.")
parser.add_argument(
"--fuse",
type=int,
default=0,
help="Use fusion or not.")
args = parser.parse_args()
src_voc_size = args.src_vocab_size
trg_voc_size = args.tgt_vocab_size
if args.use_delay_load:
dict_args = [
"src_vocab_size", src_voc_size,
"trg_vocab_size", trg_voc_size,
"bos_idx", str(0),
"eos_idx", str(1),
"unk_idx", str(int(src_voc_size) - 1)
]
else:
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
dict_args = [
"src_vocab_size", str(len(src_dict)), "trg_vocab_size",
str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
"eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]])
]
merge_cfg_from_list(args.opts + dict_args,
[TrainTaskConfig, ModelHyperParams])
return args
def prepare_batch_input(insts, data_input_names, src_pad_idx, trg_pad_idx,
n_head, d_model):
"""
Put all padded data needed by training into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len, 1)
trg_word = trg_word[:, 1:, :]
trg_pos = trg_pos.reshape(-1, trg_max_len, 1)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
# reverse_target
reverse_trg_word, _, _, _ = pad_batch_data(
[inst[3] for inst in insts], trg_pad_idx, n_head, is_target=True)
reverse_trg_word = reverse_trg_word.reshape(-1, trg_max_len, 1)
reverse_trg_word = reverse_trg_word[:, 1:, :]
reverse_lbl_word, _, _ = pad_batch_data(
[inst[4] for inst in insts],
trg_pad_idx,
n_head,
is_target=False,
is_label=True,
return_attn_bias=False,
return_max_len=False,
return_num_token=True)
eos_position = []
meet_eos = False
for word_id in reverse_lbl_word:
if word_id[0] == 1 and not meet_eos:
meet_eos = True
eos_position.append([1])
elif word_id[0] == 1 and meet_eos:
eos_position.append([0])
else:
meet_eos = False
eos_position.append([0])
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, trg_word, reverse_trg_word, trg_pos,
trg_slf_attn_bias, trg_src_attn_bias, lbl_word, lbl_weight, reverse_lbl_word, np.asarray(eos_position, dtype = "int64")
]))
return data_input_dict, np.asarray([num_token], dtype="float32")
def prepare_feed_dict_list(data_generator, count, num_tokens=None, num_insts=None):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
eos_idx = ModelHyperParams.eos_idx
n_head = ModelHyperParams.n_head
d_model = ModelHyperParams.d_model
max_length = ModelHyperParams.max_length
dense_n_head = DenseModelHyperParams.n_head
dense_d_model = DenseModelHyperParams.d_model
if data_generator is not None: # use_py_reader == False
dense_data_input_names = dense_encoder_data_input_fields + \
dense_decoder_data_input_fields[:-1] + dense_label_data_input_fields
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict, num_token = prepare_batch_input(
data_buffer, data_input_names, eos_idx,
eos_idx, n_head,
d_model)
dense_data_input_dict, _ = prepare_batch_input(
data_buffer, dense_data_input_names, eos_idx,
eos_idx, dense_n_head,
dense_d_model)
data_input_dict.update(dense_data_input_dict) # merge dict
feed_dict_list.append(data_input_dict)
if isinstance(num_tokens, list): num_tokens.append(num_token)
if isinstance(num_insts, list): num_insts.append(len(data_buffer))
return feed_dict_list if len(feed_dict_list) == count else None
def py_reader_provider_wrapper(data_reader):
"""
Data provider needed by fluid.layers.py_reader.
"""
def py_reader_provider():
"""
py_reader_provider
"""
eos_idx = ModelHyperParams.eos_idx
n_head = ModelHyperParams.n_head
d_model = ModelHyperParams.d_model
max_length = ModelHyperParams.max_length
dense_n_head = DenseModelHyperParams.n_head
dense_d_model = DenseModelHyperParams.d_model
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields
dense_data_input_names = dense_encoder_data_input_fields + \
dense_decoder_data_input_fields[:-1] + label_data_input_fields
new_data_input_names = data_input_names + dense_bias_input_fields
for batch_id, data in enumerate(data_reader()):
data_input_dict, num_token = prepare_batch_input(
data, data_input_names, eos_idx,
eos_idx, n_head,
d_model)
dense_data_input_dict, _ = prepare_batch_input(
data, dense_data_input_names, eos_idx,
eos_idx, dense_n_head,
dense_d_model)
data_input_dict["dense_src_slf_attn_bias"] = dense_data_input_dict["dense_src_slf_attn_bias"]
data_input_dict["dense_trg_slf_attn_bias"] = dense_data_input_dict["dense_trg_slf_attn_bias"]
data_input_dict["dense_trg_src_attn_bias"] = dense_data_input_dict["dense_trg_src_attn_bias"]
total_dict = dict(data_input_dict.items())
yield [total_dict[item] for item in new_data_input_names]
return py_reader_provider
from infer import prepare_feed_dict_list as infer_prepare_feed_dict_list
from infer import prepare_dense_feed_dict_list as infer_prepare_dense_feed_dict_list
def test_context(exe, train_exe, dev_count, agent_name, args):
# Context to do validation.
test_prog = fluid.Program()
startup_prog = fluid.Program()
if args.enable_ce:
test_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(test_prog, startup_prog):
if agent_name == "new_forward":
with fluid.unique_name.guard("new_forward"):
out_ids1, out_scores1 = forward_fast_decode(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
args.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
elif agent_name == "new_relative_position":
with fluid.unique_name.guard("new_relative_position"):
out_ids2, out_scores2 = relative_fast_decode(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
args.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
elif agent_name == "new_dense":
DenseModelHyperParams.src_vocab_size = ModelHyperParams.src_vocab_size
DenseModelHyperParams.trg_vocab_size = ModelHyperParams.trg_vocab_size
DenseModelHyperParams.weight_sharing = ModelHyperParams.weight_sharing
DenseModelHyperParams.embedding_sharing = ModelHyperParams.embedding_sharing
with fluid.unique_name.guard("new_dense"):
out_ids3, out_scores3 = dense_fast_decode(
DenseModelHyperParams.src_vocab_size,
DenseModelHyperParams.trg_vocab_size,
DenseModelHyperParams.max_length + 50,
DenseModelHyperParams.n_layer,
DenseModelHyperParams.enc_n_layer,
DenseModelHyperParams.n_head,
DenseModelHyperParams.d_key,
DenseModelHyperParams.d_value,
DenseModelHyperParams.d_model,
DenseModelHyperParams.d_inner_hid,
DenseModelHyperParams.prepostprocess_dropout,
DenseModelHyperParams.attention_dropout,
DenseModelHyperParams.relu_dropout,
DenseModelHyperParams.preprocess_cmd,
DenseModelHyperParams.postprocess_cmd,
DenseModelHyperParams.weight_sharing,
DenseModelHyperParams.embedding_sharing,
args.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
test_prog = test_prog.clone(for_test=True)
dev_count = 1
file_pattern = "%s" % (args.val_file_pattern)
lines_cnt = len(open(file_pattern, 'r').readlines())
data_reader = line_reader(file_pattern, args.infer_batch_size, dev_count,
token_delimiter=args.token_delimiter,
max_len=ModelHyperParams.max_length,
parse_line=parse_src_line)
test_data = prepare_data_generator(args, is_test=True, count=dev_count, pyreader=None,
batch_size=args.infer_batch_size, data_reader=data_reader)
def test(step_id, exe=exe):
f = ""
if agent_name == "new_relative_position":
f = open("./output/new_relative_position_iter_%d.trans" % (step_id), 'w')
elif agent_name == "new_forward":
f = open("./output/new_forward_iter_%d.trans" % (step_id), 'w')
elif agent_name == "new_dense":
f = open("./output/new_dense_iter_%d.trans" % (step_id), 'w')
data_generator = test_data()
trans_list = []
while True:
try:
feed_dict_list = infer_prepare_feed_dict_list(data_generator, 1) if agent_name != "new_dense" else infer_prepare_dense_feed_dict_list(data_generator, 1)
if agent_name == "new_forward":
seq_ids, seq_scores = exe.run(
fetch_list=[out_ids1.name, out_scores1.name],
feed=feed_dict_list,
program=test_prog,
return_numpy=True)
elif agent_name == "new_relative_position":
seq_ids, seq_scores = exe.run(
fetch_list=[out_ids2.name, out_scores2.name],
feed=feed_dict_list,
program=test_prog,
return_numpy=True)
elif agent_name == "new_dense":
seq_ids, seq_scores = exe.run(
fetch_list=[out_ids3.name, out_scores3.name],
feed=feed_dict_list,
program=test_prog,
return_numpy=True)
seq_ids = seq_ids.tolist()
for index in xrange(args.infer_batch_size):
seq = seq_ids[index][0]
if 1 not in seq:
res = seq[1:-1]
else:
res = seq[1: seq.index(1)]
res = map(str, res)
trans_list.append(" ".join(res))
except (StopIteration, fluid.core.EOFException):
# The current pass is over.
break
trans_list = trans_list[:lines_cnt]
for trans in trans_list:
f.write("%s\n" % trans)
f.flush()
f.close()
return test
def get_tensor_by_prefix(pre_name, param_name_list):
tensors_list = []
for param_name in param_name_list:
if pre_name in param_name:
tensors_list.append(fluid.global_scope().find_var(param_name).get_tensor())
if pre_name == "fixed_relative_positionfixed_relative_position":
tensors_list.append(fluid.global_scope().find_var("fixed_relative_positions_keys").get_tensor())
tensors_list.append(fluid.global_scope().find_var("fixed_relative_positions_values").get_tensor())
elif pre_name == "new_relative_positionnew_relative_position":
tensors_list.append(fluid.global_scope().find_var("new_relative_positions_keys").get_tensor())
tensors_list.append(fluid.global_scope().find_var("new_relative_positions_values").get_tensor())
return tensors_list
def train_loop(exe,
train_prog,
startup_prog,
args,
dev_count,
avg_cost,
teacher_cost,
single_model_sum_cost,
single_model_avg_cost,
token_num,
pyreader, place,
nccl2_num_trainers=1,
nccl2_trainer_id=0,
scaled_cost=None,
loss_scaling=None
):
"""
train_loop
"""
# Initialize the parameters.
if TrainTaskConfig.ckpt_path:
exe.run(startup_prog)
logging.info("load checkpoint from {}".format(TrainTaskConfig.ckpt_path))
fluid.io.load_params(exe, TrainTaskConfig.ckpt_path, main_program=train_prog)
else:
logging.info("init fluid.framework.default_startup_program")
exe.run(startup_prog)
param_list = train_prog.block(0).all_parameters()
param_name_list = [p.name for p in param_list ]
logging.info("begin reader")
batch_scheme = batching_scheme(args.batch_size, 256, shard_multiplier=nccl2_num_trainers)
tf_data = bucket_by_sequence_length(
repeat(
interleave_reader(
args.train_file_pattern,
cycle_length=8,
token_delimiter=args.token_delimiter,
max_len=ModelHyperParams.max_length,
parse_line=parse_line,
), -1),
lambda x:max(len(x[0]), len(x[1])),
batch_scheme["boundaries"],
batch_scheme["batch_sizes"],
nccl2_num_trainers,
nccl2_trainer_id
)
args.use_token_batch = False
train_data = prepare_data_generator(
args, is_test=False, count=dev_count, pyreader=pyreader, data_reader=tf_data, \
py_reader_provider_wrapper=py_reader_provider_wrapper)
# For faster executor
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = 20
build_strategy = fluid.BuildStrategy()
build_strategy.enable_inplace = True
build_strategy.fuse_all_optimizer_ops = False
build_strategy.fuse_all_reduce_ops = False
build_strategy.enable_backward_optimizer_op_deps = True
if args.fuse:
build_strategy.fuse_all_reduce_ops = True
trainer_id = nccl2_trainer_id
train_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
loss_name=avg_cost.name,
main_program=train_prog,
build_strategy=build_strategy,
exec_strategy=exec_strategy,
num_trainers=nccl2_num_trainers,
trainer_id=nccl2_trainer_id)
if args.val_file_pattern is not None:
new_forward_test = test_context(exe, train_exe, dev_count, "new_forward", args)
new_dense_test = test_context(exe, train_exe, dev_count, "new_dense", args)
new_relative_position_test = test_context(exe, train_exe, dev_count, "new_relative_position", args)
# the best cross-entropy value with label smoothing
loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
(1. - TrainTaskConfig.label_smooth_eps
)) + TrainTaskConfig.label_smooth_eps *
np.log(TrainTaskConfig.label_smooth_eps / (
ModelHyperParams.trg_vocab_size - 1) + 1e-20))
# set recovery step
step_idx = args.restore_step if args.restore_step else 0
if step_idx != 0:
var = fluid.global_scope().find_var("@LR_DECAY_COUNTER@").get_tensor()
recovery_step = np.array([step_idx]).astype("int64")
var.set(recovery_step, fluid.CPUPlace())
step = np.array(var)[0]
# set pos encoding
model_prefix = ["fixed_forward", "fixed_relative_position",
"new_forward", "new_relative_position"]
for pos_enc_param_name in pos_enc_param_names:
for prefix in model_prefix:
pos_name = prefix * 2 + pos_enc_param_name
pos_enc_param = fluid.global_scope().find_var(
pos_name).get_tensor()
pos_enc_param.set(
forward_position_encoding_init(
ModelHyperParams.max_length + 50,
ModelHyperParams.d_model), place)
model_prefix_2 = ["fixed_dense", "new_dense"]
for pos_enc_param_name in pos_enc_param_names:
for prefix in model_prefix_2:
pos_name = prefix * 2 + pos_enc_param_name
pos_enc_param = fluid.global_scope().find_var(
pos_name).get_tensor()
pos_enc_param.set(
forward_position_encoding_init(
DenseModelHyperParams.max_length + 50,
DenseModelHyperParams.d_model), place)
logging.info("begin train")
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
avg_batch_time = time.time()
pyreader.start()
data_generator = None
batch_id = 0
while True:
try:
num_tokens = []
num_insts = []
feed_dict_list = prepare_feed_dict_list(data_generator,
dev_count, num_tokens, num_insts)
num_token = np.sum(num_tokens).reshape([-1])
num_inst = np.sum(num_insts).reshape([-1])
outs = train_exe.run(
fetch_list=[avg_cost.name, token_num.name, teacher_cost.name]
if (step_idx == 0 or step_idx % args.fetch_steps == (args.fetch_steps - 1)) else [],
feed=feed_dict_list)
if (step_idx == 0 or step_idx % args.fetch_steps == (args.fetch_steps - 1)):
single_model_total_avg_cost, token_num_val = np.array(outs[0]), np.array(outs[1])
teacher = np.array(outs[2])
if step_idx == 0:
logging.info(
("step_idx: %d, epoch: %d, batch: %d, teacher loss: %f, avg loss: %f, "
"normalized loss: %f, ppl: %f" + (", batch size: %d" if num_inst else "")) %
((step_idx, pass_id, batch_id, teacher, single_model_total_avg_cost,
single_model_total_avg_cost - loss_normalizer,
np.exp([min(single_model_total_avg_cost, 100)])) + ((num_inst,) if num_inst else ())))
else:
logging.info(
("step_idx: %d, epoch: %d, batch: %d, teacher loss: %f, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" + \
(", batch size: %d" if num_inst else "")) %
((step_idx, pass_id, batch_id, teacher, single_model_total_avg_cost,
single_model_total_avg_cost - loss_normalizer,
np.exp([min(single_model_total_avg_cost, 100)]),
args.fetch_steps / (time.time() - avg_batch_time)) + ((num_inst,) if num_inst else ())))
avg_batch_time = time.time()
if step_idx % TrainTaskConfig.fixed_freq == (TrainTaskConfig.fixed_freq - 1):
logging.info("copy parameters to fixed parameters when step_idx is {}".format(step_idx))
fixed_forward_tensors = get_tensor_by_prefix("fixed_forwardfixed_forward", param_name_list)
new_forward_tensors = get_tensor_by_prefix("new_forwardnew_forward", param_name_list)
fixed_dense_tensors = get_tensor_by_prefix("fixed_densefixed_dense", param_name_list)
new_dense_tensors = get_tensor_by_prefix("new_densenew_dense", param_name_list)
fixed_relative_tensors = get_tensor_by_prefix("fixed_relative_positionfixed_relative_position", param_name_list)
new_relative_tensors = get_tensor_by_prefix("new_relative_positionnew_relative_position", param_name_list)
for (fixed_tensor, new_tensor) in zip(fixed_forward_tensors, new_forward_tensors):
fixed_tensor.set(np.array(new_tensor), place)
for (fixed_tensor, new_tensor) in zip(fixed_relative_tensors, new_relative_tensors):
fixed_tensor.set(np.array(new_tensor), place)
for (fixed_tensor, new_tensor) in zip(fixed_dense_tensors, new_dense_tensors):
fixed_tensor.set(np.array(new_tensor), place)
if step_idx % TrainTaskConfig.save_freq == (TrainTaskConfig.save_freq - 1):
if trainer_id == 0:
fluid.io.save_params(
exe,
os.path.join(TrainTaskConfig.model_dir,
"iter_" + str(step_idx) + ".infer.model"),train_prog)
if args.val_file_pattern is not None:
train_exe.drop_local_exe_scopes()
new_dense_test(step_idx)
new_forward_test(step_idx)
new_relative_position_test(step_idx)
batch_id += 1
step_idx += 1
except (StopIteration, fluid.core.EOFException):
break
def train(args):
"""
train
"""
is_local = os.getenv("PADDLE_IS_LOCAL", "1")
if is_local == '0':
args.local = False
print(args)
if args.device == 'CPU':
TrainTaskConfig.use_gpu = False
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
gpus = os.getenv("FLAGS_selected_gpus").split(",")
gpu_id = int(gpus[0])
if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu):
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else:
place = fluid.CUDAPlace(gpu_id)
dev_count = len(gpus)
exe = fluid.Executor(place)
train_prog = fluid.Program()
startup_prog = fluid.Program()
if args.enable_ce:
train_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(train_prog, startup_prog):
logits_list = []
data_input_names = encoder_data_input_fields + \
decoder_data_input_fields[:-1] + label_data_input_fields + dense_bias_input_fields
all_data_inputs, pyreader = make_all_py_reader_inputs(data_input_names)
with fluid.unique_name.guard("new_forward"):
new_forward_sum_cost, new_forward_avg_cost, new_forward_token_num, new_forward_logits, new_forward_xent, new_forward_loss, new_forward_label, new_forward_non_zeros = forward_transformer(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=True,
is_test=False,
params_type="new",
all_data_inputs=all_data_inputs)
with fluid.unique_name.guard("new_relative_position"):
new_relative_position_sum_cost, new_relative_position_avg_cost, new_relative_position_token_num, new_relative_position_logits, new_relative_position_xent, new_relative_position_loss, new_relative_position_label, new_relative_position_non_zeros = relative_transformer(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=args.use_py_reader,
is_test=False,
params_type="new",
all_data_inputs=all_data_inputs)
DenseModelHyperParams.src_vocab_size = ModelHyperParams.src_vocab_size
DenseModelHyperParams.trg_vocab_size = ModelHyperParams.trg_vocab_size
DenseModelHyperParams.weight_sharing = ModelHyperParams.weight_sharing
DenseModelHyperParams.embedding_sharing = ModelHyperParams.embedding_sharing
with fluid.unique_name.guard("new_dense"):
new_dense_sum_cost, new_dense_avg_cost, new_dense_token_num, new_dense_logits, new_dense_xent, new_dense_loss, new_dense_label, _ = dense_transformer(
DenseModelHyperParams.src_vocab_size,
DenseModelHyperParams.trg_vocab_size,
DenseModelHyperParams.max_length + 50,
DenseModelHyperParams.n_layer,
DenseModelHyperParams.enc_n_layer,
DenseModelHyperParams.n_head,
DenseModelHyperParams.d_key,
DenseModelHyperParams.d_value,
DenseModelHyperParams.d_model,
DenseModelHyperParams.d_inner_hid,
DenseModelHyperParams.prepostprocess_dropout,
DenseModelHyperParams.attention_dropout,
DenseModelHyperParams.relu_dropout,
DenseModelHyperParams.preprocess_cmd,
DenseModelHyperParams.postprocess_cmd,
DenseModelHyperParams.weight_sharing,
DenseModelHyperParams.embedding_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=args.use_py_reader,
is_test=False,
params_type="new",
all_data_inputs=all_data_inputs)
with fluid.unique_name.guard("fixed_forward"):
fixed_forward_sum_cost, fixed_forward_avg_cost, fixed_forward_token_num, fixed_forward_logits, fixed_forward_xent, fixed_forward_loss, fixed_forward_label, fixed_forward_non_zeros = forward_transformer(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=args.use_py_reader,
is_test=False,
params_type="fixed",
all_data_inputs=all_data_inputs)
logits_list.append(fixed_forward_logits)
DenseModelHyperParams.src_vocab_size = ModelHyperParams.src_vocab_size
DenseModelHyperParams.trg_vocab_size = ModelHyperParams.trg_vocab_size
DenseModelHyperParams.weight_sharing = ModelHyperParams.weight_sharing
DenseModelHyperParams.embedding_sharing = ModelHyperParams.embedding_sharing
with fluid.unique_name.guard("fixed_dense"):
fixed_dense_sum_cost, fixed_dense_avg_cost, fixed_dense_token_num, fixed_dense_logits, fixed_dense_xent, fixed_dense_loss, fixed_dense_label, _ = dense_transformer(
DenseModelHyperParams.src_vocab_size,
DenseModelHyperParams.trg_vocab_size,
DenseModelHyperParams.max_length + 50,
DenseModelHyperParams.n_layer,
DenseModelHyperParams.enc_n_layer,
DenseModelHyperParams.n_head,
DenseModelHyperParams.d_key,
DenseModelHyperParams.d_value,
DenseModelHyperParams.d_model,
DenseModelHyperParams.d_inner_hid,
DenseModelHyperParams.prepostprocess_dropout,
DenseModelHyperParams.attention_dropout,
DenseModelHyperParams.relu_dropout,
DenseModelHyperParams.preprocess_cmd,
DenseModelHyperParams.postprocess_cmd,
DenseModelHyperParams.weight_sharing,
DenseModelHyperParams.embedding_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=args.use_py_reader,
is_test=False,
params_type="fixed",
all_data_inputs=all_data_inputs)
logits_list.append(fixed_dense_logits)
with fluid.unique_name.guard("fixed_relative_position"):
fixed_relative_sum_cost, fixed_relative_avg_cost, fixed_relative_token_num, fixed_relative_logits, fixed_relative_xent, fixed_relative_loss, fixed_relative_label, _ = relative_transformer(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
TrainTaskConfig.label_smooth_eps,
use_py_reader=args.use_py_reader,
is_test=False,
params_type="fixed",
all_data_inputs=all_data_inputs)
logits_list.append(fixed_relative_logits)
# normalizing
confidence = 1.0 - TrainTaskConfig.label_smooth_eps
low_confidence = (1.0 - confidence) / (ModelHyperParams.trg_vocab_size - 1)
normalizing = -(confidence * math.log(confidence) + (ModelHyperParams.trg_vocab_size - 1) *
low_confidence * math.log(low_confidence + 1e-20))
batch_size = layers.shape(new_forward_logits)[0]
seq_length = layers.shape(new_forward_logits)[1]
trg_voc_size = layers.shape(new_forward_logits)[2]
# ensemble
teacher_logits = logits_list[0]
for index in xrange(1, len(logits_list)):
teacher_logits += logits_list[index]
teacher_logits = teacher_logits / len(logits_list)
# new_target
new_target = layers.softmax(teacher_logits)
new_target.stop_gradient = True
# agent_1: forward
fdistill_xent = layers.softmax_with_cross_entropy(
logits=new_forward_logits,
label=new_target,
soft_label=True)
fdistill_xent -= normalizing
fdistill_loss = layers.reduce_sum(fdistill_xent * new_forward_non_zeros) / new_forward_token_num
# agent_2: relative
rdistill_xent = layers.softmax_with_cross_entropy(
logits=new_relative_position_logits,
label=new_target,
soft_label=True)
rdistill_xent -= normalizing
rdistill_loss = layers.reduce_sum(rdistill_xent * new_forward_non_zeros) / new_forward_token_num
# agent_3: dense
ddistill_xent = layers.softmax_with_cross_entropy(
logits=new_dense_logits,
label=new_target,
soft_label=True)
ddistill_xent -= normalizing
ddistill_loss = layers.reduce_sum(ddistill_xent * new_forward_non_zeros) / new_forward_token_num
teacher_loss = fixed_forward_avg_cost + fixed_dense_avg_cost + fixed_relative_avg_cost
avg_cost = TrainTaskConfig.beta * new_forward_avg_cost + (1.0 - TrainTaskConfig.beta) * fdistill_loss + TrainTaskConfig.beta * new_relative_position_avg_cost + (1.0 - TrainTaskConfig.beta) * rdistill_loss + TrainTaskConfig.beta * new_dense_avg_cost + (1.0 - TrainTaskConfig.beta) * ddistill_loss + teacher_loss
avg_cost.persistable = True
teacher_loss.persistable = True
optimizer = None
if args.sync:
lr_decay = fluid.layers.learning_rate_scheduler.noam_decay(
ModelHyperParams.d_model, TrainTaskConfig.warmup_steps)
logging.info("before adam")
with fluid.default_main_program()._lr_schedule_guard():
learning_rate = lr_decay * TrainTaskConfig.learning_rate
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate,
beta1=TrainTaskConfig.beta1,
beta2=TrainTaskConfig.beta2,
epsilon=TrainTaskConfig.eps)
else:
optimizer = fluid.optimizer.SGD(0.003)
if args.use_fp16:
#black_varnames={"src_slf_attn_bias", "trg_slf_attn_bias", "trg_src_attn_bias", "dense_src_slf_attn_bias", "dense_trg_slf_attn_bias", "dense_trg_src_attn_bias"}
#amp_lists=fluid.contrib.mixed_precision.AutoMixedPrecisionLists(custom_black_varnames=black_varnames,
# custom_black_list=["dropout"])
#optimizer = fluid.contrib.mixed_precision.decorate(optimizer, amp_lists=amp_lists,
optimizer = fluid.contrib.mixed_precision.decorate(optimizer,
init_loss_scaling=32768, incr_every_n_steps=2000,
use_dynamic_loss_scaling=True)
optimizer.minimize(avg_cost)
loss_scaling=None
scaled_cost=None
if args.use_fp16:
scaled_cost = optimizer.get_scaled_loss()
loss_scaling = optimizer.get_loss_scaling()
if args.local:
logging.info("local start_up:")
train_loop(exe, train_prog, startup_prog, args, dev_count, avg_cost, teacher_loss, new_relative_position_sum_cost, new_relative_position_avg_cost,
new_relative_position_token_num, pyreader, place)
else:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
worker_endpoints_env = os.getenv("PADDLE_TRAINER_ENDPOINTS")
current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")
worker_endpoints = worker_endpoints_env.split(",")
trainers_num = len(worker_endpoints)
logging.info("worker_endpoints:{} trainers_num:{} current_endpoint:{} \
trainer_id:{}".format(worker_endpoints, trainers_num,
current_endpoint, trainer_id))
config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2"
if args.nccl_comm_num > 1:
config.nccl_comm_num = args.nccl_comm_num
if args.use_hierarchical_allreduce and trainers_num > args.hierarchical_allreduce_inter_nranks:
logging.info("use_hierarchical_allreduce")
config.use_hierarchical_allreduce=args.use_hierarchical_allreduce
config.hierarchical_allreduce_inter_nranks=8
if config.hierarchical_allreduce_inter_nranks > 1:
config.hierarchical_allreduce_inter_nranks=args.hierarchical_allreduce_inter_nranks
assert config.hierarchical_allreduce_inter_nranks > 1
assert trainers_num % config.hierarchical_allreduce_inter_nranks == 0
config.hierarchical_allreduce_exter_nranks = \
trainers_num / config.hierarchical_allreduce_inter_nranks
t = fluid.DistributeTranspiler(config=config)
t.transpile(
trainer_id, trainers=worker_endpoints_env,
current_endpoint=current_endpoint, program=train_prog,
startup_program=startup_prog)
train_loop(exe, train_prog, startup_prog, args, dev_count, avg_cost, teacher_loss,
new_relative_position_sum_cost, new_relative_position_avg_cost, new_relative_position_token_num, pyreader, place, trainers_num, trainer_id, scaled_cost=scaled_cost, loss_scaling=loss_scaling)
if __name__ == "__main__":
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(
stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
logging.getLogger().setLevel(logging.INFO)
args = parse_args()
train(args)
#!/bin/bash
source ./env/env.sh
source ./env/utils.sh
source ./env/cloud_job_conf.conf
iplist=$1
#iplist=`echo $nodelist | xargs | sed 's/ /,/g'`
if [ ! -d log ]
then
mkdir log
fi
export GLOG_vmodule=fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10
if [[ ${FUSE} == "1" ]]; then
export FLAGS_fuse_parameter_memory_size=64 #MB
fi
set -ux
check_iplist
distributed_args=""
if [[ ${NUM_CARDS} == "1" ]]; then
distributed_args="--selected_gpus 0"
fi
node_ips=${PADDLE_TRAINERS}
distributed_args="--node_ips ${PADDLE_TRAINERS} --node_id ${PADDLE_TRAINER_ID} --current_node_ip ${POD_IP} --nproc_per_node 8 --selected_gpus 0,1,2,3,4,5,6,7"
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.1
export NCCL_IB_GID_INDEX=3
export NCCL_IB_RETRY_CNT=10
export FLAGS_sync_nccl_allreduce=0
BATCH_SIZE=1250
python -u ./src/launch.py ${distributed_args} \
./src/train.py \
--src_vocab_size 37007 \
--tgt_vocab_size 37007 \
--train_file_pattern 'data/translate-train-*' \
--token_delimiter ' ' \
--batch_size ${BATCH_SIZE} \
--use_py_reader True \
--use_delay_load True \
--nccl_comm_num ${NCCL_COMM_NUM} \
--use_hierarchical_allreduce ${USE_HIERARCHICAL_ALLREDUCE} \
--fetch_steps 50 \
--fuse ${FUSE} \
--val_file_pattern 'testset/testfile' \
--infer_batch_size 32 \
--decode_alpha 0.3 \
--beam_size 4 \
--use_fp16 True \
learning_rate 2.0 \
warmup_steps 8000 \
beta2 0.997 \
d_model 1024 \
d_inner_hid 4096 \
n_head 16 \
prepostprocess_dropout 0.3 \
attention_dropout 0.1 \
relu_dropout 0.1 \
embedding_sharing True \
pass_num 100 \
max_length 256 \
save_freq 5000 \
model_dir 'output'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册