未验证 提交 497faa61 编写于 作者: J jiangbojian 提交者: GitHub

Mal (#4172)

* EMNLP2019-MAL

* EMNLP2019-MAL, update README.md

* EMNLP2019-MAL, update README.md

* EMNLP2019-MAL, update license
上级 a7ef0aef
# Multi-agent Learning for Neural Machine Translation(MAL)
## 简介
MAL是百度翻译团队近期提出的首个多智能体端到端联合学习框架,该框架显著提升了单智能体学习能力,在多个机器翻译测试集上刷新了当前最好结果。 该框架投稿并被EMNLP2019录用 [Multi-agent Learning for Neural Machine Translation](https://www.aclweb.org/anthology/D19-1079.pdf)。 具体结构如下:
<p align="center">
<img src="images/arch.png" width = "340" height = "300" /> <br />
MAL整体框架
</p>
这个repo包含了PaddlePaddle版本的MAL实现,框架在论文的基础上做了一些修改,在WMT英德2014测试集上BLEU达到30.04,超过了论文中的结果,在不改变模型结构的基础上,刷新了SOTA。
### 实验结果
#### WMT 英德
| Models | En-De |
| :------------- | :---------: |
| [ConvS2S](https://pdfs.semanticscholar.org/bb3e/bc09b65728d6eced04929df72a006fb5210b.pdf) | 25.20 |
| [Transformer](https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf) | 28.40 |
| [Rel](https://www.aclweb.org/anthology/N18-2074.pdf) | 29.20 |
| [DynamicConv](https://openreview.net/pdf?id=SkVhlh09tX) | 29.70 |
| L2R | 28.88 |
| MAL-L2R | **30.04** |
## 运行
### 环境
运行环境需要满足如下要求:
+ python 2.7
+ paddlepaddle-gpu (1.6.1)
+ CUDA, CuDNN and NCCL (CUDA 9.0, CuDNN v7 and NCCL 2.3.5)
WMT英德的实验结果复现需要56张 32G V100, 运行30W步左右。
### 数据准备
运行get_data.sh脚本拉取原始数据并做预处理,形成训练需要的文件格式
```
sh get_data.sh
```
### 模型运行
在运行前,需要配置CUDA, CuDNN, NCCL的路径,具体路径修改在env/env.sh
调用train.sh运行MAL,产出的模型在output下,模型会边训练,边预测,针对训练过程中解码出来的文件,可以调用evaluate.sh来测BLEU
在train.sh中有个参数是distributed_args,这里需要使用者根据自身机器的情况来改变,需要修改的有nproc_per_node和selected_gpus,nproc_per_node代表每台机器需要使用几张卡,selected_gpus为gpu的卡号,例如一台8卡的v100,使用8张卡跑训练,那么nproc_per_node设置为8,selected_gpus为0, 1, 2, 3, 4, 5, 6, 7
```
sh train.sh ip1,ip2,ip3...(机器的ip地址,不要写127.0.0.1,填写hostname -i的结果) &
sh evaluate.sh file_path(预测出的文件,在output路径下)
```
### 复现论文中结果
我们提供了MAL在英德任务上训练出的模型,调用infer.sh可以观察到最终结果(因为测试集需要提前生成,所以在调用infer.sh前,请先调用get_data.sh,同时也需要设置好CUDA, CuDNN路径)
```
sh infer.sh
```
### 代码结构
我们主要的代码均在src文件夹中
train.py 训练的入口文件
infer.py 模型预测入口
config.py 定义了该项目模型的相关配置,包括具体模型类别、以及模型的超参数
reader.py 定义了读入数据的功能
bleu_hook.py BLEU计算脚本
####
################################## User Define Configuration ###########################
################################## Data Configuration ##################################
#type of storage cluster
#storage_type = "afs"
#attention: files for training should be put on hdfs
##the list contains all file locations should be specified here
#fs_name = "afs://xingtian.afs.baidu.com:9902"
##If force_reuse_output_path is True ,paddle will remove output_path without check output_path exist
#force_reuse_output_path = "True"
##ugi of hdfs
#fs_ugi = "NLP_KM_Data,NLP_km_2018"
#the initial model path on hdfs used to init parameters
#init_model_path=
#the initial model path for pservers
#pserver_model_dir=
#which pass
#pserver_model_pass=
#example of above 2 args:
#if set pserver_model_dir to /app/paddle/models
#and set pserver_model_pass to 123
#then rank 0 will download model from /app/paddle/models/rank-00000/pass-00123/
#and rank 1 will download model from /app/paddle/models/rank-00001/pass-00123/, etc.
##train data path on hdfs
#train_data_path = "/user/NLP_KM_Data/gongweibao/transformer/paddle_training_data/train_data"
##test data path on hdfs, can be null or not setted
#test_data_path = "/app/inf/mpi/bml-guest/paddle-platform/dataset/mnist/data/test/"
#the output directory on hdfs
#output_path = "/user/NLP_KM_Data/gongweibao/transformer/output"
#add datareader to thirdparty
#thirdparty_path = "/user/NLP_KM_Data/gongweibao/transformer/thirdparty"
FLAGS_rpc_deadline=3000000
#whl_name=paddlepaddle_ab57d3_post97_gpu-0.0.0-cp27-cp27mu-linux_x86_64.whl
#dataset_path=/user/NLP_KM_Data/gongweibao/transformer/small/paddle_training_data
PROFILE=0
FUSE=1
NCCL_COMM_NUM=2
NUM_THREADS=3
USE_HIERARCHICAL_ALLREDUCE=True
NUM_CARDS=8
NUM_EPOCHS=100
BATCH_SIZE=4096
#!/bin/bash
export BASE_PATH="$PWD"
#NCCL
export NCCL_DEBUG=INFO
export NCCL_IB_GID_INDEX=3
#export NCCL_IB_RETRY_CNT=0
#PADDLE
export FLAGS_fraction_of_gpu_memory_to_use=0.98
export FLAGS_sync_nccl_allreduce=0
export FLAGS_eager_delete_tensor_gb=0.0
#Cudnn
#export FLAGS_cudnn_exhaustive_search=1
export LD_LIBRARY_PATH=/home/work/cuda-9.0/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/home/work/cudnn/cudnn_v7/cuda/lib64:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH="${BASE_PATH}/nccl_2.3.5/lib/:$LD_LIBRARY_PATH"
#proxy
unset https_proxy http_proxy
# GLOG
export GLOG_v=1
#export GLOG_vmodule=fused_all_reduce_op_handle=10,all_reduce_op_handle=10,alloc_continuous_space_op=10,fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10,fast_threaded_ssa_graph_executor=10,threaded_ssa_graph_executor=10,backward_op_deps_pass=10,graph=10
export GLOG_logtostderr=1
#!/bin/bash
set -u
function check_iplist() {
if [ ${iplist:-} ]; then
#paddle envs
export PADDLE_PSERVER_PORT=9184
export PADDLE_TRAINER_IPS=${iplist}
#export PADDLE_CURRENT_IP=`/sbin/ip a | grep inet | grep global | awk '{print $2}' | sed 's/\/[0-9][0-9].*$//g'`
export PADDLE_CURRENT_IP=`hostname -i`
iparray=(${iplist//,/ })
for i in "${!iparray[@]}"; do
echo $i
if [ ${iparray[$i]} == ${PADDLE_CURRENT_IP} ]; then
export PADDLE_TRAINER_ID=$i
fi
done
export TRAINING_ROLE=TRAINER
#export PADDLE_PSERVERS=127.0.0.1
export PADDLE_INIT_TRAINER_COUNT=${#iparray[@]}
export PADDLE_PORT=${PADDLE_PSERVER_PORT}
export PADDLE_TRAINERS=${PADDLE_TRAINER_IPS}
export POD_IP=${PADDLE_CURRENT_IP}
export PADDLE_TRAINERS_NUM=${PADDLE_INIT_TRAINER_COUNT}
#is local
export PADDLE_IS_LOCAL=0
echo "****************************************************"
#paddle debug envs
export GLOG_v=0
export GLOG_logtostderr=1
#nccl debug envs
export NCCL_DEBUG=INFO
#export NCCL_IB_DISABLE=1
#export NCCL_IB_GDR_LEVEL=4
export NCCL_IB_GID_INDEX=3
#export NCCL_SOCKET_IFNAME=eth2
fi
}
#! /bin/sh
path=$1
python ./src/id2word.py data/vocab.source.32000 < ${path} > ${path}_word
head -n 3003 ${path}_word > ${path}_word_tmp
mv ${path}_word_tmp ${path}_word
cat ${path}_word | sed 's/@@ //g' > ${path}.trans.post
python ./src/bleu_hook.py --reference wmt16_en_de/newstest2014.tok.de --translation ${path}.trans.post
#! /bin/sh
tmp_dir=wmt16_en_de
data_dir=data
source_file=train.tok.clean.bpe.32000.en
target_file=train.tok.clean.bpe.32000.de
source_vocab_size=32000
target_vocab_size=32000
num_shards=100
if [ ! -d wmt16_en_de ]
then
mkdir wmt16_en_de
fi
wget https://baidu-nlp.bj.bcebos.com/EMNLP2019-MAL/wmt16_en_de.tar.gz -O wmt16_en_de/wmt16_en_de.tar.gz
tar -zxf wmt16_en_de/wmt16_en_de.tar.gz -C wmt16_en_de
if [ ! -d $data_dir ]
then
mkdir data
fi
if [ ! -d testset ]
then
mkdir testset
fi
cp wmt16_en_de/vocab.bpe.32000 data/vocab.source.32000
python ./src/gen_records.py --tmp_dir ${tmp_dir} --data_dir ${data_dir} --source_train_files ${source_file} --target_train_files ${target_file} --source_vocab_size ${source_vocab_size} --target_vocab_size ${target_vocab_size} --num_shards ${num_shards} --token True --onevocab True
python ./src/preprocess/gen_utils.py --vocab $data_dir/vocab.source.${source_vocab_size} --testset ${tmp_dir}/newstest2014.tok.bpe.32000.en --output ./testset/testfile
#! /bin/sh
export LD_LIBRARY_PATH=/home/work/cuda-9.0/lib64:/home/work/cudnn/cudnn_v7/cuda/lib64:/home/work/cuda-9.0/extras/CUPTI/lib64:$LD_LIBRARY_PATH
wget https://baidu-nlp.bj.bcebos.com/EMNLP2019-MAL/checkpoint.best.tgz
tar -zxf checkpoint.best.tgz
infer(){
CUDA_VISIBLE_DEVICES=$1 python -u src/infer.py \
--val_file_pattern $3 \
--vocab_size $4 \
--special_token '<s>' '<e>' '<unk>' \
--use_mem_opt True \
--use_delay_load True \
--infer_batch_size 16 \
--decode_alpha 0.3 \
d_model 1024 \
d_inner_hid 4096 \
n_head 16 \
prepostprocess_dropout 0.0 \
attention_dropout 0.0 \
relu_dropout 0.0 \
model_path $2 \
beam_size 4 \
max_out_len 306 \
max_length 256
}
infer 0 checkpoint.best testset/testfile 37007
sh evaluate.sh trans/forward_checkpoint.best
grep "BLEU_cased" trans/*
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import argparse
def str2bool(v):
"""
because argparse does not support to parse "true, False" as python
boolean directly
"""
return v.lower() in ("true", "t", "1")
class ArgumentGroup(object):
"""
ArgumentGroup
"""
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, positional_arg=False, **kwargs):
"""
add_arg
"""
prefix = "" if positional_arg else "--"
type = str2bool if type == bool else type
self._group.add_argument(
prefix + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def print_arguments(args):
"""
print_arguments
"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def inv_arguments(args):
"""
inv_arguments
"""
print('[Warning] Only keyword argument type is supported.')
args_list = []
for arg, value in sorted(six.iteritems(vars(args))):
args_list.extend(['--' + str(arg), str(value)])
return args_list
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layer_helper import LayerHelper as LayerHelper
def generate_relative_positions_matrix(length, max_relative_position, cache=False):
if not cache:
range_vec = layers.range(0, length, 1, 'int32')
range_vec.stop_gradient = True
shapes = layers.shape(range_vec)
range_vec = layers.reshape(range_vec, shape=[1, shapes[0]])
range_mat = layers.expand(range_vec, [shapes[0], 1])
distance_mat = range_mat - layers.transpose(range_mat, [1, 0])
else:
distance_mat = layers.range(-1 * length+1, 1, 1, 'int32')
distance_mat.stop_gradient = True
shapes = layers.shape(distance_mat)
distance_mat = layers.reshape(distance_mat, [1, shapes[0]])
distance_mat_clipped = layers.clip(layers.cast(distance_mat, dtype="float32"), float(-max_relative_position), float(max_relative_position))
final_mat = layers.cast(distance_mat_clipped, dtype = 'int32') + max_relative_position
return final_mat
def generate_relative_positions_embeddings(length, depth, max_relative_position, name, cache=False):
relative_positions_matrix = generate_relative_positions_matrix(
length, max_relative_position, cache=cache)
y = layers.reshape(relative_positions_matrix, [-1])
y.stop_gradient = True
vocab_size = max_relative_position * 2 + 1
#embeddings_table = layers.create_parameter(shape=[vocab_size, depth], dtype='float32', default_initializer=fluid.initializer.Constant(1.2345), name=name)
embeddings_table = layers.create_parameter(shape=[vocab_size, depth], dtype='float32', name=name)
#layers.Print(embeddings_table, message = "embeddings_table=====")
embeddings_1 = layers.gather(embeddings_table, y)
embeddings = layers.reshape(embeddings_1, [-1, length, depth])
return embeddings
def _relative_attention_inner(q, k, v, transpose):
batch_size = layers.shape(q)[0]
heads = layers.shape(q)[1]
length = layers.shape(q)[2]
xy_matmul = layers.matmul(q, k, transpose_y=transpose)
x_t = layers.transpose(q, [2, 0, 1, 3])
x_t_r = layers.reshape(x_t, [length, batch_size * heads, -1])
x_tz_matmul = layers.matmul(x_t_r, v, transpose_y = transpose)
x_tz_matmul_r = layers.reshape(x_tz_matmul, [length, batch_size, heads, -1])
x_tz_matmul_r_t = layers.transpose(x_tz_matmul_r, [1, 2, 0, 3])
return xy_matmul + x_tz_matmul_r_t
def _dot_product_relative(q, k, v, bias, dropout=0.1, cache=None, params_type="normal"):
depth_constant = int(k.shape[3])
heads = layers.shape(k)[1]
length = layers.shape(k)[2]
max_relative_position = 4
pre_name = "relative_positions_"
if params_type == "fixed":
pre_name = "fixed_relative_positions_"
elif params_type == "new":
pre_name = "new_relative_positions_"
relations_keys = generate_relative_positions_embeddings(
length, depth_constant, max_relative_position, name=pre_name + "keys",
cache=cache is not None)
relations_values = generate_relative_positions_embeddings(
length, depth_constant, max_relative_position,
name = pre_name + "values",
cache=cache is not None)
logits = _relative_attention_inner(q, k, relations_keys, True)
if bias is not None: logits += bias
weights = layers.softmax(logits, name = "attention_weights")
weights = layers.dropout(weights, dropout_prob=float(dropout))
output = _relative_attention_inner(weights, v, relations_values, False)
return output
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key ** -0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
seed=ModelHyperParams.dropout_seed,
is_test=False, dropout_implementation='upscale_in_train')
out = layers.matmul(weights, v)
return out
if __name__ == "__main__":
batch_size = 2
heads = 8
length = 5
depth = 3
cpu = fluid.core.CPUPlace()
exe = fluid.Executor(cpu)
startup_prog = fluid.Program()
train_prog = fluid.Program()
with fluid.program_guard(train_prog, startup_prog):
with fluid.unique_name.guard("forward"):
x = layers.reshape(layers.cast(layers.range(0, 18, 1, "int32"), dtype = "float32"), shape =[-1, 3, 3])
y = layers.reshape(layers.cast(layers.range(0, 2, 1, "int32"), dtype = "float32"), shape =[-1, 1])
z = x * y
exe.run(startup_prog)
outs = exe.run(train_prog, fetch_list=[x, y, z])
print outs[0]
print outs[1]
print outs[2]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers as layers
INF = 1. * 1e9
class BeamSearch(object):
"""
beam_search class
"""
def __init__(self, beam_size, batch_size, alpha, vocab_size, hidden_size):
self.beam_size = beam_size
self.batch_size = batch_size
self.alpha = alpha
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.gather_top2k_append_index = layers.range(0, 2 * self.batch_size * beam_size, 1, 'int64') // \
(2 * self.beam_size) * (self.beam_size)
self.gather_topk_append_index = layers.range(0, self.batch_size * beam_size, 1, 'int64') // \
self.beam_size * (2 * self.beam_size)
self.gather_finish_topk_append_index = layers.range(0, self.batch_size * beam_size, 1, 'int64') // \
self.beam_size * (3 * self.beam_size)
self.eos_id = layers.fill_constant([self.batch_size, 2 * self.beam_size], 'int64', value=1)
self.get_alive_index = layers.range(0, self.batch_size, 1, 'int64') * self.beam_size
def gather_cache(self, kv_caches, select_id):
"""
gather cache
"""
for index in xrange(len(kv_caches)):
kv_cache = kv_caches[index]
select_k = layers.gather(kv_cache['k'], [select_id])
select_v = layers.gather(kv_cache['v'], [select_id])
layers.assign(select_k, kv_caches[index]['k'])
layers.assign(select_v, kv_caches[index]['v'])
# topk_seq, topk_scores, topk_log_probs, topk_finished, cache
def compute_topk_scores_and_seq(self, sequences, scores, scores_to_gather, flags, pick_finish=False, cache=None):
"""
compute_topk_scores_and_seq
"""
topk_scores, topk_indexes = layers.topk(scores, k=self.beam_size) #[batch_size, beam_size]
if not pick_finish:
flat_topk_indexes = layers.reshape(topk_indexes, [-1]) + self.gather_topk_append_index
flat_sequences = layers.reshape(sequences, [2 * self.batch_size * self.beam_size, -1])
else:
flat_topk_indexes = layers.reshape(topk_indexes, [-1]) + self.gather_finish_topk_append_index
flat_sequences = layers.reshape(sequences, [3 * self.batch_size * self.beam_size, -1])
topk_seq = layers.gather(flat_sequences, [flat_topk_indexes])
topk_seq = layers.reshape(topk_seq, [self.batch_size, self.beam_size, -1])
flat_flags = layers.reshape(flags, [-1])
topk_flags = layers.gather(flat_flags, [flat_topk_indexes])
topk_flags = layers.reshape(topk_flags, [-1, self.beam_size])
flat_scores = layers.reshape(scores_to_gather, [-1])
topk_gathered_scores = layers.gather(flat_scores, [flat_topk_indexes])
topk_gathered_scores = layers.reshape(topk_gathered_scores, [-1, self.beam_size])
if cache:
self.gather_cache(cache, flat_topk_indexes)
return topk_seq, topk_gathered_scores, topk_flags, cache
def grow_topk(self, i, logits, alive_seq, alive_log_probs, cache, enc_output, enc_bias):
"""
grow_topk
"""
logits = layers.reshape(logits, [self.batch_size, self.beam_size, -1])
candidate_log_probs = layers.log(layers.softmax(logits, axis=2))
log_probs = candidate_log_probs + layers.unsqueeze(alive_log_probs, axes=[2])
base_1 = layers.cast(i, 'float32') + 6.0
base_1 /= 6.0
length_penalty = layers.pow(base_1, self.alpha)
#length_penalty = layers.pow(((5.0 + layers.cast(i+1, 'float32')) / 6.0), self.alpha)
curr_scores = log_probs / length_penalty
flat_curr_scores = layers.reshape(curr_scores, [self.batch_size, self.beam_size * self.vocab_size])
topk_scores, topk_ids = layers.topk(flat_curr_scores, k=self.beam_size * 2)
topk_log_probs = topk_scores * length_penalty
select_beam_index = topk_ids // self.vocab_size
select_id = topk_ids % self.vocab_size
#layers.Print(select_id, message="select_id", summarize=1024)
#layers.Print(topk_scores, message="topk_scores", summarize=10000000)
flat_select_beam_index = layers.reshape(select_beam_index, [-1]) + self.gather_top2k_append_index
topk_seq = layers.gather(alive_seq, [flat_select_beam_index])
topk_seq = layers.reshape(topk_seq, [self.batch_size, 2 * self.beam_size, -1])
#concat with current ids
topk_seq = layers.concat([topk_seq, layers.unsqueeze(select_id, axes=[2])], axis=2)
topk_finished = layers.cast(layers.equal(select_id, self.eos_id), 'float32')
#gather cache
self.gather_cache(cache, flat_select_beam_index)
#topk_seq: [batch_size, 2*beam_size, i+1]
#topk_log_probs, topk_scores, topk_finished: [batch_size, 2*beam_size]
return topk_seq, topk_log_probs, topk_scores, topk_finished, cache
def grow_alive(self, curr_seq, curr_scores, curr_log_probs, curr_finished, cache):
"""
grow_alive
"""
finish_float_flag = layers.cast(curr_finished, 'float32')
finish_float_flag = finish_float_flag * -INF
curr_scores += finish_float_flag
return self.compute_topk_scores_and_seq(curr_seq, curr_scores,
curr_log_probs, curr_finished, cache=cache)
def grow_finished(self, i, finished_seq, finished_scores, finished_flags, curr_seq,
curr_scores, curr_finished):
"""
grow_finished
"""
finished_seq = layers.concat([finished_seq,
layers.fill_constant([self.batch_size, self.beam_size, 1], dtype='int64', value=0)],
axis=2)
curr_scores = curr_scores + (1.0 - layers.cast(curr_finished, 'int64')) * -INF
curr_finished_seq = layers.concat([finished_seq, curr_seq], axis=1)
curr_finished_scores = layers.concat([finished_scores, curr_scores], axis=1)
curr_finished_flags = layers.concat([finished_flags, curr_finished], axis=1)
return self.compute_topk_scores_and_seq(curr_finished_seq, curr_finished_scores,
curr_finished_scores, curr_finished_flags,
pick_finish=True)
def inner_func(self, i, logits, alive_seq, alive_log_probs, finished_seq, finished_scores,
finished_flags, cache, enc_output, enc_bias):
"""
inner_func
"""
topk_seq, topk_log_probs, topk_scores, topk_finished, cache = self.grow_topk(
i, logits, alive_seq, alive_log_probs, cache, enc_output, enc_bias)
alive_seq, alive_log_probs, _, cache = self.grow_alive(
topk_seq, topk_scores, topk_log_probs, topk_finished, cache)
#layers.Print(alive_seq, message="alive_seq", summarize=1024)
finished_seq, finished_scores, finished_flags, _ = self.grow_finished(
i, finished_seq, finished_scores, finished_flags, topk_seq, topk_scores, topk_finished)
return alive_seq, alive_log_probs, finished_seq, finished_scores, finished_flags, cache
def is_finished(self, step_idx, source_length, alive_log_probs, finished_scores, finished_in_finished):
"""
is_finished
"""
base_1 = layers.cast(source_length, 'float32') + 55.0
base_1 /= 6.0
max_length_penalty = layers.pow(base_1, self.alpha)
flat_alive_log_probs = layers.reshape(alive_log_probs, [-1])
lower_bound_alive_scores_1 = layers.gather(flat_alive_log_probs, [self.get_alive_index])
lower_bound_alive_scores = lower_bound_alive_scores_1 / max_length_penalty
lowest_score_of_finished_in_finish = layers.reduce_min(finished_scores * finished_in_finished, dim=1)
finished_in_finished = layers.cast(finished_in_finished, 'bool')
lowest_score_of_finished_in_finish += \
((1.0 - layers.cast(layers.reduce_any(finished_in_finished, 1), 'float32')) * -INF)
#print lowest_score_of_finished_in_finish
bound_is_met = layers.reduce_all(layers.greater_than(lowest_score_of_finished_in_finish,
lower_bound_alive_scores))
decode_length = source_length + 50
length_cond = layers.less_than(x=step_idx, y=decode_length)
return layers.logical_and(x=layers.logical_not(bound_is_met), y=length_cond)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
import collections
import math
import os
import re
import sys
import time
import unicodedata
# Dependency imports
import numpy as np
import six
from six.moves import range
from six.moves import zip
from preprocess import text_encoder
def _get_ngrams(segment, max_order):
"""Extracts all n-grams up to a given maximum order from an input segment.
Args:
segment: text segment from which n-grams will be extracted.
max_order: maximum length in tokens of the n-grams returned by this
methods.
Returns:
The Counter containing all n-grams up to max_order in segment
with a count of how many times each n-gram occurred.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i:i + order])
ngram_counts[ngram] += 1
return ngram_counts
def compute_bleu(reference_corpus,
translation_corpus,
max_order=4,
use_bp=True):
"""Computes BLEU score of translated segments against one or more references.
Args:
reference_corpus: list of references for each translation. Each
reference should be tokenized into a list of tokens.
translation_corpus: list of translations to score. Each translation
should be tokenized into a list of tokens.
max_order: Maximum n-gram order to use when computing BLEU score.
use_bp: boolean, whether to apply brevity penalty.
Returns:
BLEU score.
"""
reference_length = 0
translation_length = 0
bp = 1.0
geo_mean = 0
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
precisions = []
for (references, translations) in zip(reference_corpus, translation_corpus):
reference_length += len(references)
translation_length += len(translations)
ref_ngram_counts = _get_ngrams(references, max_order)
translation_ngram_counts = _get_ngrams(translations, max_order)
overlap = dict((ngram,
min(count, translation_ngram_counts[ngram]))
for ngram, count in ref_ngram_counts.items())
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for ngram in translation_ngram_counts:
possible_matches_by_order[len(ngram)-1] += translation_ngram_counts[ngram]
precisions = [0] * max_order
smooth = 1.0
for i in range(0, max_order):
if possible_matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
if matches_by_order[i] > 0:
precisions[i] = matches_by_order[i] / possible_matches_by_order[i]
else:
smooth *= 2
precisions[i] = 1.0 / (smooth * possible_matches_by_order[i])
else:
precisions[i] = 0.0
if max(precisions) > 0:
p_log_sum = sum(math.log(p) for p in precisions if p)
geo_mean = math.exp(p_log_sum / max_order)
if use_bp:
ratio = (translation_length + 1e-6) / reference_length
bp = math.exp(1 - 1. / ratio) if ratio < 1.0 else 1.0
bleu = geo_mean * bp
return np.float32(bleu)
class UnicodeRegex(object):
"""Ad-hoc hack to recognize all punctuation and symbols."""
def __init__(self):
punctuation = self.property_chars("P")
self.nondigit_punct_re = re.compile(r"([^\d])([" + punctuation + r"])")
self.punct_nondigit_re = re.compile(r"([" + punctuation + r"])([^\d])")
self.symbol_re = re.compile("([" + self.property_chars("S") + "])")
def property_chars(self, prefix):
"""
get unicode of specified chars
"""
return "".join(six.unichr(x) for x in range(sys.maxunicode)
if unicodedata.category(six.unichr(x)).startswith(prefix))
uregex = UnicodeRegex()
def bleu_tokenize(string):
r"""Tokenize a string following the official BLEU implementation.
See https://github.com/moses-smt/mosesdecoder/"
"blob/master/scripts/generic/mteval-v14.pl#L954-L983
In our case, the input string is expected to be just one line
and no HTML entities de-escaping is needed.
So we just tokenize on punctuation and symbols,
except when a punctuation is preceded and followed by a digit
(e.g. a comma/dot as a thousand/decimal separator).
Note that a number (e.g. a year) followed by a dot at the end of sentence
is NOT tokenized,
i.e. the dot stays with the number because `s/(\p{P})(\P{N})/ $1 $2/g`
does not match this case (unless we add a space after each sentence).
However, this error is already in the original mteval-v14.pl
and we want to be consistent with it.
Args:
string: the input string
Returns:
a list of tokens
"""
string = uregex.nondigit_punct_re.sub(r"\1 \2 ", string)
string = uregex.punct_nondigit_re.sub(r" \1 \2", string)
string = uregex.symbol_re.sub(r" \1 ", string)
return string.split()
def bleu_wrapper(ref_filename, hyp_filename, case_sensitive=False):
"""Compute BLEU for two files (reference and hypothesis translation)."""
ref_lines = text_encoder.native_to_unicode(
open(ref_filename, "r").read()).splitlines()
hyp_lines = text_encoder.native_to_unicode(
open(hyp_filename, "r").read()).splitlines()
assert len(ref_lines) == len(hyp_lines)
if not case_sensitive:
ref_lines = [x.lower() for x in ref_lines]
hyp_lines = [x.lower() for x in hyp_lines]
ref_tokens = [bleu_tokenize(x) for x in ref_lines]
hyp_tokens = [bleu_tokenize(x) for x in hyp_lines]
return compute_bleu(ref_tokens, hyp_tokens)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Calc BLEU.")
parser.add_argument(
"--reference",
type=str,
required=True,
help="path of reference.")
parser.add_argument(
"--translation",
type=str,
required=True,
help="path of translation.")
args = parser.parse_args()
bleu_uncased = 100 * bleu_wrapper(args.reference, args.translation, case_sensitive=False)
bleu_cased = 100 * bleu_wrapper(args.reference, args.translation, case_sensitive=True)
f = open("%s.bleu" % args.translation, 'w')
f.write("BLEU_uncased = %6.2f\n" % (bleu_uncased))
f.write("BLEU_cased = %6.2f\n" % (bleu_cased))
f.close()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainTaskConfig(object):
"""
TrainTaskConfig
"""
# support both CPU and GPU now.
use_gpu = True
# the epoch number to train.
pass_num = 30
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 4.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
# the directory for saving trained models.
model_dir = "trained_models"
# the directory for saving checkpoints.
ckpt_dir = "trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path = None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step = 0
# the frequency to save trained models.
save_freq = 5000
# the frequency to copy unfixed parameters to fixed parameters
fixed_freq = 50000
beta = 0.7
class InferTaskConfig(object):
"""
InferTaskConfig
"""
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 10
# the parameters for beam search.
beam_size = 5
max_out_len = 256
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = True
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
decode_alpha = 0.6
class ModelHyperParams(object):
"""
ModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences deciding the size of position encoding table.
max_length = 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 1024
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 4096
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 16
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE.
dropout_seed = None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
embedding_sharing = True
class DenseModelHyperParams(object):
"""
DenseModelHyperParams
"""
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 37007
# size of target word dictionay
trg_vocab_size = 37007
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# max length of sequences deciding the size of position encoding table.
max_length = 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
enc_n_layer = 25
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE.
dropout_seed = None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
embedding_sharing = True
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = ModelHyperParams.max_length
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
"dense_src_slf_attn_bias": [(batch_size, DenseModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
"reverse_trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
"dense_trg_slf_attn_bias": [(batch_size, DenseModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, ModelHyperParams.n_head, seq_len,
seq_len), "float32"],
"dense_trg_src_attn_bias": [(batch_size, DenseModelHyperParams.n_head, seq_len,
seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
"reverse_lbl_word": [(batch_size * seq_len, 1), "int64"],
"eos_position": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32"],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
"decode_length": [(batch_size, ), "int64"],
}
# Names of word embedding table which might be reused for weight sharing.
dense_word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
dense_pos_enc_param_names = (
"dense_src_pos_enc_table",
"dense_trg_pos_enc_table", )
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table", )
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias", )
# separated inputs for different usages.
dense_encoder_data_input_fields = (
"src_word",
"src_pos",
"dense_src_slf_attn_bias", )
decoder_data_input_fields = (
"trg_word",
"reverse_trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
dense_decoder_data_input_fields = (
"trg_word",
"reverse_trg_word",
"trg_pos",
"dense_trg_slf_attn_bias",
"dense_trg_src_attn_bias",
"enc_output", )
label_data_input_fields = (
"lbl_word",
"lbl_weight",
"reverse_lbl_word",
"eos_position")
dense_bias_input_fields = (
"dense_src_slf_attn_bias",
"dense_trg_slf_attn_bias",
"dense_trg_src_attn_bias")
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"dense_src_slf_attn_bias", )
fast_decoder_data_input_fields = (
"decode_length", )
此差异已折叠。
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os, sys
import random
import six
import ast
class TRDataGen(object):
"""record data generator
"""
def __init__(self, num_shards, data_dir):
self.num_shards = num_shards
self.data_dir = data_dir
def gen_data_fnames(self, is_train=True):
"""generate filenames for train and valid
return:
train_filenames, valid_filenames
"""
if not os.path.isdir(self.data_dir):
try:
os.mkdir(self.data_dir)
except Exception as e:
raise ValueError("%s is exists as one file", self.data_dir)
if is_train:
train_prefix = os.path.join(self.data_dir, "translate-train-%05d-of_unshuffle")
return [train_prefix % i for i in xrange(self.num_shards)]
return [os.path.join(self.data_dir, "translate-dev-00000-of_unshuffle")]
def generate(self, data_list, is_train=True, is_shuffle=True):
"""generating record file
:param data_list:
:param is_train:
:return:
"""
output_filename = self.gen_data_fnames(is_train)
#writers = [tf.python_io.TFRecordWriter(fname) for fname in output_filename]
writers = [open(fname, 'w') for fname in output_filename]
ct = 0
shard = 0
for case in data_list:
ct += 1
if ct % 10000 == 0:
logging.info("Generating case %s ." % ct)
example = self.to_example(case)
writers[shard].write(example.strip() + "\n")
if is_train:
shard = (shard + 1) % self.num_shards
logging.info("Generating case %s ." % ct)
for writer in writers:
writer.close()
if is_shuffle:
self.shuffle_dataset(output_filename)
def to_example(self, dictionary):
"""
:param source:
:param target:
:return:
"""
if "inputs" not in dictionary or "targets" not in dictionary:
raise ValueError("Empty generated field: inputs or target")
inputs = " ".join(str(x) for x in dictionary["inputs"])
targets = " ".join(str(x) for x in dictionary["targets"])
return inputs + "\t" + targets
def shuffle_dataset(self, filenames):
"""
:return:
"""
logging.info("Shuffling data...")
for fname in filenames:
records = self.read_records(fname)
random.shuffle(records)
out_fname = fname.replace("_unshuffle", "-shuffle")
self.write_records(records, out_fname)
os.remove(fname)
def read_records(self, filename):
"""
:param filename:
:return:
"""
records = []
with open(filename, 'r') as reader:
for record in reader:
records.append(record)
if len(records) % 100000 == 0:
logging.info("read: %d", len(records))
return records
def write_records(self, records, out_filename):
"""
:param records:
:param out_filename:
:return:
"""
with open(out_filename, 'w') as f:
for count, record in enumerate(records):
f.write(record)
if count > 0 and count % 100000 == 0:
logging.info("write: %d", count)
if __name__ == "__main__":
from preprocess.problem import SubwordVocabProblem
from preprocess.problem import TokenVocabProblem
import argparse
parser = argparse.ArgumentParser("Tips for generating subword.")
parser.add_argument(
"--tmp_dir",
type=str,
required=True,
help="dir that includes original corpus.")
parser.add_argument(
"--data_dir",
type=str,
required=True,
help="dir that generates training files")
parser.add_argument(
"--source_train_files",
type=str,
required=True,
help="train file for source")
parser.add_argument(
"--target_train_files",
type=str,
required=True,
help="train file for target")
parser.add_argument(
"--source_vocab_size",
type=int,
required=True,
help="source_vocab_size")
parser.add_argument(
"--target_vocab_size",
type=int,
required=True,
help="target_vocab_size")
parser.add_argument(
"--num_shards",
type=int,
default=100,
help="number of shards")
parser.add_argument(
"--subword",
type=ast.literal_eval,
default=False,
help="subword")
parser.add_argument(
"--token",
type=ast.literal_eval,
default=False,
help="token")
parser.add_argument(
"--onevocab",
type=ast.literal_eval,
default=False,
help="share vocab")
args = parser.parse_args()
print args
gen = TRDataGen(args.num_shards, args.data_dir)
source_train_files = args.source_train_files.split(",")
target_train_files = args.target_train_files.split(",")
if args.token == args.subword:
print "one of subword or token is True"
import sys
sys.exit(1)
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=LOG_FORMAT)
if args.subword:
problem = SubwordVocabProblem(args.source_vocab_size,
args.target_vocab_size,
source_train_files,
target_train_files,
None,
None,
args.onevocab)
else:
problem = TokenVocabProblem(args.source_vocab_size,
args.target_vocab_size,
source_train_files,
target_train_files,
None,
None,
args.onevocab)
gen.generate(problem.generate_data(args.data_dir, args.tmp_dir, True), True, True)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
id2word = {}
ln = sys.stdin
def load_vocab(file_path):
start_index = 0
f = open(file_path, 'r')
for line in f:
line = line.strip()
id2word[start_index] = line
start_index += 1
f.close()
if __name__=="__main__":
load_vocab(sys.argv[1])
while True:
line = ln.readline().strip()
if not line:
break
split_res = line.split(" ")
output_str = ""
for item in split_res:
output_str += id2word[int(item.strip())]
output_str += " "
output_str = output_str.strip()
print output_str
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import ast
import multiprocessing
import numpy as np
import os
from functools import partial
import contextlib
import time
import paddle.fluid.profiler as profiler
import paddle
import paddle.fluid as fluid
import forward_model
import reader
import sys
from config import *
from forward_model import wrap_encoder as encoder
from forward_model import wrap_decoder as decoder
from forward_model import forward_fast_decode
from dense_model import dense_fast_decode
from relative_model import relative_fast_decode
from forward_model import forward_position_encoding_init
from reader import *
def parse_args():
"""
parse_args
"""
parser = argparse.ArgumentParser("Training for Transformer.")
parser.add_argument(
"--val_file_pattern",
type=str,
required=True,
help="The pattern to match test data files.")
parser.add_argument(
"--batch_size",
type=int,
default=50,
help="The number of examples in one run for sequence generation.")
parser.add_argument(
"--pool_size",
type=int,
default=10000,
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use memory optimization.")
parser.add_argument(
"--use_py_reader",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--use_parallel_exe",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
"--use_candidate",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use candidates.")
parser.add_argument(
"--common_ids",
type=str,
default="",
help="The file path of common ids.")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
parser.add_argument(
"--use_delay_load",
type=ast.literal_eval,
default=True,
help=
"The flag indicating whether to load all data into memories at once.")
parser.add_argument(
"--vocab_size",
type=str,
required=True,
help="Size of Vocab.")
parser.add_argument(
"--infer_batch_size",
type=int,
help="Infer batch_size")
parser.add_argument(
"--decode_alpha",
type=float,
help="decode_alpha")
args = parser.parse_args()
# Append args related to dict
#src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
#trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
#dict_args = [
# "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
# str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
# "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
# str(src_dict[args.special_token[2]])
#]
voc_size = args.vocab_size
dict_args = [
"src_vocab_size", voc_size,
"trg_vocab_size", voc_size,
"bos_idx", str(0),
"eos_idx", str(1),
"unk_idx", str(int(voc_size) - 1)
]
merge_cfg_from_list(args.opts + dict_args,
[InferTaskConfig, ModelHyperParams])
return args
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
d_model):
"""
Put all padded data needed by beam search decoder into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
source_length = np.asarray([src_max_len], dtype="int64")
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, source_length
]))
return data_input_dict
def prepare_feed_dict_list(data_generator, count):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.bos_idx,
ModelHyperParams.bos_idx, ModelHyperParams.n_head,
ModelHyperParams.d_model)
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def prepare_dense_feed_dict_list(data_generator, count):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = dense_encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict = prepare_batch_input(
data_buffer, data_input_names, DenseModelHyperParams.bos_idx,
DenseModelHyperParams.bos_idx, DenseModelHyperParams.n_head,
DenseModelHyperParams.d_model)
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def prepare_infer_feed_dict_list(data_generator, count):
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
dense_data_input_names = dense_encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
dense_data_input_dict = prepare_batch_input(
data_buffer, dense_data_input_names, DenseModelHyperParams.bos_idx,
DenseModelHyperParams.bos_idx, DenseModelHyperParams.n_head,
DenseModelHyperParams.d_model)
data_input_dict = prepare_batch_input(data_buffer, data_input_names,
ModelHyperParams.bos_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
for key in dense_data_input_dict:
if key not in data_input_dict:
data_input_dict[key] = dense_data_input_dict[key]
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def get_trans_res(batch_size, out_list, final_list):
"""
Get trans
"""
for index in xrange(batch_size):
seq = out_list[index][0] #top1 seq
if 1 not in seq:
res = seq[1:-1]
else:
res = seq[1:seq.index(1)]
res = map(str, res)
final_list.append(" ".join(res))
def fast_infer(args):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
test_prog = fluid.Program()
startup_prog = fluid.Program()
#with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard("new_forward"):
out_ids1, out_scores1 = forward_fast_decode(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
InferTaskConfig.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
with fluid.unique_name.guard("new_relative_position"):
out_ids2, out_scores2 = relative_fast_decode(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.max_length + 50,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
ModelHyperParams.embedding_sharing,
InferTaskConfig.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
DenseModelHyperParams.src_vocab_size = ModelHyperParams.src_vocab_size
DenseModelHyperParams.trg_vocab_size = ModelHyperParams.trg_vocab_size
DenseModelHyperParams.weight_sharing = ModelHyperParams.weight_sharing
DenseModelHyperParams.embedding_sharing = ModelHyperParams.embedding_sharing
with fluid.unique_name.guard("new_dense"):
out_ids3, out_scores3 = dense_fast_decode(
DenseModelHyperParams.src_vocab_size,
DenseModelHyperParams.trg_vocab_size,
DenseModelHyperParams.max_length + 50,
DenseModelHyperParams.n_layer,
DenseModelHyperParams.enc_n_layer,
DenseModelHyperParams.n_head,
DenseModelHyperParams.d_key,
DenseModelHyperParams.d_value,
DenseModelHyperParams.d_model,
DenseModelHyperParams.d_inner_hid,
DenseModelHyperParams.prepostprocess_dropout,
DenseModelHyperParams.attention_dropout,
DenseModelHyperParams.relu_dropout,
DenseModelHyperParams.preprocess_cmd,
DenseModelHyperParams.postprocess_cmd,
DenseModelHyperParams.weight_sharing,
DenseModelHyperParams.embedding_sharing,
InferTaskConfig.beam_size,
args.infer_batch_size,
InferTaskConfig.max_out_len,
args.decode_alpha,
ModelHyperParams.eos_idx,
params_type="new"
)
test_prog = fluid.default_main_program().clone(for_test=True)
# This is used here to set dropout to the test mode.
if InferTaskConfig.use_gpu:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_params(
exe,
InferTaskConfig.model_path,
main_program=test_prog)
if args.use_mem_opt:
fluid.memory_optimize(test_prog)
exec_strategy = fluid.ExecutionStrategy()
# For faster executor
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 1
build_strategy = fluid.BuildStrategy()
# data reader settings for inference
args.use_token_batch = False
#args.sort_type = reader.SortType.NONE
args.shuffle = False
args.shuffle_batch = False
dev_count = 1
lines_cnt = len(open(args.val_file_pattern, 'r').readlines())
data_reader = line_reader(args.val_file_pattern, args.infer_batch_size, dev_count,
token_delimiter=args.token_delimiter,
max_len=ModelHyperParams.max_length,
parse_line=parse_src_line)
test_data = prepare_data_generator(
args,
is_test=True,
count=dev_count,
pyreader=None,
batch_size=args.infer_batch_size, data_reader=data_reader)
data_generator = test_data()
iter_num = 0
if not os.path.exists("trans"):
os.mkdir("trans")
model_name = InferTaskConfig.model_path.split("/")[-1]
forward_res = open(os.path.join("trans", "forward_%s" % model_name), 'w')
relative_res = open(os.path.join("trans", "relative_%s" % model_name), 'w')
dense_res = open(os.path.join("trans", "dense_%s" % model_name), 'w')
forward_list = []
relative_list = []
dense_list = []
with profile_context(False):
while True:
try:
feed_dict_list = prepare_infer_feed_dict_list(data_generator, dev_count)
forward_seq_ids, relative_seq_ids, dense_seq_ids = exe.run(
program=test_prog,
fetch_list=[out_ids1.name, out_ids2.name, out_ids3.name],
feed=feed_dict_list[0]
if feed_dict_list is not None else None,
return_numpy=False,
use_program_cache=False)
fseq_ids = np.asarray(forward_seq_ids).tolist()
rseq_ids = np.asarray(relative_seq_ids).tolist()
dseq_ids = np.asarray(dense_seq_ids).tolist()
get_trans_res(args.infer_batch_size, fseq_ids, forward_list)
get_trans_res(args.infer_batch_size, rseq_ids, relative_list)
get_trans_res(args.infer_batch_size, dseq_ids, dense_list)
except (StopIteration, fluid.core.EOFException):
break
forward_list = forward_list[:lines_cnt]
relative_list = relative_list[:lines_cnt]
dense_list = dense_list[:lines_cnt]
forward_res.writelines("\n".join(forward_list))
forward_res.flush()
forward_res.close()
relative_res.writelines("\n".join(relative_list))
relative_res.flush()
relative_res.close()
dense_res.writelines("\n".join(dense_list))
dense_res.flush()
dense_res.close()
@contextlib.contextmanager
def profile_context(profile=True):
"""
profile_context
"""
if profile:
with profiler.profiler('All', 'total', './profile_dir/profile_file_tmp'):
yield
else:
yield
if __name__ == "__main__":
args = parse_args()
fast_infer(args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import subprocess
import commands
import os
import six
import copy
import argparse
import time
from args import ArgumentGroup, print_arguments, inv_arguments
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
multip_g = ArgumentGroup(parser, "multiprocessing",
"start paddle training using multi-processing mode.")
multip_g.add_arg("node_ips", str, None,
"paddle trainer ips")
multip_g.add_arg("node_id", int, None,
"the trainer id of the node for multi-node distributed training.")
multip_g.add_arg("print_config", bool, True,
"print the config of multi-processing mode.")
multip_g.add_arg("current_node_ip", str, None,
"the ip of current node.")
multip_g.add_arg("split_log_path", str, "log",
"log path for each trainer.")
multip_g.add_arg("log_prefix", str, "",
"the prefix name of job log.")
multip_g.add_arg("nproc_per_node", int, 8,
"the number of process to use on each node.")
multip_g.add_arg("selected_gpus", str, "0,1,2,3,4,5,6,7",
"the gpus selected to use.")
multip_g.add_arg("training_script", str, None, "the program/script to be lauched "
"in parallel followed by all the arguments", positional_arg=True)
multip_g.add_arg("training_script_args", str, None,
"training script args", positional_arg=True, nargs=argparse.REMAINDER)
# yapf: enable
def start_procs(args):
"""
start_procs
"""
procs = []
log_fns = []
default_env = os.environ.copy()
node_id = args.node_id
node_ips = [x.strip() for x in args.node_ips.split(',')]
current_ip = args.current_node_ip
num_nodes = len(node_ips)
selected_gpus = [x.strip() for x in args.selected_gpus.split(',')]
selected_gpu_num = len(selected_gpus)
all_trainer_endpoints = ""
for ip in node_ips:
for i in range(args.nproc_per_node):
if all_trainer_endpoints != "":
all_trainer_endpoints += ","
all_trainer_endpoints += "%s:617%d" % (ip, i)
nranks = num_nodes * args.nproc_per_node
gpus_per_proc = args.nproc_per_node % selected_gpu_num
if gpus_per_proc == 0:
gpus_per_proc = selected_gpu_num / args.nproc_per_node
else:
gpus_per_proc = selected_gpu_num / args.nproc_per_node + 1
selected_gpus_per_proc = [selected_gpus[i:i + gpus_per_proc]
for i in range(0, len(selected_gpus), gpus_per_proc)]
if args.print_config:
print("all_trainer_endpoints: ", all_trainer_endpoints,
", node_id: ", node_id,
", current_ip: ", current_ip,
", num_nodes: ", num_nodes,
", node_ips: ", node_ips,
", gpus_per_proc: ", gpus_per_proc,
", selected_gpus_per_proc: ", selected_gpus_per_proc,
", nranks: ", nranks)
current_env = copy.copy(default_env)
procs = []
cmds = []
log_fns = []
for i in range(0, args.nproc_per_node):
trainer_id = node_id * args.nproc_per_node + i
current_env.update({
"FLAGS_selected_gpus": "%s" % ",".join([str(s) for s in selected_gpus_per_proc[i]]),
"PADDLE_TRAINER_ID": "%d" % trainer_id,
"PADDLE_CURRENT_ENDPOINT": "%s:617%d" % (current_ip, i),
"PADDLE_TRAINERS_NUM": "%d" % nranks,
"PADDLE_TRAINER_ENDPOINTS": all_trainer_endpoints,
"PADDLE_NODES_NUM": "%d" % num_nodes
})
cmd = [sys.executable, "-u",
args.training_script] + args.training_script_args
cmds.append(cmd)
if args.split_log_path:
fn = open("%s/%sjob.log.%d" % (args.split_log_path, args.log_prefix, trainer_id), "a")
log_fns.append(fn)
process = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
process = subprocess.Popen(cmd, env=current_env)
procs.append(process)
for i in range(len(procs)):
proc = procs[i]
proc.wait()
if len(log_fns) > 0:
log_fns[i].close()
if proc.returncode != 0:
raise subprocess.CalledProcessError(returncode=procs[i].returncode,
cmd=cmds[i])
else:
print("proc %d finsh" % i)
def main(args):
"""
main_func
"""
if args.print_config:
print_arguments(args)
start_procs(args)
if __name__ == "__main__":
lanch_args = parser.parse_args()
main(lanch_args)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
from text_encoder import SubwordTextEncoder, TokenTextEncoder
from text_encoder import EOS_ID
def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
sources, file_byte_budget=1e6):
"""Generate a vocabulary from the datasets in sources."""
def generate():
"""Generate lines for vocabulary generation."""
logging.info("Generating vocab from: %s", str(sources))
for source in sources:
for lang_file in source[1]:
logging.info("Reading file: %s" % lang_file)
filepath = os.path.join(tmp_dir, lang_file)
with open(filepath, mode="r") as source_file:
file_byte_budget_ = file_byte_budget
counter = 0
countermax = int(os.path.getsize(filepath) / file_byte_budget_ / 2)
logging.info("countermax: %d" % countermax)
for line in source_file:
if counter < countermax:
counter += 1
else:
if file_byte_budget_ <= 0:
break
line = line.strip()
file_byte_budget_ -= len(line)
counter = 0
yield line
return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generate())
def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generator, max_subtoken_length=None,
reserved_tokens=None):
"""Inner implementation for vocab generators.
Args:
data_dir: The base directory where data and vocab files are stored. If None,
then do not save the vocab even if it doesn't exist.
vocab_filename: relative filename where vocab file is stored
vocab_size: target size of the vocabulary constructed by SubwordTextEncoder
generator: a generator that produces tokens from the vocabulary
max_subtoken_length: an optional integer. Set this to a finite value to
avoid quadratic costs during vocab building.
reserved_tokens: List of reserved tokens. `text_encoder.RESERVED_TOKENS`
should be a prefix of `reserved_tokens`. If `None`, defaults to
`RESERVED_TOKENS`.
Returns:
A SubwordTextEncoder vocabulary object.
"""
if data_dir and vocab_filename:
vocab_filepath = os.path.join(data_dir, vocab_filename)
if os.path.exists(vocab_filepath):
logging.info("Found vocab file: %s", vocab_filepath)
return SubwordTextEncoder(vocab_filepath)
else:
vocab_filepath = None
logging.info("Generating vocab file: %s", vocab_filepath)
vocab = SubwordTextEncoder.build_from_generator(
generator, vocab_size, max_subtoken_length=max_subtoken_length,
reserved_tokens=reserved_tokens)
if vocab_filepath:
if not os.path.exists(data_dir):
os.makedirs(data_dir)
vocab.store_to_file(vocab_filepath)
return vocab
def txt_line_iterator(fname):
"""
generator for line
:param fname:
:return:
"""
with open(fname, 'r') as f:
for line in f:
yield line.strip()
def txt2txt_generator(source_fname, target_fname):
"""
:param source_fname:
:param target_fname:
:return:
"""
for source, target in zip(
txt_line_iterator(source_fname),
txt_line_iterator(target_fname)
):
yield {"inputs": source, "targets": target}
def txt2txt_encoder(sample_generator, vocab, target_vocab=None):
"""
:param sample_generator:
:param vocab:
:param target_vocab:
:return:
"""
target_vocab = target_vocab or vocab
for sample in sample_generator:
sample["inputs"] = vocab.encode(sample["inputs"])
sample["inputs"].append(EOS_ID)
sample["targets"] = target_vocab.encode(sample["targets"])
sample["targets"].append(EOS_ID)
yield sample
def txt_encoder(filename, batch_size=1, vocab=None):
"""
:param sample_generator:
:param vocab:
:return:
"""
def pad_mini_batch(batch):
"""
:param batch:
:return:
"""
lens = map(lambda x: len(x), batch)
max_len = max(lens)
for i in range(len(batch)):
batch[i] = batch[i] + [0] * (max_len - lens[i])
return batch
fp = open(filename, 'r')
samples = []
batches = []
ct = 0
for sample in fp:
sample = sample.strip()
if vocab:
sample = vocab.encode(sample)
else:
sample = [int(s) for s in sample]
#sample.append(EOS_ID)
batches.append(sample)
ct += 1
if ct % batch_size == 0:
batches = pad_mini_batch(batches)
samples.extend(batches)
batches = []
if ct % batch_size != 0:
batches += [batches[-1]] * (batch_size - ct % batch_size)
batches = pad_mini_batch(batches)
samples.extend(batches)
return samples
if __name__ == "__main__":
parser = argparse.ArgumentParser("Tips for generating testset")
parser.add_argument(
"--vocab",
type=str,
required=True,
help="The path of source vocab.")
parser.add_argument(
"--testset",
type=str,
required=True,
help="The path of testset.")
parser.add_argument(
"--output",
type=str,
required=True,
help="The path of result.")
args = parser.parse_args()
token = TokenTextEncoder(args.vocab)
samples = txt_encoder(args.testset, 1, token)
with open(args.output, 'w') as f:
for sample in samples:
res = [str(item) for item in sample]
f.write("%s\n" % " ".join(res))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from gen_utils import get_or_generate_vocab
from gen_utils import txt_line_iterator
import os, sys
from gen_utils import txt2txt_encoder
from gen_utils import txt2txt_generator
from text_encoder import TokenTextEncoder
import logging
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=LOG_FORMAT)
class GenSubword(object):
"""
gen subword
"""
def __init__(self,
vocab_size=8000,
training_dataset_filenames="train.txt"):
"""
:param vocab_size:
:param vocab_name:
:param training_dataset_filenames: list
"""
self.vocab_size = vocab_size
self.vocab_name = "vocab.%s" % self.vocab_size
if not isinstance(training_dataset_filenames, list):
training_dataset_filenames = [training_dataset_filenames]
self.training_dataset_filenames = training_dataset_filenames
def generate_data(self, data_dir, tmp_dir):
"""
:param data_dir: target dir(includes vocab file)
:param tmp_dir: original dir(includes training dataset filenames)
:return:
"""
data_set = [["", self.training_dataset_filenames]]
source_vocab = get_or_generate_vocab(
data_dir,
tmp_dir,
self.vocab_name,
self.vocab_size,
data_set,
file_byte_budget=1e8)
source_vocab.store_to_file(os.path.join(data_dir, self.vocab_name))
class SubwordVocabProblem(object):
"""subword input"""
def __init__(self,
source_vocab_size=8000,
target_vocab_size=8000,
source_train_filenames="train.src",
target_train_filenames="train.tgt",
source_dev_filenames="dev.src",
target_dev_filenames="dev.tgt",
one_vocab=False):
"""
:param source_vocab_size:
:param target_vocab_size:
:param source_train_filenames:
:param target_train_filenames:
:param source_dev_filenames:
:param target_dev_filenames:
"""
self.source_vocab_size = source_vocab_size
self.target_vocab_size = target_vocab_size
self.source_vocab_name = "vocab.source.%s" % self.source_vocab_size
self.target_vocab_name = "vocab.target.%s" % self.target_vocab_size
if not isinstance(source_train_filenames, list):
source_train_filenames = [source_train_filenames]
if not isinstance(target_train_filenames, list):
target_train_filenames = [target_train_filenames]
if not isinstance(source_dev_filenames, list):
source_dev_filenames = [source_dev_filenames]
if not isinstance(target_dev_filenames, list):
target_dev_filenames = [target_dev_filenames]
self.source_train_filenames = source_train_filenames
self.target_train_filenames = target_train_filenames
self.source_dev_filenames = source_dev_filenames
self.target_dev_filenames = target_dev_filenames
self.one_vocab = one_vocab
def generate_data(self, data_dir, tmp_dir, is_train=True):
"""
:param data_dir:
:param tmp_dir:
:return:
"""
self.source_train_ds = [["", self.source_train_filenames]]
self.target_train_ds = [["", self.target_train_filenames]]
logging.info("building source vocab ...")
logging.info(self.one_vocab)
if not self.one_vocab:
source_vocab = get_or_generate_vocab(data_dir, tmp_dir,
self.source_vocab_name,
self.source_vocab_size,
self.source_train_ds,
file_byte_budget=1e8)
logging.info("building target vocab ...")
target_vocab = get_or_generate_vocab(data_dir, tmp_dir,
self.target_vocab_name,
self.target_vocab_size,
self.target_train_ds,
file_byte_budget=1e8)
else:
train_ds = [["", self.source_train_filenames + self.target_train_filenames]]
source_vocab = get_or_generate_vocab(data_dir, tmp_dir,
self.source_vocab_name,
self.source_vocab_size,
train_ds,
file_byte_budget=1e8)
target_vocab = source_vocab
target_vocab.store_to_file(os.path.join(data_dir, self.target_vocab_name))
pair_filenames = [self.source_train_filenames, self.target_train_filenames]
if not is_train:
pair_filenames = [self.source_dev_filenames, self.target_dev_filenames]
self.compile_data(tmp_dir, pair_filenames, is_train)
source_fname = "train.lang1" if is_train else "dev.lang1"
target_fname = "train.lang2" if is_train else "dev.lang2"
source_fname = os.path.join(tmp_dir, source_fname)
target_fname = os.path.join(tmp_dir, target_fname)
return txt2txt_encoder(txt2txt_generator(source_fname, target_fname),
source_vocab,
target_vocab)
def compile_data(self, tmp_dir, pair_filenames, is_train=True):
"""
combine the input files
:param tmp_dir:
:param pair_filenames:
:param is_train:
:return:
"""
filename = "train.lang1" if is_train else "dev.lang1"
out_file_1 = open(os.path.join(tmp_dir, filename), "w")
filename = "train.lang2" if is_train else "dev.lang2"
out_file_2 = open(os.path.join(tmp_dir, filename), "w")
for file1, file2 in zip(pair_filenames[0], pair_filenames[1]):
for line in txt_line_iterator(os.path.join(tmp_dir, file1)):
out_file_1.write(line + "\n")
for line in txt_line_iterator(os.path.join(tmp_dir, file2)):
out_file_2.write(line + "\n")
out_file_2.close()
out_file_1.close()
class TokenVocabProblem(object):
"""token input"""
def __init__(self,
source_vocab_size=8000,
target_vocab_size=8000,
source_train_filenames="train.src",
target_train_filenames="train.tgt",
source_dev_filenames="dev.src",
target_dev_filenames="dev.tgt",
one_vocab=False):
"""
:param source_vocab_size:
:param target_vocab_size:
:param source_train_filenames:
:param target_train_filenames:
:param source_dev_filenames:
:param target_dev_filenames:
"""
self.source_vocab_size = source_vocab_size
self.target_vocab_size = target_vocab_size
self.source_vocab_name = "vocab.source.%s" % self.source_vocab_size
self.target_vocab_name = "vocab.target.%s" % self.target_vocab_size
if not isinstance(source_train_filenames, list):
source_train_filenames = [source_train_filenames]
if not isinstance(target_train_filenames, list):
target_train_filenames = [target_train_filenames]
if not isinstance(source_dev_filenames, list):
source_dev_filenames = [source_dev_filenames]
if not isinstance(target_dev_filenames, list):
target_dev_filenames = [target_dev_filenames]
self.source_train_filenames = source_train_filenames
self.target_train_filenames = target_train_filenames
self.source_dev_filenames = source_dev_filenames
self.target_dev_filenames = target_dev_filenames
self.one_vocab = one_vocab
def add_exsits_vocab(self, filename):
"""
:param filename
"""
token_list = []
with open(filename) as f:
for line in f:
line = line.strip()
token_list.append(line)
token_list.append("UNK")
return token_list
def generate_data(self, data_dir, tmp_dir, is_train=True):
"""
:param data_dir:
:param tmp_dir:
:return:
"""
self.source_train_ds = [["", self.source_train_filenames]]
self.target_train_ds = [["", self.target_train_filenames]]
pair_filenames = [self.source_train_filenames, self.target_train_filenames]
if not is_train:
pair_filenames = [self.source_dev_filenames, self.target_dev_filenames]
self.compile_data(tmp_dir, pair_filenames, is_train)
source_fname = "train.lang1" if is_train else "dev.lang1"
target_fname = "train.lang2" if is_train else "dev.lang2"
source_fname = os.path.join(tmp_dir, source_fname)
target_fname = os.path.join(tmp_dir, target_fname)
if is_train:
source_vocab_path = os.path.join(data_dir, self.source_vocab_name)
target_vocab_path = os.path.join(data_dir, self.target_vocab_name)
if not self.one_vocab:
if os.path.exists(source_vocab_path) and os.path.exists(target_vocab_path):
logging.info("found source vocab ...")
source_vocab = TokenTextEncoder(None, vocab_list=self.add_exsits_vocab(source_vocab_path))
logging.info("found target vocab ...")
target_vocab = TokenTextEncoder(None, vocab_list=self.add_exsits_vocab(target_vocab_path))
else:
logging.info("building source vocab ...")
source_vocab = TokenTextEncoder.build_from_corpus(source_fname,
self.source_vocab_size)
os.makedirs(data_dir)
logging.info("building target vocab ...")
target_vocab = TokenTextEncoder.build_from_corpus(target_fname,
self.target_vocab_size)
else:
if os.path.exists(source_vocab_path):
logging.info("found source vocab ...")
source_vocab = TokenTextEncoder(None, vocab_list=self.add_exsits_vocab(source_vocab_path))
else:
source_vocab = TokenTextEncoder.build_from_corpus([source_fname, target_fname],
self.source_vocab_size)
logging.info("building target vocab ...")
target_vocab = source_vocab
source_vocab.store_to_file(source_vocab_path)
target_vocab.store_to_file(target_vocab_path)
else:
source_vocab = TokenTextEncoder(os.path.join(data_dir, self.source_vocab_name))
target_vocab = TokenTextEncoder(os.path.join(data_dir, self.target_vocab_name))
return txt2txt_encoder(txt2txt_generator(source_fname, target_fname),
source_vocab,
target_vocab)
def compile_data(self, tmp_dir, pair_filenames, is_train=True):
"""
combine the input files
:param tmp_dir:
:param pair_filenames:
:param is_train:
:return:
"""
filename = "train.lang1" if is_train else "dev.lang1"
out_file_1 = open(os.path.join(tmp_dir, filename), "w")
filename = "train.lang2" if is_train else "dev.lang2"
out_file_2 = open(os.path.join(tmp_dir, filename), "w")
for file1, file2 in zip(pair_filenames[0], pair_filenames[1]):
for line in txt_line_iterator(os.path.join(tmp_dir, file1)):
out_file_1.write(line + "\n")
for line in txt_line_iterator(os.path.join(tmp_dir, file2)):
out_file_2.write(line + "\n")
out_file_2.close()
out_file_1.close()
if __name__ == "__main__":
gen_sub = GenSubword().generate_data("train_data", "../asr/")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import argparse
from text_encoder import SubwordTextEncoder
from text_encoder import EOS_ID
def get_or_generate_vocab(data_dir, tmp_dir, vocab_filename, vocab_size,
sources, file_byte_budget=1e6):
"""Generate a vocabulary from the datasets in sources."""
def generate():
"""Generate lines for vocabulary generation."""
logging.info("Generating vocab from: %s", str(sources))
for source in sources:
for lang_file in source[1]:
logging.info("Reading file: %s" % lang_file)
filepath = os.path.join(tmp_dir, lang_file)
with open(filepath, mode="r") as source_file:
file_byte_budget_ = file_byte_budget
counter = 0
countermax = int(os.path.getsize(filepath) / file_byte_budget_ / 2)
logging.info("countermax: %d" % countermax)
for line in source_file:
if counter < countermax:
counter += 1
else:
if file_byte_budget_ <= 0:
break
line = line.strip()
file_byte_budget_ -= len(line)
counter = 0
yield line
return get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generate())
def get_or_generate_vocab_inner(data_dir, vocab_filename, vocab_size,
generator, max_subtoken_length=None,
reserved_tokens=None):
"""Inner implementation for vocab generators.
Args:
data_dir: The base directory where data and vocab files are stored. If None,
then do not save the vocab even if it doesn't exist.
vocab_filename: relative filename where vocab file is stored
vocab_size: target size of the vocabulary constructed by SubwordTextEncoder
generator: a generator that produces tokens from the vocabulary
max_subtoken_length: an optional integer. Set this to a finite value to
avoid quadratic costs during vocab building.
reserved_tokens: List of reserved tokens. `text_encoder.RESERVED_TOKENS`
should be a prefix of `reserved_tokens`. If `None`, defaults to
`RESERVED_TOKENS`.
Returns:
A SubwordTextEncoder vocabulary object.
"""
if data_dir and vocab_filename:
vocab_filepath = os.path.join(data_dir, vocab_filename)
if os.path.exists(vocab_filepath):
logging.info("Found vocab file: %s", vocab_filepath)
return SubwordTextEncoder(vocab_filepath)
else:
vocab_filepath = None
logging.info("Generating vocab file: %s", vocab_filepath)
vocab = SubwordTextEncoder.build_from_generator(
generator, vocab_size, max_subtoken_length=max_subtoken_length,
reserved_tokens=reserved_tokens)
if vocab_filepath:
if not os.path.exists(data_dir):
os.makedirs(data_dir)
vocab.store_to_file(vocab_filepath)
return vocab
def txt_line_iterator(fname):
"""
generator for line
:param fname:
:return:
"""
with open(fname, 'r') as f:
for line in f:
yield line.strip()
def txt2txt_generator(source_fname, target_fname):
"""
:param source_fname:
:param target_fname:
:return:
"""
for source, target in zip(
txt_line_iterator(source_fname),
txt_line_iterator(target_fname)
):
yield {"inputs": source, "targets": target}
def txt2txt_encoder(sample_generator, vocab, target_vocab=None):
"""
:param sample_generator:
:param vocab:
:param target_vocab:
:return:
"""
target_vocab = target_vocab or vocab
for sample in sample_generator:
sample["inputs"] = vocab.encode(sample["inputs"])
sample["inputs"].append(EOS_ID)
sample["targets"] = target_vocab.encode(sample["targets"])
sample["targets"].append(EOS_ID)
yield sample
def txt_encoder(filename, batch_size=1, vocab=None):
"""
:param sample_generator:
:param vocab:
:return:
"""
def pad_mini_batch(batch):
"""
:param batch:
:return:
"""
lens = map(lambda x: len(x), batch)
max_len = max(lens)
for i in range(len(batch)):
batch[i] = batch[i] + [0] * (max_len - lens[i])
return batch
fp = open(filename, 'r')
samples = []
batches = []
ct = 0
for sample in fp:
sample = sample.strip()
if vocab:
sample = vocab.encode(sample)
else:
sample = [int(s) for s in sample]
#sample.append(EOS_ID)
batches.append(sample)
ct += 1
if ct % batch_size == 0:
batches = pad_mini_batch(batches)
samples.extend(batches)
batches = []
if ct % batch_size != 0:
batches += [batches[-1]] * (batch_size - ct % batch_size)
batches = pad_mini_batch(batches)
samples.extend(batches)
return samples
if __name__ == "__main__":
parser = argparse.ArgumentParser("Tips for generating testset")
parser.add_argument(
"--vocab",
type=str,
required=True,
help="The path of source vocab.")
parser.add_argument(
"--input",
type=str,
required=True,
help="The path of testset.")
parser.add_argument(
"--output",
type=str,
required=True,
help="The path of result.")
args = parser.parse_args()
subword = SubwordTextEncoder(args.vocab)
samples = []
with open(args.input, 'r') as f:
for line in f:
line = line.strip()
ids_list = [int(num) for num in line.split(" ")]
samples.append(ids_list)
with open(args.output, 'w') as f:
for sample in samples:
ret = subword.decode(sample)
f.write("%s\n" % ret)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import sys
import glob
import unicodedata
import six
import logging
from six.moves import range # pylint: disable=redefined-builtin
# Conversion between Unicode and UTF-8, if required (on Python2)
_native_to_unicode = (lambda s: s.decode("utf-8")) if six.PY2 else (lambda s: s)
# This set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set(
six.unichr(i) for i in range(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))
def encode(text):
"""Encode a unicode string as a list of tokens.
Args:
text: a unicode string
Returns:
a list of tokens as Unicode strings
"""
if not text:
return []
ret = []
token_start = 0
# Classify each character in the input string
is_alnum = [c in _ALPHANUMERIC_CHAR_SET for c in text]
for pos in range(1, len(text)):
if is_alnum[pos] != is_alnum[pos - 1]:
token = text[token_start:pos]
if token != u" " or token_start == 0:
ret.append(token)
token_start = pos
final_token = text[token_start:]
ret.append(final_token)
return ret
def decode(tokens):
"""Decode a list of tokens to a unicode string.
Args:
tokens: a list of Unicode strings
Returns:
a unicode string
"""
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens]
ret = []
for i, token in enumerate(tokens):
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]:
ret.append(u" ")
ret.append(token)
return "".join(ret)
def _read_filepattern(filepattern, max_lines=None, split_on_newlines=True):
"""Reads files matching a wildcard pattern, yielding the contents.
Args:
filepattern: A wildcard pattern matching one or more files.
max_lines: If set, stop reading after reading this many lines.
split_on_newlines: A boolean. If true, then split files by lines and strip
leading and trailing whitespace from each line. Otherwise, treat each
file as a single string.
Yields:
The contents of the files as lines, if split_on_newlines is True, or
the entire contents of each file if False.
"""
filenames = sorted(glob.glob(filepattern))
lines_read = 0
for filename in filenames:
with open(filename, 'r') as f:
if split_on_newlines:
for line in f:
yield line.strip()
lines_read += 1
if max_lines and lines_read >= max_lines:
return
else:
if max_lines:
doc = []
for line in f:
doc.append(line)
lines_read += 1
if max_lines and lines_read >= max_lines:
yield "".join(doc)
return
yield "".join(doc)
else:
yield f.read()
def corpus_token_counts(
text_filepattern, corpus_max_lines, split_on_newlines=True):
"""Read the corpus and compute a dictionary of token counts.
Args:
text_filepattern: A pattern matching one or more files.
corpus_max_lines: An integer; maximum total lines to read.
split_on_newlines: A boolean. If true, then split files by lines and strip
leading and trailing whitespace from each line. Otherwise, treat each
file as a single string.
Returns:
a dictionary mapping token to count.
"""
counts = collections.Counter()
for doc in _read_filepattern(
text_filepattern,
max_lines=corpus_max_lines,
split_on_newlines=split_on_newlines):
counts.update(encode(_native_to_unicode(doc)))
return counts
def vocab_token_counts(text_filepattern, max_lines):
"""Read a vocab file and return a dictionary of token counts.
Reads a two-column CSV file of tokens and their frequency in a dataset. The
tokens are presumed to be generated by encode() or the equivalent.
Args:
text_filepattern: A pattern matching one or more files.
max_lines: An integer; maximum total lines to read.
Returns:
a dictionary mapping token to count.
"""
ret = {}
for i, line in enumerate(
_read_filepattern(text_filepattern, max_lines=max_lines)):
if "," not in line:
logging.warning("Malformed vocab line #%d '%s'", i, line)
continue
token, count = line.rsplit(",", 1)
ret[_native_to_unicode(token)] = int(count)
return ret
此差异已折叠。
此差异已折叠。
此差异已折叠。
#!/bin/bash
source ./env/env.sh
source ./env/utils.sh
source ./env/cloud_job_conf.conf
iplist=$1
#iplist=`echo $nodelist | xargs | sed 's/ /,/g'`
if [ ! -d log ]
then
mkdir log
fi
export GLOG_vmodule=fuse_all_reduce_op_pass=10,alloc_continuous_space_for_grad_pass=10
if [[ ${FUSE} == "1" ]]; then
export FLAGS_fuse_parameter_memory_size=64 #MB
fi
set -ux
check_iplist
distributed_args=""
if [[ ${NUM_CARDS} == "1" ]]; then
distributed_args="--selected_gpus 0"
fi
node_ips=${PADDLE_TRAINERS}
distributed_args="--node_ips ${PADDLE_TRAINERS} --node_id ${PADDLE_TRAINER_ID} --current_node_ip ${POD_IP} --nproc_per_node 8 --selected_gpus 0,1,2,3,4,5,6,7"
export FLAGS_eager_delete_tensor_gb=0.0
export FLAGS_fraction_of_gpu_memory_to_use=0.1
export NCCL_IB_GID_INDEX=3
export NCCL_IB_RETRY_CNT=10
export FLAGS_sync_nccl_allreduce=0
BATCH_SIZE=1250
python -u ./src/launch.py ${distributed_args} \
./src/train.py \
--src_vocab_size 37007 \
--tgt_vocab_size 37007 \
--train_file_pattern 'data/translate-train-*' \
--token_delimiter ' ' \
--batch_size ${BATCH_SIZE} \
--use_py_reader True \
--use_delay_load True \
--nccl_comm_num ${NCCL_COMM_NUM} \
--use_hierarchical_allreduce ${USE_HIERARCHICAL_ALLREDUCE} \
--fetch_steps 50 \
--fuse ${FUSE} \
--val_file_pattern 'testset/testfile' \
--infer_batch_size 32 \
--decode_alpha 0.3 \
--beam_size 4 \
--use_fp16 True \
learning_rate 2.0 \
warmup_steps 8000 \
beta2 0.997 \
d_model 1024 \
d_inner_hid 4096 \
n_head 16 \
prepostprocess_dropout 0.3 \
attention_dropout 0.1 \
relu_dropout 0.1 \
embedding_sharing True \
pass_num 100 \
max_length 256 \
save_freq 5000 \
model_dir 'output'
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册