未验证 提交 e93f9d5d 编写于 作者: P pkpk 提交者: GitHub

Merge pull request #29 from guoshengCS/fix-transformer

Update transformer
...@@ -521,7 +521,15 @@ class DynamicDecode(Layer): ...@@ -521,7 +521,15 @@ class DynamicDecode(Layer):
(step_outputs, next_states, next_inputs, (step_outputs, next_states, next_inputs,
next_finished) = self.decoder.step(step_idx_tensor, inputs, next_finished) = self.decoder.step(step_idx_tensor, inputs,
states, **kwargs) states, **kwargs)
next_finished = layers.logical_or(next_finished, finished) if not self.decoder.tracks_own_finished:
# BeamSearchDecoder would track it own finished, since
# beams would be reordered and the finished status of each
# entry might change. Otherwise, perform logical OR which
# would not change the already finished.
next_finished = layers.logical_or(next_finished, finished)
# To confirm states.finished/finished be consistent with
# next_finished.
layers.assign(next_finished, finished)
next_sequence_lengths = layers.elementwise_add( next_sequence_lengths = layers.elementwise_add(
sequence_lengths, sequence_lengths,
layers.cast( layers.cast(
......
...@@ -34,8 +34,8 @@ ...@@ -34,8 +34,8 @@
克隆代码库到本地 克隆代码库到本地
```shell ```shell
git clone https://github.com/PaddlePaddle/models.git git clone https://github.com/PaddlePaddle/hapi
cd models/dygraph/transformer cd hapi/transformer
``` ```
3. 环境依赖 3. 环境依赖
...@@ -62,7 +62,7 @@ ...@@ -62,7 +62,7 @@
### 单机训练 ### 单机训练
### 单机单卡 #### 单机单卡
以提供的英德翻译数据为例,可以执行以下命令进行模型训练: 以提供的英德翻译数据为例,可以执行以下命令进行模型训练:
...@@ -100,28 +100,24 @@ python -u train.py \ ...@@ -100,28 +100,24 @@ python -u train.py \
--prepostprocess_dropout 0.3 --prepostprocess_dropout 0.3
``` ```
另外,如果在执行训练时若提供了 `save_model`(默认为 trained_models),则每隔一定 iteration 后(通过参数 `save_step` 设置,默认为10000)将保存当前训练的到相应目录(会保存分别记录了模型参数和优化器状态的 `transformer.pdparams``transformer.pdopt` 两个文件),每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出: 另外,如果在执行训练时若提供了 `save_model`(默认为 trained_models),则每个 epoch 将保存当前训练的到相应目录(会保存分别记录了模型参数和优化器状态的 `epoch_id.pdparams``epoch_id.pdopt` 两个文件),每隔一定数目的 iteration (通过参数 `print_step` 设置,默认为100)将打印如下的日志到标准输出:
```txt ```txt
[2019-08-02 15:30:51,656 INFO train.py:262] step_idx: 150100, epoch: 32, batch: 1364, avg loss: 2.880427, normalized loss: 1.504687, ppl: 17.821888, speed: 3.34 step/s step 100/1 - loss: 9.165776 - normalized loss: 7.790036 - ppl: 9564.142578 - 247ms/step
[2019-08-02 15:31:19,824 INFO train.py:262] step_idx: 150200, epoch: 32, batch: 1464, avg loss: 2.955965, normalized loss: 1.580225, ppl: 19.220257, speed: 3.55 step/s step 200/1 - loss: 8.037900 - normalized loss: 6.662160 - ppl: 3096.104492 - 227ms/step
[2019-08-02 15:31:48,151 INFO train.py:262] step_idx: 150300, epoch: 32, batch: 1564, avg loss: 2.951180, normalized loss: 1.575439, ppl: 19.128502, speed: 3.53 step/s step 300/1 - loss: 7.668307 - normalized loss: 6.292567 - ppl: 2139.457031 - 221ms/step
[2019-08-02 15:32:16,401 INFO train.py:262] step_idx: 150400, epoch: 32, batch: 1664, avg loss: 3.027281, normalized loss: 1.651540, ppl: 20.641024, speed: 3.54 step/s step 400/1 - loss: 7.598633 - normalized loss: 6.222893 - ppl: 1995.466797 - 218ms/step
[2019-08-02 15:32:44,764 INFO train.py:262] step_idx: 150500, epoch: 32, batch: 1764, avg loss: 3.069125, normalized loss: 1.693385, ppl: 21.523066, speed: 3.53 step/s
[2019-08-02 15:33:13,199 INFO train.py:262] step_idx: 150600, epoch: 32, batch: 1864, avg loss: 2.869379, normalized loss: 1.493639, ppl: 17.626074, speed: 3.52 step/s
[2019-08-02 15:33:41,601 INFO train.py:262] step_idx: 150700, epoch: 32, batch: 1964, avg loss: 2.980905, normalized loss: 1.605164, ppl: 19.705633, speed: 3.52 step/s
[2019-08-02 15:34:10,079 INFO train.py:262] step_idx: 150800, epoch: 32, batch: 2064, avg loss: 3.047716, normalized loss: 1.671976, ppl: 21.067181, speed: 3.51 step/s
[2019-08-02 15:34:38,598 INFO train.py:262] step_idx: 150900, epoch: 32, batch: 2164, avg loss: 2.956475, normalized loss: 1.580735, ppl: 19.230072, speed: 3.51 step/s
``` ```
也可以使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度较慢。 也可以使用 CPU 训练(通过参数 `--use_cuda False` 设置),训练速度较慢。
#### 单机多卡 #### 单机多卡
Paddle动态图支持多进程多卡进行模型训练,启动训练的方式如下: 支持多进程多卡进行模型训练,启动训练的方式如下:
```sh ```sh
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 --log_dir ./mylog train.py \ export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 train.py \
--epoch 30 \ --epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \ --trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
...@@ -129,25 +125,27 @@ python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3, ...@@ -129,25 +125,27 @@ python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \ --training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \ --batch_size 4096 \
--print_step 100 \ --print_step 100
--use_cuda True \
--save_step 10000
``` ```
此时,程序会将每个进程的输出log导入到`./mylog`路径下,只有第一个工作进程会保存模型。 #### 静态图训练
默认使用动态图模式进行训练,可以通过设置 `eager_run` 参数为False来以静态图模式进行训练,如下:
```sh
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,4,5,6,7 train.py \
--epoch 30 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--validation_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 100 \
--eager_run False
``` ```
.
├── mylog
│   ├── workerlog.0
│   ├── workerlog.1
│   ├── workerlog.2
│   ├── workerlog.3
│   ├── workerlog.4
│   ├── workerlog.5
│   ├── workerlog.6
│   └── workerlog.7
```
### 模型推断 ### 模型推断
...@@ -163,13 +161,13 @@ python -u predict.py \ ...@@ -163,13 +161,13 @@ python -u predict.py \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \ --batch_size 32 \
--init_from_params trained_params/step_100000 \ --init_from_params base_model_dygraph/step_100000/transformer \
--beam_size 5 \ --beam_size 5 \
--max_out_len 255 \ --max_out_len 255 \
--output_file predict.txt --output_file predict.txt
``` ```
`predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型所在目录,更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令: `predict_file` 指定的文件中文本的翻译结果会输出到 `output_file` 指定的文件。执行预测时需要设置 `init_from_params` 来给出模型文件路径(不包含扩展名),更多参数的使用可以在 `transformer.yaml` 文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令:
```sh ```sh
# setting visible devices for prediction # setting visible devices for prediction
...@@ -181,7 +179,7 @@ python -u predict.py \ ...@@ -181,7 +179,7 @@ python -u predict.py \
--special_token '<s>' '<e>' '<unk>' \ --special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \ --predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \ --batch_size 32 \
--init_from_params trained_params/step_100000 \ --init_from_params base_model_dygraph/step_100000/transformer \
--beam_size 5 \ --beam_size 5 \
--max_out_len 255 \ --max_out_len 255 \
--output_file predict.txt \ --output_file predict.txt \
...@@ -191,6 +189,24 @@ python -u predict.py \ ...@@ -191,6 +189,24 @@ python -u predict.py \
--prepostprocess_dropout 0.3 --prepostprocess_dropout 0.3
``` ```
和训练类似,预测时同样可以以静态图模式进行,如下:
```sh
# setting visible devices for prediction
export CUDA_VISIBLE_DEVICES=0
python -u predict.py \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 32 \
--init_from_params base_model_dygraph/step_100000/transformer \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--eager_run False
```
### 模型评估 ### 模型评估
......
#! /usr/bin/env bash
set -e
OUTPUT_DIR=$PWD/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt16_ende_data"
OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt16_ende_data_bpe"
LANG1="en"
LANG2="de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA=(
'http://www.statmt.org/europarl/v7/de-en.tgz'
'europarl-v7.de-en.en' 'europarl-v7.de-en.de'
'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en' 'commoncrawl.de-en.de'
'http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz'
'news-commentary-v11.de-en.en' 'news-commentary-v11.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
DEV_TEST_DATA=(
'http://data.statmt.org/wmt16/translation-task/dev.tgz'
'newstest201[45]-deen-ref.en.sgm' 'newstest201[45]-deen-src.de.sgm'
'http://data.statmt.org/wmt16/translation-task/test.tgz'
'newstest2016-deen-ref.en.sgm' 'newstest2016-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir -p $OUTPUT_DIR_DATA $OUTPUT_DIR_BPE_DATA
# Extract training data
for ((i=0;i<${#TRAIN_DATA[@]};i+=3)); do
data_url=${TRAIN_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${TRAIN_DATA[i+1]}
data_lang2=${TRAIN_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
# concatenate all training data
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
f_base=${f%.*}
f_ext=${f##*.}
if [ $f_ext == "gz" ]; then
gunzip $f
l=${f_base##*.}
f_base=${f_base%.*}
else
l=${f_ext}
fi
if [ $i -eq 0 ]; then
cat ${f_base}.$l > ${OUTPUT_DIR_DATA}/train.$l
else
cat ${f_base}.$l >> ${OUTPUT_DIR_DATA}/train.$l
fi
done
done
done
# Clone mosesdecoder
if [ ! -d ${OUTPUT_DIR}/mosesdecoder ]; then
echo "Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git ${OUTPUT_DIR}/mosesdecoder
fi
# Extract develop and test data
dev_test_data=""
for ((i=0;i<${#DEV_TEST_DATA[@]};i+=3)); do
data_url=${DEV_TEST_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${DEV_TEST_DATA[i+1]}
data_lang2=${DEV_TEST_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
data_out=`echo ${data_file} | cut -d '-' -f 1` # newstest2016
l=`echo ${data_file} | cut -d '.' -f 2` # en
dev_test_data="${dev_test_data}\|${data_out}" # to make regexp
if [ ! -e ${OUTPUT_DIR_DATA}/${data_out}.$l ]; then
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
< $f > ${OUTPUT_DIR_DATA}/${data_out}.$l
fi
done
done
done
# Tokenize data
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.$l$"`; do
f_base=${f%.*} # dir/train dir/newstest2016
f_out=$f_base.tok.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l $l -threads 8 < $f > $f_out
fi
done
done
# Clean data
for f in ${OUTPUT_DIR_DATA}/train.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.${LANG1}; do
f_base=${f%.*} # dir/train dir/train.tok
f_out=${f_base}.clean
if [ ! -e $f_out.${LANG1} ] && [ ! -e $f_out.${LANG2} ]; then
echo "Clean "${f_base}
${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $f_base ${LANG1} ${LANG2} ${f_out} 1 80
fi
done
# Clone subword-nmt and generate BPE data
if [ ! -d ${OUTPUT_DIR}/subword-nmt ]; then
git clone https://github.com/rsennrich/subword-nmt.git ${OUTPUT_DIR}/subword-nmt
fi
# Generate BPE data and vocabulary
for num_operations in 32000; do
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} ]; then
echo "Learn BPE with ${num_operations} merge operations"
cat ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG2} | \
${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $num_operations > ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations}
fi
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.tok\(\.clean\)\?\.$l$"`; do
f_base=${f%.*} # dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base=${f_base##*/} # train.tok train.tok.clean newstest2016.tok
f_out=${OUTPUT_DIR_BPE_DATA}/${f_base}.bpe.${num_operations}.$l
if [ ! -e $f_out ]; then
echo "Apply BPE to "$f
${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} < $f > $f_out
fi
done
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} ]; then
echo "Create vocabulary for BPE data"
cat ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG1} ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG2} | \
${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' > ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations}
fi
done
# Adapt to the reader
for f in ${OUTPUT_DIR_BPE_DATA}/*.bpe.${num_operations}.${LANG1}; do
f_base=${f%.*} # dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out=${f_base}.${LANG1}-${LANG2}
if [ ! -e $f_out ]; then
paste -d '\t' $f_base.${LANG1} $f_base.${LANG2} > $f_out
fi
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations} ]; then
sed '1i\<s>\n<e>\n<unk>' ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} > ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations}
fi
echo "All done."
...@@ -77,11 +77,12 @@ def do_predict(args): ...@@ -77,11 +77,12 @@ def do_predict(args):
token_delimiter=args.token_delimiter, token_delimiter=args.token_delimiter,
start_mark=args.special_token[0], start_mark=args.special_token[0],
end_mark=args.special_token[1], end_mark=args.special_token[1],
unk_mark=args.special_token[2]) unk_mark=args.special_token[2],
byte_data=True)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary() args.unk_idx = dataset.get_vocab_summary()
trg_idx2word = Seq2SeqDataset.load_dict( trg_idx2word = Seq2SeqDataset.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True, byte_data=True)
batch_sampler = Seq2SeqBatchSampler( batch_sampler = Seq2SeqBatchSampler(
dataset=dataset, dataset=dataset,
use_token_batch=False, use_token_batch=False,
...@@ -91,10 +92,12 @@ def do_predict(args): ...@@ -91,10 +92,12 @@ def do_predict(args):
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
places=device, places=device,
feed_list=None
if fluid.in_dygraph_mode() else [x.forward() for x in inputs],
collate_fn=partial( collate_fn=partial(
prepare_infer_input, src_pad_idx=args.eos_idx, n_head=args.n_head), prepare_infer_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
src_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, num_workers=0,
return_list=True) return_list=True)
...@@ -124,7 +127,7 @@ def do_predict(args): ...@@ -124,7 +127,7 @@ def do_predict(args):
# load the trained model # load the trained model
assert args.init_from_params, ( assert args.init_from_params, (
"Please set init_from_params to load the infer model.") "Please set init_from_params to load the infer model.")
transformer.load(os.path.join(args.init_from_params, "transformer")) transformer.load(args.init_from_params)
# TODO: use model.predict when support variant length # TODO: use model.predict when support variant length
f = open(args.output_file, "wb") f = open(args.output_file, "wb")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
import glob import glob
import six import six
import os import os
import tarfile import io
import itertools import itertools
from functools import partial
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv ...@@ -24,16 +25,67 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset from paddle.fluid.io import BatchSampler, DataLoader, Dataset
def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): def create_data_loader(args, device):
data_loaders = [None, None]
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
byte_data=True)
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length,
distribute_mode=True
if i == 0 else False) # every device eval all data
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
collate_fn=partial(
prepare_train_input,
bos_idx=args.bos_idx,
eos_idx=args.eos_idx,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, # TODO: use multi-process
return_list=True)
data_loaders[i] = data_loader
return data_loaders
def prepare_train_input(insts, bos_idx, eos_idx, src_pad_idx, trg_pad_idx,
n_head):
""" """
Put all padded data needed by training into a list. Put all padded data needed by training into a list.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] + [eos_idx] for inst in insts],
src_pad_idx,
n_head,
is_target=False)
src_word = src_word.reshape(-1, src_max_len) src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len) src_pos = src_pos.reshape(-1, src_max_len)
trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data(
[inst[1] for inst in insts], trg_pad_idx, n_head, is_target=True) [[bos_idx] + inst[1] for inst in insts],
trg_pad_idx,
n_head,
is_target=True)
trg_word = trg_word.reshape(-1, trg_max_len) trg_word = trg_word.reshape(-1, trg_max_len)
trg_pos = trg_pos.reshape(-1, trg_max_len) trg_pos = trg_pos.reshape(-1, trg_max_len)
...@@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -41,7 +93,7 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
[1, 1, trg_max_len, 1]).astype("float32") [1, 1, trg_max_len, 1]).astype("float32")
lbl_word, lbl_weight, num_token = pad_batch_data( lbl_word, lbl_weight, num_token = pad_batch_data(
[inst[2] for inst in insts], [inst[1] + [eos_idx] for inst in insts],
trg_pad_idx, trg_pad_idx,
n_head, n_head,
is_target=False, is_target=False,
...@@ -60,20 +112,21 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head): ...@@ -60,20 +112,21 @@ def prepare_train_input(insts, src_pad_idx, trg_pad_idx, n_head):
return data_inputs return data_inputs
def prepare_infer_input(insts, src_pad_idx, n_head): def prepare_infer_input(insts, bos_idx, eos_idx, src_pad_idx, n_head):
""" """
Put all padded data needed by beam search decoder into a list. Put all padded data needed by beam search decoder into a list.
""" """
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) [inst[0] + [eos_idx] for inst in insts],
src_pad_idx,
n_head,
is_target=False)
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32") [1, 1, 1, 1]).astype("float32")
src_word = src_word.reshape(-1, src_max_len) src_word = src_word.reshape(-1, src_max_len)
src_pos = src_pos.reshape(-1, src_max_len) src_pos = src_pos.reshape(-1, src_max_len)
data_inputs = [ data_inputs = [src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias]
src_word, src_pos, src_slf_attn_bias, trg_src_attn_bias
]
return data_inputs return data_inputs
...@@ -142,29 +195,30 @@ class SortType(object): ...@@ -142,29 +195,30 @@ class SortType(object):
class Converter(object): class Converter(object):
def __init__(self, vocab, beg, end, unk, delimiter, add_beg): def __init__(self, vocab, beg, end, unk, delimiter, add_beg, add_end):
self._vocab = vocab self._vocab = vocab
self._beg = beg self._beg = beg
self._end = end self._end = end
self._unk = unk self._unk = unk
self._delimiter = delimiter self._delimiter = delimiter
self._add_beg = add_beg self._add_beg = add_beg
self._add_end = add_end
def __call__(self, sentence): def __call__(self, sentence):
return ([self._beg] if self._add_beg else []) + [ return ([self._beg] if self._add_beg else []) + [
self._vocab.get(w, self._unk) self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter) for w in sentence.split(self._delimiter)
] + [self._end] ] + ([self._end] if self._add_end else [])
class ComposedConverter(object): class ComposedConverter(object):
def __init__(self, converters): def __init__(self, converters):
self._converters = converters self._converters = converters
def __call__(self, parallel_sentence): def __call__(self, fields):
return [ return [
self._converters[i](parallel_sentence[i]) converter(field)
for i in range(len(self._converters)) for field, converter in zip(fields, self._converters)
] ]
...@@ -201,10 +255,11 @@ class TokenBatchCreator(object): ...@@ -201,10 +255,11 @@ class TokenBatchCreator(object):
class SampleInfo(object): class SampleInfo(object):
def __init__(self, i, max_len, min_len): def __init__(self, i, lens):
self.i = i self.i = i
self.min_len = min_len # take bos and eos into account
self.max_len = max_len self.min_len = min(lens[0] + 1, lens[1] + 2)
self.max_len = max(lens[0] + 1, lens[1] + 2)
class MinMaxFilter(object): class MinMaxFilter(object):
...@@ -229,98 +284,109 @@ class Seq2SeqDataset(Dataset): ...@@ -229,98 +284,109 @@ class Seq2SeqDataset(Dataset):
src_vocab_fpath, src_vocab_fpath,
trg_vocab_fpath, trg_vocab_fpath,
fpattern, fpattern,
tar_fname=None,
field_delimiter="\t", field_delimiter="\t",
token_delimiter=" ", token_delimiter=" ",
start_mark="<s>", start_mark="<s>",
end_mark="<e>", end_mark="<e>",
unk_mark="<unk>", unk_mark="<unk>",
only_src=False): only_src=False,
# convert str to bytes, and use byte data trg_fpattern=None,
field_delimiter = field_delimiter.encode("utf8") byte_data=False):
token_delimiter = token_delimiter.encode("utf8") if byte_data:
start_mark = start_mark.encode("utf8") # The WMT16 bpe data used here seems including bytes can not be
end_mark = end_mark.encode("utf8") # decoded by utf8. Thus convert str to bytes, and use byte data
unk_mark = unk_mark.encode("utf8") field_delimiter = field_delimiter.encode("utf8")
self._src_vocab = self.load_dict(src_vocab_fpath) token_delimiter = token_delimiter.encode("utf8")
self._trg_vocab = self.load_dict(trg_vocab_fpath) start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8")
self._byte_data = byte_data
self._src_vocab = self.load_dict(src_vocab_fpath, byte_data=byte_data)
self._trg_vocab = self.load_dict(trg_vocab_fpath, byte_data=byte_data)
self._bos_idx = self._src_vocab[start_mark] self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark] self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark] self._unk_idx = self._src_vocab[unk_mark]
self._only_src = only_src
self._field_delimiter = field_delimiter self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter self._token_delimiter = token_delimiter
self.load_src_trg_ids(fpattern, tar_fname) self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, tar_fname): def load_src_trg_ids(self, fpattern, trg_fpattern=None):
converters = [ src_converter = Converter(
Converter(vocab=self._src_vocab, vocab=self._src_vocab,
beg=self._bos_idx, beg=self._bos_idx,
end=self._eos_idx, end=self._eos_idx,
unk=self._unk_idx, unk=self._unk_idx,
delimiter=self._token_delimiter, delimiter=self._token_delimiter,
add_beg=False) add_beg=False,
] add_end=False)
if not self._only_src:
converters.append( trg_converter = Converter(
Converter(vocab=self._trg_vocab, vocab=self._trg_vocab,
beg=self._bos_idx, beg=self._bos_idx,
end=self._eos_idx, end=self._eos_idx,
unk=self._unk_idx, unk=self._unk_idx,
delimiter=self._token_delimiter, delimiter=self._token_delimiter,
add_beg=True)) add_beg=False,
add_end=False)
converters = ComposedConverter(converters)
converters = ComposedConverter([src_converter, trg_converter])
self._src_seq_ids = [] self._src_seq_ids = []
self._trg_seq_ids = None if self._only_src else [] self._trg_seq_ids = []
self._sample_infos = [] self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)): slots = [self._src_seq_ids, self._trg_seq_ids]
src_trg_ids = converters(line) for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
self._src_seq_ids.append(src_trg_ids[0]) lens = []
lens = [len(src_trg_ids[0])] for field, slot in zip(converters(line), slots):
if not self._only_src: slot.append(field)
self._trg_seq_ids.append(src_trg_ids[1]) lens.append(len(field))
lens.append(len(src_trg_ids[1])) self._sample_infos.append(SampleInfo(i, lens))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_lines(self, fpattern, tar_fname): def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern) fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path" assert len(fpaths) > 0, "no matching file to the provided data path"
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): (f_mode, f_encoding,
if tar_fname is None: endl) = ("rb", None, b"\n") if self._byte_data else ("r", "utf8",
raise Exception("If tar file provided, please set tar_fname.") "\n")
if trg_fpattern is None:
f = tarfile.open(fpaths[0], "rb")
for line in f.extractfile(tar_fname):
fields = line.strip(b"\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
for fpath in fpaths: for fpath in fpaths:
if not os.path.isfile(fpath): with io.open(fpath, f_mode, encoding=f_encoding) as f:
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f: for line in f:
fields = line.strip(b"\n").split(self._field_delimiter) fields = line.strip(endl).split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or ( yield fields
self._only_src and len(fields) == 1): else:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
# TODO: Need more rigorous check
trg_fpaths = glob.glob(trg_fpattern)
trg_fpaths = sorted(trg_fpaths)
assert len(fpaths) == len(
trg_fpaths
), "the number of source language data files must equal \
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
with io.open(fpath, f_mode, encoding=f_encoding) as f:
with io.open(
trg_fpath, f_mode, encoding=f_encoding) as trg_f:
for line in zip(f, trg_f):
fields = [field.strip(endl) for field in line]
yield fields yield fields
@staticmethod @staticmethod
def load_dict(dict_path, reverse=False): def load_dict(dict_path, reverse=False, byte_data=False):
word_dict = {} word_dict = {}
with open(dict_path, "rb") as fdict: (f_mode, f_encoding,
endl) = ("rb", None, b"\n") if byte_data else ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if reverse: if reverse:
word_dict[idx] = line.strip(b"\n") word_dict[idx] = line.strip(endl)
else: else:
word_dict[line.strip(b"\n")] = idx word_dict[line.strip(endl)] = idx
return word_dict return word_dict
def get_vocab_summary(self): def get_vocab_summary(self):
...@@ -328,9 +394,8 @@ class Seq2SeqDataset(Dataset): ...@@ -328,9 +394,8 @@ class Seq2SeqDataset(Dataset):
self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx
def __getitem__(self, idx): def __getitem__(self, idx):
return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1], return (self._src_seq_ids[idx], self._trg_seq_ids[idx]
self._trg_seq_ids[idx][1:] ) if self._trg_seq_ids else self._src_seq_ids[idx]
) if not self._only_src else self._src_seq_ids[idx]
def __len__(self): def __len__(self):
return len(self._sample_infos) return len(self._sample_infos)
...@@ -348,6 +413,7 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -348,6 +413,7 @@ class Seq2SeqBatchSampler(BatchSampler):
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
clip_last_batch=False, clip_last_batch=False,
distribute_mode=True,
seed=0): seed=0):
for arg, value in locals().items(): for arg, value in locals().items():
if arg != "self": if arg != "self":
...@@ -355,6 +421,7 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -355,6 +421,7 @@ class Seq2SeqBatchSampler(BatchSampler):
self._random = np.random self._random = np.random
self._random.seed(seed) self._random.seed(seed)
# for multi-devices # for multi-devices
self._distribute_mode = distribute_mode
self._nranks = ParallelEnv().nranks self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id self._device_id = ParallelEnv().dev_id
...@@ -362,8 +429,8 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -362,8 +429,8 @@ class Seq2SeqBatchSampler(BatchSampler):
def __iter__(self): def __iter__(self):
# global sort or global shuffle # global sort or global shuffle
if self._sort_type == SortType.GLOBAL: if self._sort_type == SortType.GLOBAL:
infos = sorted(self._dataset._sample_infos, infos = sorted(
key=lambda x: x.max_len) self._dataset._sample_infos, key=lambda x: x.max_len)
else: else:
if self._shuffle: if self._shuffle:
infos = self._dataset._sample_infos infos = self._dataset._sample_infos
...@@ -383,9 +450,9 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -383,9 +450,9 @@ class Seq2SeqBatchSampler(BatchSampler):
batches = [] batches = []
batch_creator = TokenBatchCreator( batch_creator = TokenBatchCreator(
self._batch_size self.
) if self._use_token_batch else SentenceBatchCreator(self._batch_size * _batch_size) if self._use_token_batch else SentenceBatchCreator(
self._nranks) self._batch_size * self._nranks)
batch_creator = MinMaxFilter(self._max_length, self._min_length, batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator) batch_creator)
...@@ -413,11 +480,21 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -413,11 +480,21 @@ class Seq2SeqBatchSampler(BatchSampler):
# for multi-device # for multi-device
for batch_id, batch in enumerate(batches): for batch_id, batch in enumerate(batches):
if batch_id % self._nranks == self._local_rank: if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch] batch_indices = [info.i for info in batch]
yield batch_indices yield batch_indices
if self._local_rank > len(batches) % self._nranks: if self._distribute_mode and len(batches) % self._nranks != 0:
yield batch_indices if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self): def __len__(self):
return 100 if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
# TODO(guosheng): fix the uncertain length
batch_number = 1
return batch_number
python -u train.py \
--epoch 30 \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--training_file wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny \
--validation_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 4096 \
--print_step 1 \
--use_cuda True \
--random_seed 1000 \
--save_step 10 \
--eager_run True
#--init_from_pretrain_model base_model_dygraph/step_100000/ \
#--init_from_checkpoint trained_models/step_200/transformer
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
exit
echo `date`
python -u predict.py \
--src_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--predict_file wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de \
--batch_size 64 \
--init_from_params base_model_dygraph/step_100000/ \
--beam_size 5 \
--max_out_len 255 \
--output_file predict.txt \
--eager_run True
#--max_length 500 \
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
echo `date`
\ No newline at end of file
...@@ -17,12 +17,10 @@ import os ...@@ -17,12 +17,10 @@ import os
import six import six
import sys import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from functools import partial
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable
from paddle.fluid.io import DataLoader from paddle.fluid.io import DataLoader
from utils.configure import PDConfig from utils.configure import PDConfig
...@@ -30,33 +28,39 @@ from utils.check import check_gpu, check_version ...@@ -30,33 +28,39 @@ from utils.check import check_gpu, check_version
from model import Input, set_device from model import Input, set_device
from callbacks import ProgBarLogger from callbacks import ProgBarLogger
from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler from reader import create_data_loader
from transformer import Transformer, CrossEntropyCriterion, NoamDecay from transformer import Transformer, CrossEntropyCriterion
class LoggerCallback(ProgBarLogger): class TrainCallback(ProgBarLogger):
def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.): def __init__(self, args, verbose=2):
super(LoggerCallback, self).__init__(log_freq, verbose) # TODO(guosheng): save according to step
# TODO: wrap these override function to simplify super(TrainCallback, self).__init__(args.print_step, verbose)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
self.loss_normalizer = loss_normalizer self.loss_normalizer = loss_normalizer
def on_train_begin(self, logs=None): def on_train_begin(self, logs=None):
super(LoggerCallback, self).on_train_begin(logs) super(TrainCallback, self).on_train_begin(logs)
self.train_metrics += ["normalized loss", "ppl"] self.train_metrics += ["normalized loss", "ppl"]
def on_train_batch_end(self, step, logs=None): def on_train_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100)) logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_train_batch_end(step, logs) super(TrainCallback, self).on_train_batch_end(step, logs)
def on_eval_begin(self, logs=None): def on_eval_begin(self, logs=None):
super(LoggerCallback, self).on_eval_begin(logs) super(TrainCallback, self).on_eval_begin(logs)
self.eval_metrics += ["normalized loss", "ppl"] self.eval_metrics = list(
self.eval_metrics) + ["normalized loss", "ppl"]
def on_eval_batch_end(self, step, logs=None): def on_eval_batch_end(self, step, logs=None):
logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer
logs["ppl"] = np.exp(min(logs["loss"][0], 100)) logs["ppl"] = np.exp(min(logs["loss"][0], 100))
super(LoggerCallback, self).on_eval_batch_end(step, logs) super(TrainCallback, self).on_eval_batch_end(step, logs)
def do_train(args): def do_train(args):
...@@ -100,44 +104,7 @@ def do_train(args): ...@@ -100,44 +104,7 @@ def do_train(args):
] ]
# def dataloader # def dataloader
data_loaders = [None, None] train_loader, eval_loader = create_data_loader(args, device)
data_files = [args.training_file, args.validation_file
] if args.validation_file else [args.training_file]
for i, data_file in enumerate(data_files):
dataset = Seq2SeqDataset(
fpattern=data_file,
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
token_delimiter=args.token_delimiter,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2])
args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
args.unk_idx = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=args.use_token_batch,
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
max_length=args.max_length)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None if fluid.in_dygraph_mode() else
[x.forward() for x in inputs + labels],
collate_fn=partial(
prepare_train_input,
src_pad_idx=args.eos_idx,
trg_pad_idx=args.eos_idx,
n_head=args.n_head),
num_workers=0, # TODO: use multi-process
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
# define model # define model
transformer = Transformer( transformer = Transformer(
...@@ -149,8 +116,10 @@ def do_train(args): ...@@ -149,8 +116,10 @@ def do_train(args):
transformer.prepare( transformer.prepare(
fluid.optimizer.Adam( fluid.optimizer.Adam(
learning_rate=fluid.layers.noam_decay(args.d_model, learning_rate=fluid.layers.noam_decay(
args.warmup_steps), args.d_model,
args.warmup_steps,
learning_rate=args.learning_rate),
beta1=args.beta1, beta1=args.beta1,
beta2=args.beta2, beta2=args.beta2,
epsilon=float(args.eps), epsilon=float(args.eps),
...@@ -161,32 +130,19 @@ def do_train(args): ...@@ -161,32 +130,19 @@ def do_train(args):
## init from some checkpoint, to resume the previous training ## init from some checkpoint, to resume the previous training
if args.init_from_checkpoint: if args.init_from_checkpoint:
transformer.load( transformer.load(args.init_from_checkpoint)
os.path.join(args.init_from_checkpoint, "transformer"))
## init from some pretrain models, to better solve the current task ## init from some pretrain models, to better solve the current task
if args.init_from_pretrain_model: if args.init_from_pretrain_model:
transformer.load( transformer.load(args.init_from_pretrain_model, reset_optimizer=True)
os.path.join(args.init_from_pretrain_model, "transformer"),
reset_optimizer=True)
# the best cross-entropy value with label smoothing
loss_normalizer = -(
(1. - args.label_smooth_eps) * np.log(
(1. - args.label_smooth_eps)) + args.label_smooth_eps *
np.log(args.label_smooth_eps / (args.trg_vocab_size - 1) + 1e-20))
# model train # model train
transformer.fit(train_data=train_loader, transformer.fit(train_data=train_loader,
eval_data=eval_loader, eval_data=eval_loader,
epochs=1, epochs=args.epoch,
eval_freq=1, eval_freq=1,
save_freq=1, save_freq=1,
verbose=2, save_dir=args.save_model,
callbacks=[ callbacks=[TrainCallback(args)])
LoggerCallback(
log_freq=args.print_step,
loss_normalizer=loss_normalizer)
])
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -79,7 +79,8 @@ class PrePostProcessLayer(Layer): ...@@ -79,7 +79,8 @@ class PrePostProcessLayer(Layer):
self.functors = [] self.functors = []
for cmd in self.process_cmd: for cmd in self.process_cmd:
if cmd == "a": # add residual connection if cmd == "a": # add residual connection
self.functors.append(lambda x, y: x + y if y else x) self.functors.append(
lambda x, y: x + y if y is not None else x)
elif cmd == "n": # add layer normalization elif cmd == "n": # add layer normalization
self.functors.append( self.functors.append(
self.add_sublayer( self.add_sublayer(
...@@ -169,7 +170,7 @@ class MultiHeadAttention(Layer): ...@@ -169,7 +170,7 @@ class MultiHeadAttention(Layer):
# scale dot product attention # scale dot product attention
product = layers.matmul( product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5) x=q, y=k, transpose_y=True, alpha=self.d_model**-0.5)
if attn_bias: if attn_bias is not None:
product += attn_bias product += attn_bias
weights = layers.softmax(product) weights = layers.softmax(product)
if self.dropout_rate: if self.dropout_rate:
......
# used for continuous evaluation # used for continuous evaluation
enable_ce: False enable_ce: False
eager_run: False eager_run: True
# The frequency to save trained models when training. # The frequency to save trained models when training.
save_step: 10000 save_step: 10000
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册