未验证 提交 1080d1b0 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #1306 from guoshengCS/add-transformer-gendata

Add gen_data.sh in Transformer
......@@ -16,7 +16,7 @@
├── reader.py # 数据读取接口
├── README.md # 文档
├── train.py # 训练脚本
└── util.py # wordpiece 数据解码工具
└── gen_data.sh # 数据生成脚本
```
### 简介
......@@ -59,36 +59,29 @@ Decoder 具有和 Encoder 类似的结构,只是相比于组成 Encoder 的 la
### 数据准备
WMT 数据集是机器翻译领域公认的主流数据集;WMT 英德和英法数据集也是 Transformer 论文中所用数据集,其中英德数据集使用了 BPE(byte-pair encoding)[4]编码的数据,英法数据集使用了 wordpiece [5]的数据。我们这里也将使用 WMT 英德和英法翻译数据,并和论文保持一致使用 BPE 和 wordpiece 的数据,下面给出了使用的方法。对于其他自定义数据,参照下文遵循或转换为类似的数据格式即可。
WMT 数据集是机器翻译领域公认的主流数据集[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)是其中一个中等规模的数据集,也是 Transformer 论文中用到的一个数据集,这里将其作为示例,可以直接运行 `gen_data.sh` 脚本进行 WMT'16 EN-DE 数据集的下载和预处理。数据处理过程主要包括 Tokenize 和 BPE 编码(byte-pair encoding);BPE 编码的数据能够较好的解决未登录词(out-of-vocabulary,OOV)的问题[4],其在 Transformer 论文中也被使用。运行成功后,将会生成文件夹 `gen_data`,其目录结构如下(可在 `gen_data.sh` 中修改):
#### WMT 英德翻译数据
[WMT'16 EN-DE 数据集](http://www.statmt.org/wmt16/translation-task.html)是一个中等规模的数据集。参照论文,英德数据集我们使用 BPE 编码的数据,这能够更好的解决未登录词(out-of-vocabulary,OOV)的问题[4]。用到的 BPE 数据可以参照[这里](https://github.com/google/seq2seq/blob/master/docs/data.md)进行下载(如果希望在自定义数据中使用 BPE 编码,可以参照[这里](https://github.com/rsennrich/subword-nmt)进行预处理),下载后解压,其中 `train.tok.clean.bpe.32000.en``train.tok.clean.bpe.32000.de` 为使用 BPE 的训练数据(平行语料,分别对应了英语和德语,经过了 tokenize 和 BPE 的处理),`newstest2016.tok.bpe.32000.en``newstest2016.tok.bpe.32000.de` 等为测试数据(`newstest2016.tok.en``newstest2016.tok.de` 等则为对应的未使用 BPE 的测试数据),`vocab.bpe.32000` 为相应的词典文件(源语言和目标语言共享该词典文件)。
由于本示例中的数据读取脚本 `reader.py` 默认使用的样本数据的格式为 `\t` 分隔的的源语言和目标语言句子对(默认句子中的词之间使用空格分隔),因此需要将源语言到目标语言的平行语料库文件合并为一个文件,可以执行以下命令进行合并:
```sh
paste -d '\t' train.tok.clean.bpe.32000.en train.tok.clean.bpe.32000.de > train.tok.clean.bpe.32000.en-de
```
此外,下载的词典文件 `vocab.bpe.32000` 中未包含表示序列开始、序列结束和未登录词的特殊符号,可以使用如下命令在词典中加入 `<s>``<e>``<unk>` 作为这三个特殊符号(用 BPE 表示数据已有效避免了未登录词的问题,这里加入只是做通用处理)。
```sh
sed -i '1i\<s>\n<e>\n<unk>' vocab.bpe.32000
```text
.
├── wmt16_ende_data # WMT16 英德翻译数据
├── wmt16_ende_data_bpe # BPE 编码的 WMT16 英德翻译数据
├── mosesdecoder # Moses 机器翻译工具集,包含了 Tokenize、BLEU 评估等脚本
└── subword-nmt # BPE 编码的代码
```
#### WMT 英法翻译数据
[WMT'14 EN-FR 数据集](http://www.statmt.org/wmt14/translation-task.html)是一个较大规模的数据集。参照论文,英法数据我们使用 wordpiece 表示的数据,wordpiece 和 BPE 类似同为采用 sub-word units 来解决 OOV 问题的方法[5]。我们提供了已完成预处理的 wordpiece 数据的下载,可以从[这里](http://transformer-data.bj.bcebos.com/wmt14_enfr.tar)下载,其中 `train.wordpiece.en-fr` 为使用 wordpiece 的训练数据,`newstest2014.wordpiece.en-fr` 为测试数据(`newstest2014.tok.en``newstest2014.tok.fr` 为对应的未经 wordpiece 处理过的测试数据,使用[脚本](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl)进行了 tokenize 的处理),`vocab.wordpiece.en-fr` 为相应的词典文件(源语言和目标语言共享该词典文件)。
`gen_data/wmt16_ende_data_bpe` 中是我们最终使用的英德翻译数据,其中 `train.tok.clean.bpe.32000.en-de` 为训练数据,`newstest2016.tok.bpe.32000.en-de` 等为验证和测试数据,。`vocab_all.bpe.32000` 为相应的词典文件(已加入 `<s>``<e>``<unk>` 这三个特殊符号,源语言和目标语言共享该词典文件)。
提供的英法翻译数据无需进行额外的处理,可以直接使用;需要注意的是,这些用 wordpiece 表示的数据中句子内的 token 之间使用 `\x01` 而非空格进行分隔(因部分 token 内包含空格),这需要在训练时进行指定
对于其他自定义数据,转换为类似 `train.tok.clean.bpe.32000.en-de` 的数据格式(`\t` 分隔的源语言和目标语言句子对,句子中的 token 之间使用空格分隔)即可;如需使用 BPE 编码,可参考,亦可以使用类似 WMT,使用 `gen_data.sh` 进行处理
### 模型训练
`train.py` 是模型训练脚本。以英德翻译数据为例,可以执行以下命令进行模型训练:
```sh
python -u train.py \
--src_vocab_fpath data/vocab.bpe.32000 \
--trg_vocab_fpath data/vocab.bpe.32000 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--train_file_pattern data/train.tok.clean.bpe.32000.en-de \
--train_file_pattern gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--token_delimiter ' ' \
--use_token_batch True \
--batch_size 4096 \
......@@ -104,10 +97,10 @@ python train.py --help
```sh
python -u train.py \
--src_vocab_fpath data/vocab.bpe.32000 \
--trg_vocab_fpath data/vocab.bpe.32000 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--train_file_pattern data/train.tok.clean.bpe.32000.en-de \
--train_file_pattern gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de \
--token_delimiter ' ' \
--use_token_batch True \
--batch_size 3200 \
......@@ -120,7 +113,7 @@ python -u train.py \
n_head 16 \
prepostprocess_dropout 0.3
```
有关这些参数更详细信息的请参考 `config.py` 中的注释说明。对于英法翻译数据,执行训练和英德翻译训练类似,修改命令中的词典和数据文件为英法数据相应文件的路径,另外要注意的是由于英法翻译数据 token 间不是使用空格进行分隔,需要修改 `token_delimiter` 参数的设置为 `--token_delimiter '\x01'`
有关这些参数更详细信息的请参考 `config.py` 中的注释说明。
训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用的 GPU 数目。也可以只使用 CPU 训练(通过参数 `--divice CPU` 设置),训练速度相对较慢。在训练过程中,每隔一定 iteration 后(通过参数 `save_freq` 设置,默认为10000)保存模型到参数 `model_dir` 指定的目录,每个 epoch 结束后也会保存 checkpiont 到 `ckpt_dir` 指定的目录,每个 iteration 将打印如下的日志到标准输出:
```txt
......@@ -141,10 +134,10 @@ step_idx: 9, epoch: 0, batch: 9, avg loss: 10.993434, normalized loss: 9.616467,
`infer.py` 是模型预测脚本。以英德翻译数据为例,模型训练完成后可以执行以下命令对指定文件中的文本进行翻译:
```sh
python -u infer.py \
--src_vocab_fpath data/vocab.bpe.32000 \
--trg_vocab_fpath data/vocab.bpe.32000 \
--src_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--trg_vocab_fpath gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000 \
--special_token '<s>' '<e>' '<unk>' \
--test_file_pattern data/newstest2016.tok.bpe.32000.en-de \
--test_file_pattern gen_data/wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de \
--use_wordpiece False \
--token_delimiter ' ' \
--batch_size 32 \
......@@ -159,14 +152,9 @@ python -u infer.py \
sed -r 's/(@@ )|(@@ ?$)//g' predict.txt > predict.tok.txt
```
对于英法翻译的 wordpiece 数据,执行预测和英德翻译预测类似,修改命令中的词典和数据文件为英法数据相应文件的路径,另外需要注意修改 `token_delimiter` 参数的设置为 `--token_delimiter '\x01'`;同时要修改 `use_wordpiece` 参数的设置为 `--use_wordpiece True`,这会在预测时将翻译得到的 wordpiece 数据还原为原始数据输出。为了使用 tokenize 的数据进行评估,还需要对翻译结果进行 tokenize 的处理,[Moses](https://github.com/moses-smt/mosesdecoder) 提供了一系列机器翻译相关的脚本。执行 `git clone https://github.com/moses-smt/mosesdecoder.git` 克隆 mosesdecoder 仓库后,可以使用其中的 `tokenizer.perl` 脚本对 `predict.txt` 内的翻译结果进行 tokenize 处理并输出到 `predict.tok.txt` 中,如下:
```sh
perl mosesdecoder/scripts/tokenizer/tokenizer.perl -l fr < predict.txt > predict.tok.txt
```
接下来就可以使用参考翻译对翻译结果进行 BLEU 指标的评估了。计算 BLEU 值的脚本也在 Moses 中包含,以英德翻译 `newstest2016.tok.de` 数据为例,执行如下命令:
接下来就可以使用参考翻译对翻译结果进行 BLEU 指标的评估了。以英德翻译 `newstest2016.tok.de` 数据为例,执行如下命令:
```sh
perl mosesdecoder/scripts/generic/multi-bleu.perl data/newstest2016.tok.de < predict.tok.txt
perl gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data/wmt16_ende_data/newstest2016.tok.de < predict.tok.txt
```
可以看到类似如下的结果(为单机两卡训练 200K 个 iteration 后模型的预测结果)。
```
......@@ -174,11 +162,10 @@ BLEU = 33.08, 64.2/39.2/26.4/18.5 (BP=0.994, ratio=0.994, hyp_len=61971, ref_len
```
目前在未使用 model average 的情况下,英德翻译 base model 八卡训练 100K 个 iteration 后测试 BLEU 值如下:
| 测试集 | newstest2013 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|-|
| BLEU | 25.27 | 26.05 | 28.75 | 33.27 |
| 测试集 | newstest2014 | newstest2015 | newstest2016 |
|-|-|-|-|
| BLEU | 26.05 | 28.75 | 33.27 |
英法翻译 base model 八卡训练 100K 个 iteration 后在 `newstest2014` 上测试 BLEU 值为36.。
### 分布式训练
......@@ -260,4 +247,3 @@ export PADDLE_PORT=6177
2. He K, Zhang X, Ren S, et al. [Deep residual learning for image recognition](http://openaccess.thecvf.com/content_cvpr_2016/papers/He_Deep_Residual_Learning_CVPR_2016_paper.pdf)[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
3. Ba J L, Kiros J R, Hinton G E. [Layer normalization](https://arxiv.org/pdf/1607.06450.pdf)[J]. arXiv preprint arXiv:1607.06450, 2016.
4. Sennrich R, Haddow B, Birch A. [Neural machine translation of rare words with subword units](https://arxiv.org/pdf/1508.07909)[J]. arXiv preprint arXiv:1508.07909, 2015.
5. Wu Y, Schuster M, Chen Z, et al. [Google's neural machine translation system: Bridging the gap between human and machine translation](https://arxiv.org/pdf/1609.08144.pdf)[J]. arXiv preprint arXiv:1609.08144, 2016.
#! /usr/bin/env bash
set -e
OUTPUT_DIR=$PWD/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt16_ende_data"
OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt16_ende_data_bpe"
LANG1="en"
LANG2="de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA=(
'http://www.statmt.org/europarl/v7/de-en.tgz'
'europarl-v7.de-en.en' 'europarl-v7.de-en.de'
'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en' 'commoncrawl.de-en.de'
'http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz'
'news-commentary-v11.de-en.en' 'news-commentary-v11.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
DEV_TEST_DATA=(
'http://data.statmt.org/wmt16/translation-task/dev.tgz'
'newstest201[45]-deen-ref.en.sgm' 'newstest201[45]-deen-src.de.sgm'
'http://data.statmt.org/wmt16/translation-task/test.tgz'
'newstest2016-deen-ref.en.sgm' 'newstest2016-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir -p $OUTPUT_DIR_DATA $OUTPUT_DIR_BPE_DATA
# Extract training data
for ((i=0;i<${#TRAIN_DATA[@]};i+=3)); do
data_url=${TRAIN_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${TRAIN_DATA[i+1]}
data_lang2=${TRAIN_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
# concatenate all training data
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
f_base=${f%.*}
f_ext=${f##*.}
if [ $f_ext == "gz" ]; then
gunzip $f
l=${f_base##*.}
f_base=${f_base%.*}
else
l=${f_ext}
fi
if [ $i -eq 0 ]; then
cat ${f_base}.$l > ${OUTPUT_DIR_DATA}/train.$l
else
cat ${f_base}.$l >> ${OUTPUT_DIR_DATA}/train.$l
fi
done
done
done
# Clone mosesdecoder
if [ ! -d ${OUTPUT_DIR}/mosesdecoder ]; then
echo "Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git ${OUTPUT_DIR}/mosesdecoder
fi
# Extract develop and test data
dev_test_data=""
for ((i=0;i<${#DEV_TEST_DATA[@]};i+=3)); do
data_url=${DEV_TEST_DATA[i]}
data_tgz=${data_url##*/} # training-parallel-commoncrawl.tgz
data=${data_tgz%.*} # training-parallel-commoncrawl
data_lang1=${DEV_TEST_DATA[i+1]}
data_lang2=${DEV_TEST_DATA[i+2]}
if [ ! -e ${OUTPUT_DIR_DATA}/${data_tgz} ]; then
echo "Download "${data_url}
wget -O ${OUTPUT_DIR_DATA}/${data_tgz} ${data_url}
fi
if [ ! -d ${OUTPUT_DIR_DATA}/${data} ]; then
echo "Extract "${data_tgz}
mkdir -p ${OUTPUT_DIR_DATA}/${data}
tar_type=${data_tgz:0-3}
if [ ${tar_type} == "tar" ]; then
tar -xvf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
else
tar -xvzf ${OUTPUT_DIR_DATA}/${data_tgz} -C ${OUTPUT_DIR_DATA}/${data}
fi
fi
for data_lang in $data_lang1 $data_lang2; do
for f in `find ${OUTPUT_DIR_DATA}/${data} -regex ".*/${data_lang}"`; do
data_dir=`dirname $f`
data_file=`basename $f`
data_out=`echo ${data_file} | cut -d '-' -f 1` # newstest2016
l=`echo ${data_file} | cut -d '.' -f 2` # en
dev_test_data="${dev_test_data}\|${data_out}" # to make regexp
if [ ! -e ${OUTPUT_DIR_DATA}/${data_out}.$l ]; then
${OUTPUT_DIR}/mosesdecoder/scripts/ems/support/input-from-sgm.perl \
< $f > ${OUTPUT_DIR_DATA}/${data_out}.$l
fi
done
done
done
# Tokenize data
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.$l$"`; do
f_base=${f%.*} # dir/train dir/newstest2016
f_out=$f_base.tok.$l
if [ ! -e $f_out ]; then
echo "Tokenize "$f
${OUTPUT_DIR}/mosesdecoder/scripts/tokenizer/tokenizer.perl -q -l $l -threads 8 < $f > $f_out
fi
done
done
# Clean data
for f in ${OUTPUT_DIR_DATA}/train.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.${LANG1}; do
f_base=${f%.*} # dir/train dir/train.tok
f_out=${f_base}.clean
if [ ! -e $f_out.${LANG1} ] && [ ! -e $f_out.${LANG2} ]; then
echo "Clean "${f_base}
${OUTPUT_DIR}/mosesdecoder/scripts/training/clean-corpus-n.perl $f_base ${LANG1} ${LANG2} ${f_out} 1 80
fi
done
# Clone subword-nmt and generate BPE data
if [ ! -d ${OUTPUT_DIR}/subword-nmt ]; then
git clone https://github.com/rsennrich/subword-nmt.git ${OUTPUT_DIR}/subword-nmt
fi
# Generate BPE data and vocabulary
for num_operations in 32000; do
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} ]; then
echo "Learn BPE with ${num_operations} merge operations"
cat ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG1} ${OUTPUT_DIR_DATA}/train.tok.clean.${LANG2} | \
${OUTPUT_DIR}/subword-nmt/learn_bpe.py -s $num_operations > ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations}
fi
for l in ${LANG1} ${LANG2}; do
for f in `ls ${OUTPUT_DIR_DATA}/*.$l | grep "\(train${dev_test_data}\)\.tok\(\.clean\)\?\.$l$"`; do
f_base=${f%.*} # dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base=${f_base##*/} # train.tok train.tok.clean newstest2016.tok
f_out=${OUTPUT_DIR_BPE_DATA}/${f_base}.bpe.${num_operations}.$l
if [ ! -e $f_out ]; then
echo "Apply BPE to "$f
${OUTPUT_DIR}/subword-nmt/apply_bpe.py -c ${OUTPUT_DIR_BPE_DATA}/bpe.${num_operations} < $f > $f_out
fi
done
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} ]; then
echo "Create vocabulary for BPE data"
cat ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG1} ${OUTPUT_DIR_BPE_DATA}/train.tok.clean.bpe.${num_operations}.${LANG2} | \
${OUTPUT_DIR}/subword-nmt/get_vocab.py | cut -f1 -d ' ' > ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations}
fi
done
# Adapt to the reader
for f in ${OUTPUT_DIR_BPE_DATA}/*.bpe.${num_operations}.${LANG1}; do
f_base=${f%.*} # dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out=${f_base}.${LANG1}-${LANG2}
if [ ! -e $f_out ]; then
paste -d '\t' $f_base.${LANG1} $f_base.${LANG2} > $f_out
fi
done
if [ ! -e ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations} ]; then
sed '1i\<s>\n<e>\n<unk>' ${OUTPUT_DIR_BPE_DATA}/vocab.bpe.${num_operations} > ${OUTPUT_DIR_BPE_DATA}/vocab_all.bpe.${num_operations}
fi
echo "All done."
......@@ -13,7 +13,6 @@ from model import fast_decode as fast_decoder
from config import *
from train import pad_batch_data
import reader
import util
def parse_args():
......@@ -49,21 +48,12 @@ def parse_args():
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--use_wordpiece",
type=ast.literal_eval,
default=False,
help="The flag indicating if the data in wordpiece. The EN-FR data "
"we provided is wordpiece data. For wordpiece data, converting ids to "
"original words is a little different and some special codes are "
"provided in util.py to do this.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter.; "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -144,7 +134,7 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, bos_idx, n_head,
return input_dict
def fast_infer(test_data, trg_idx2word, use_wordpiece):
def fast_infer(test_data, trg_idx2word):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
......@@ -202,9 +192,7 @@ def fast_infer(test_data, trg_idx2word, use_wordpiece):
trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]) if not use_wordpiece else util.subtoken_ids_to_str(
post_process_seq(np.array(seq_ids)[sub_start:sub_end]),
trg_idx2word))
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1])
if len(hyps[i]) >= InferTaskConfig.n_best:
......@@ -232,7 +220,7 @@ def infer(args, inferencer=fast_infer):
clip_last_batch=False)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
inferencer(test_data, trg_idx2word, args.use_wordpiece)
inferencer(test_data, trg_idx2word)
if __name__ == "__main__":
......
......@@ -459,7 +459,7 @@ def transformer(src_vocab_size,
use_py_reader=False,
is_test=False):
if weight_sharing:
assert src_vocab_size == src_vocab_size, (
assert src_vocab_size == trg_vocab_size, (
"Vocabularies in source and target should be same for weight sharing."
)
......
......@@ -80,8 +80,7 @@ def parse_args():
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. "
"For EN-FR wordpiece data we provided, use '\x01' as token delimiter.")
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
'opts',
help='See config.py for all options',
......@@ -345,8 +344,11 @@ def py_reader_provider_wrapper(data_reader):
def test_context(exe, train_exe, dev_count):
# Context to do validation.
startup_prog = fluid.Program()
test_prog = fluid.Program()
startup_prog = fluid.Program()
if args.enable_ce:
test_prog.random_seed = 1000
startup_prog.random_seed = 1000
with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard():
sum_cost, avg_cost, predict, token_num, pyreader = transformer(
......@@ -510,10 +512,12 @@ def train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
time_consumed))
else:
print("epoch: %d, consumed %fs" % (pass_id, time_consumed))
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir,
"pass_" + str(pass_id) + ".checkpoint"), train_prog)
if not args.enable_ce:
fluid.io.save_persistables(
exe,
os.path.join(TrainTaskConfig.ckpt_dir,
"pass_" + str(pass_id) + ".checkpoint"),
train_prog)
if args.enable_ce: # For CE
print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost))
......
import sys
import re
import six
import unicodedata
# Regular expression for unescaping token strings.
# '\u' is converted to '_'
# '\\' is converted to '\'
# '\213;' is converted to unichr(213)
# Inverse of escaping.
_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
# This set contains all letter and number characters.
_ALPHANUMERIC_CHAR_SET = set(
six.unichr(i) for i in range(sys.maxunicode)
if (unicodedata.category(six.unichr(i)).startswith("L") or
unicodedata.category(six.unichr(i)).startswith("N")))
# Unicode utility functions that work with Python 2 and 3
def native_to_unicode(s):
return s if is_unicode(s) else to_unicode(s)
def unicode_to_native(s):
if six.PY2:
return s.encode("utf-8") if is_unicode(s) else s
else:
return s
def is_unicode(s):
if six.PY2:
if isinstance(s, unicode):
return True
else:
if isinstance(s, str):
return True
return False
def to_unicode(s, ignore_errors=False):
if is_unicode(s):
return s
error_mode = "ignore" if ignore_errors else "strict"
return s.decode("utf-8", errors=error_mode)
def unescape_token(escaped_token):
"""
Inverse of encoding escaping.
"""
def match(m):
if m.group(1) is None:
return u"_" if m.group(0) == u"\\u" else u"\\"
try:
return six.unichr(int(m.group(1)))
except (ValueError, OverflowError) as _:
return u"\u3013" # Unicode for undefined character.
trimmed = escaped_token[:-1] if escaped_token.endswith(
"_") else escaped_token
return _UNESCAPE_REGEX.sub(match, trimmed)
def subtoken_ids_to_str(subtoken_ids, vocabs):
"""
Convert a list of subtoken(word piece) ids to a native string.
Refer to SubwordTextEncoder in Tensor2Tensor.
"""
subtokens = [vocabs.get(subtoken_id, u"") for subtoken_id in subtoken_ids]
# Convert a list of subtokens to a list of tokens.
concatenated = "".join([native_to_unicode(t) for t in subtokens])
split = concatenated.split("_")
tokens = []
for t in split:
if t:
unescaped = unescape_token(t + "_")
if unescaped:
tokens.append(unescaped)
# Convert a list of tokens to a unicode string (by inserting spaces bewteen
# word tokens).
token_is_alnum = [t[0] in _ALPHANUMERIC_CHAR_SET for t in tokens]
ret = []
for i, token in enumerate(tokens):
if i > 0 and token_is_alnum[i - 1] and token_is_alnum[i]:
ret.append(u" ")
ret.append(token)
seq = "".join(ret)
return unicode_to_native(seq)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册