Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
2292264a
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2292264a
编写于
4月 08, 2020
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update transformer predict for new reader
Update README
上级
ee442428
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
300 addition
and
84 deletion
+300
-84
transformer/README.md
transformer/README.md
+56
-33
transformer/gen_data.sh
transformer/gen_data.sh
+220
-0
transformer/predict.py
transformer/predict.py
+9
-4
transformer/reader.py
transformer/reader.py
+13
-4
transformer/run.sh
transformer/run.sh
+0
-41
transformer/train.py
transformer/train.py
+1
-1
transformer/transformer.yaml
transformer/transformer.yaml
+1
-1
未找到文件。
transformer/README.md
浏览文件 @
2292264a
...
...
@@ -34,8 +34,8 @@
克隆代码库到本地
```shell
git clone https://github.com/PaddlePaddle/
models.git
cd
models/dygraph
/transformer
git clone https://github.com/PaddlePaddle/
hapi
cd
hapi
/transformer
```
3.
环境依赖
...
...
@@ -62,7 +62,7 @@
### 单机训练
### 单机单卡
###
#
单机单卡
以提供的英德翻译数据为例,可以执行以下命令进行模型训练:
...
...
@@ -100,28 +100,31 @@ python -u train.py \
--prepostprocess_dropout
0.3
```
另外,如果在执行训练时若提供了
`save_model`
(默认为 trained_models),则每
隔一定 iteration 后(通过参数
`save_step`
设置,默认为10000)将保存当前训练的到相应目录(会保存分别记录了模型参数和优化器状态的
`transformer.pdparams`
和
`transformer
.pdopt`
两个文件),每隔一定数目的 iteration (通过参数
`print_step`
设置,默认为100)将打印如下的日志到标准输出:
另外,如果在执行训练时若提供了
`save_model`
(默认为 trained_models),则每
个 epoch 将保存当前训练的到相应目录(会保存分别记录了模型参数和优化器状态的
`epoch_id.pdparams`
和
`epoch_id
.pdopt`
两个文件),每隔一定数目的 iteration (通过参数
`print_step`
设置,默认为100)将打印如下的日志到标准输出:
```
txt
[2019-08-02 15:30:51,656 INFO train.py:262] step_idx: 150100, epoch: 32, batch: 1364, avg loss: 2.880427, normalized loss: 1.504687, ppl: 17.821888, speed: 3.34 step/s
[2019-08-02 15:31:19,824 INFO train.py:262] step_idx: 150200, epoch: 32, batch: 1464, avg loss: 2.955965, normalized loss: 1.580225, ppl: 19.220257, speed: 3.55 step/s
[2019-08-02 15:31:48,151 INFO train.py:262] step_idx: 150300, epoch: 32, batch: 1564, avg loss: 2.951180, normalized loss: 1.575439, ppl: 19.128502, speed: 3.53 step/s
[2019-08-02 15:32:16,401 INFO train.py:262] step_idx: 150400, epoch: 32, batch: 1664, avg loss: 3.027281, normalized loss: 1.651540, ppl: 20.641024, speed: 3.54 step/s
[2019-08-02 15:32:44,764 INFO train.py:262] step_idx: 150500, epoch: 32, batch: 1764, avg loss: 3.069125, normalized loss: 1.693385, ppl: 21.523066, speed: 3.53 step/s
[2019-08-02 15:33:13,199 INFO train.py:262] step_idx: 150600, epoch: 32, batch: 1864, avg loss: 2.869379, normalized loss: 1.493639, ppl: 17.626074, speed: 3.52 step/s
[2019-08-02 15:33:41,601 INFO train.py:262] step_idx: 150700, epoch: 32, batch: 1964, avg loss: 2.980905, normalized loss: 1.605164, ppl: 19.705633, speed: 3.52 step/s
[2019-08-02 15:34:10,079 INFO train.py:262] step_idx: 150800, epoch: 32, batch: 2064, avg loss: 3.047716, normalized loss: 1.671976, ppl: 21.067181, speed: 3.51 step/s
[2019-08-02 15:34:38,598 INFO train.py:262] step_idx: 150900, epoch: 32, batch: 2164, avg loss: 2.956475, normalized loss: 1.580735, ppl: 19.230072, speed: 3.51 step/s
step 500/1 - loss: 7.345725 - normalized loss: 5.969984 - ppl: 1549.557373 - 216ms/step
step 501/1 - loss: 7.019722 - normalized loss: 5.643982 - ppl: 1118.476196 - 216ms/step
step 502/1 - loss: 7.271389 - normalized loss: 5.895649 - ppl: 1438.547241 - 216ms/step
step 503/1 - loss: 7.241495 - normalized loss: 5.865755 - ppl: 1396.179932 - 216ms/step
step 504/1 - loss: 7.335604 - normalized loss: 5.959863 - ppl: 1533.953613 - 216ms/step
step 505/1 - loss: 7.388950 - normalized loss: 6.013210 - ppl: 1618.006104 - 216ms/step
step 506/1 - loss: 7.217984 - normalized loss: 5.842244 - ppl: 1363.737305 - 216ms/step
step 507/1 - loss: 7.018966 - normalized loss: 5.643226 - ppl: 1117.630615 - 216ms/step
step 508/1 - loss: 6.923923 - normalized loss: 5.548183 - ppl: 1016.299133 - 216ms/step
step 509/1 - loss: 7.472060 - normalized loss: 6.096320 - ppl: 1758.225220 - 216ms/step
step 510/1 - loss: 7.173721 - normalized loss: 5.797981 - ppl: 1304.690063 - 216ms/step
```
也可以使用 CPU 训练(通过参数
`--use_cuda False`
设置),训练速度较慢。
#### 单机多卡
Paddle动态图
支持多进程多卡进行模型训练,启动训练的方式如下:
支持多进程多卡进行模型训练,启动训练的方式如下:
```
sh
python
-m
paddle.distributed.launch
--started_port
8999
--selected_gpus
=
0,1,2,3,4,5,6,7
--log_dir
./mylog train.py
\
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
python
-m
paddle.distributed.launch
--started_port
8999
--selected_gpus
=
0,1,2,3,4,5,6,7 train.py
\
--epoch
30
\
--src_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
...
...
@@ -129,25 +132,27 @@ python -m paddle.distributed.launch --started_port 8999 --selected_gpus=0,1,2,3,
--training_file
gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de
\
--validation_file
gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
4096
\
--print_step
100
\
--use_cuda
True
\
--save_step
10000
--print_step
100
```
此时,程序会将每个进程的输出log导入到
`./mylog`
路径下,只有第一个工作进程会保存模型。
#### 静态图训练
默认使用动态图模式进行训练,可以通过设置
`eager_run`
参数为False来以静态图模式进行训练,如下:
```
sh
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
python
-m
paddle.distributed.launch
--started_port
8999
--selected_gpus
=
0,1,2,3,4,5,6,7 train.py
\
--epoch
30
\
--src_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--training_file
gen_data/wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de
\
--validation_file
gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
4096
\
--print_step
100
\
--eager_run
False
```
.
├── mylog
│ ├── workerlog.0
│ ├── workerlog.1
│ ├── workerlog.2
│ ├── workerlog.3
│ ├── workerlog.4
│ ├── workerlog.5
│ ├── workerlog.6
│ └── workerlog.7
```
### 模型推断
...
...
@@ -163,13 +168,13 @@ python -u predict.py \
--special_token
'<s>'
'<e>'
'<unk>'
\
--predict_file
gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
32
\
--init_from_params
trained_params/step_100000
\
--init_from_params
base_model_dygraph/step_100000/transformer
\
--beam_size
5
\
--max_out_len
255
\
--output_file
predict.txt
```
由
`predict_file`
指定的文件中文本的翻译结果会输出到
`output_file`
指定的文件。执行预测时需要设置
`init_from_params`
来给出模型
所在目录
,更多参数的使用可以在
`transformer.yaml`
文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令:
由
`predict_file`
指定的文件中文本的翻译结果会输出到
`output_file`
指定的文件。执行预测时需要设置
`init_from_params`
来给出模型
文件路径(不包含扩展名)
,更多参数的使用可以在
`transformer.yaml`
文件中查阅注释说明并进行更改设置。注意若在执行预测时设置了模型超参数,应与模型训练时的设置一致,如若训练时使用 big model 的参数设置,则预测时对应类似如下命令:
```
sh
# setting visible devices for prediction
...
...
@@ -181,7 +186,7 @@ python -u predict.py \
--special_token
'<s>'
'<e>'
'<unk>'
\
--predict_file
gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
32
\
--init_from_params
trained_params/step_100000
\
--init_from_params
base_model_dygraph/step_100000/transformer
\
--beam_size
5
\
--max_out_len
255
\
--output_file
predict.txt
\
...
...
@@ -191,6 +196,24 @@ python -u predict.py \
--prepostprocess_dropout
0.3
```
和训练类似,预测时同样可以以静态图模式进行,如下:
```
sh
# setting visible devices for prediction
export
CUDA_VISIBLE_DEVICES
=
0
python
-u
predict.py
\
--src_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
gen_data/wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--predict_file
gen_data/wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
32
\
--init_from_params
base_model_dygraph/step_100000/transformer
\
--beam_size
5
\
--max_out_len
255
\
--output_file
predict.txt
\
--eager_run
False
```
### 模型评估
...
...
transformer/gen_data.sh
0 → 100644
浏览文件 @
2292264a
#! /usr/bin/env bash
set
-e
OUTPUT_DIR
=
$PWD
/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA
=
"
${
OUTPUT_DIR
}
/wmt16_ende_data"
OUTPUT_DIR_BPE_DATA
=
"
${
OUTPUT_DIR
}
/wmt16_ende_data_bpe"
LANG1
=
"en"
LANG2
=
"de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA
=(
'http://www.statmt.org/europarl/v7/de-en.tgz'
'europarl-v7.de-en.en'
'europarl-v7.de-en.de'
'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en'
'commoncrawl.de-en.de'
'http://data.statmt.org/wmt16/translation-task/training-parallel-nc-v11.tgz'
'news-commentary-v11.de-en.en'
'news-commentary-v11.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
DEV_TEST_DATA
=(
'http://data.statmt.org/wmt16/translation-task/dev.tgz'
'newstest201[45]-deen-ref.en.sgm'
'newstest201[45]-deen-src.de.sgm'
'http://data.statmt.org/wmt16/translation-task/test.tgz'
'newstest2016-deen-ref.en.sgm'
'newstest2016-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir
-p
$OUTPUT_DIR_DATA
$OUTPUT_DIR_BPE_DATA
# Extract training data
for
((
i
=
0
;
i<
${#
TRAIN_DATA
[@]
}
;
i+
=
3
))
;
do
data_url
=
${
TRAIN_DATA
[i]
}
data_tgz
=
${
data_url
##*/
}
# training-parallel-commoncrawl.tgz
data
=
${
data_tgz
%.*
}
# training-parallel-commoncrawl
data_lang1
=
${
TRAIN_DATA
[i+1]
}
data_lang2
=
${
TRAIN_DATA
[i+2]
}
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
]
;
then
echo
"Download "
${
data_url
}
wget
-O
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
${
data_url
}
fi
if
[
!
-d
${
OUTPUT_DIR_DATA
}
/
${
data
}
]
;
then
echo
"Extract "
${
data_tgz
}
mkdir
-p
${
OUTPUT_DIR_DATA
}
/
${
data
}
tar_type
=
${
data_tgz
:0-3
}
if
[
${
tar_type
}
==
"tar"
]
;
then
tar
-xvf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
else
tar
-xvzf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
fi
fi
# concatenate all training data
for
data_lang
in
$data_lang1
$data_lang2
;
do
for
f
in
`
find
${
OUTPUT_DIR_DATA
}
/
${
data
}
-regex
".*/
${
data_lang
}
"
`
;
do
data_dir
=
`
dirname
$f
`
data_file
=
`
basename
$f
`
f_base
=
${
f
%.*
}
f_ext
=
${
f
##*.
}
if
[
$f_ext
==
"gz"
]
;
then
gunzip
$f
l
=
${
f_base
##*.
}
f_base
=
${
f_base
%.*
}
else
l
=
${
f_ext
}
fi
if
[
$i
-eq
0
]
;
then
cat
${
f_base
}
.
$l
>
${
OUTPUT_DIR_DATA
}
/train.
$l
else
cat
${
f_base
}
.
$l
>>
${
OUTPUT_DIR_DATA
}
/train.
$l
fi
done
done
done
# Clone mosesdecoder
if
[
!
-d
${
OUTPUT_DIR
}
/mosesdecoder
]
;
then
echo
"Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git
${
OUTPUT_DIR
}
/mosesdecoder
fi
# Extract develop and test data
dev_test_data
=
""
for
((
i
=
0
;
i<
${#
DEV_TEST_DATA
[@]
}
;
i+
=
3
))
;
do
data_url
=
${
DEV_TEST_DATA
[i]
}
data_tgz
=
${
data_url
##*/
}
# training-parallel-commoncrawl.tgz
data
=
${
data_tgz
%.*
}
# training-parallel-commoncrawl
data_lang1
=
${
DEV_TEST_DATA
[i+1]
}
data_lang2
=
${
DEV_TEST_DATA
[i+2]
}
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
]
;
then
echo
"Download "
${
data_url
}
wget
-O
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
${
data_url
}
fi
if
[
!
-d
${
OUTPUT_DIR_DATA
}
/
${
data
}
]
;
then
echo
"Extract "
${
data_tgz
}
mkdir
-p
${
OUTPUT_DIR_DATA
}
/
${
data
}
tar_type
=
${
data_tgz
:0-3
}
if
[
${
tar_type
}
==
"tar"
]
;
then
tar
-xvf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
else
tar
-xvzf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
fi
fi
for
data_lang
in
$data_lang1
$data_lang2
;
do
for
f
in
`
find
${
OUTPUT_DIR_DATA
}
/
${
data
}
-regex
".*/
${
data_lang
}
"
`
;
do
data_dir
=
`
dirname
$f
`
data_file
=
`
basename
$f
`
data_out
=
`
echo
${
data_file
}
|
cut
-d
'-'
-f
1
`
# newstest2016
l
=
`
echo
${
data_file
}
|
cut
-d
'.'
-f
2
`
# en
dev_test_data
=
"
${
dev_test_data
}
\|
${
data_out
}
"
# to make regexp
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_out
}
.
$l
]
;
then
${
OUTPUT_DIR
}
/mosesdecoder/scripts/ems/support/input-from-sgm.perl
\
<
$f
>
${
OUTPUT_DIR_DATA
}
/
${
data_out
}
.
$l
fi
done
done
done
# Tokenize data
for
l
in
${
LANG1
}
${
LANG2
}
;
do
for
f
in
`
ls
${
OUTPUT_DIR_DATA
}
/
*
.
$l
|
grep
"
\(
train
${
dev_test_data
}
\)\.
$l
$"
`
;
do
f_base
=
${
f
%.*
}
# dir/train dir/newstest2016
f_out
=
$f_base
.tok.
$l
if
[
!
-e
$f_out
]
;
then
echo
"Tokenize "
$f
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/tokenizer.perl
-q
-l
$l
-threads
8 <
$f
>
$f_out
fi
done
done
# Clean data
for
f
in
${
OUTPUT_DIR_DATA
}
/train.
${
LANG1
}
${
OUTPUT_DIR_DATA
}
/train.tok.
${
LANG1
}
;
do
f_base
=
${
f
%.*
}
# dir/train dir/train.tok
f_out
=
${
f_base
}
.clean
if
[
!
-e
$f_out
.
${
LANG1
}
]
&&
[
!
-e
$f_out
.
${
LANG2
}
]
;
then
echo
"Clean "
${
f_base
}
${
OUTPUT_DIR
}
/mosesdecoder/scripts/training/clean-corpus-n.perl
$f_base
${
LANG1
}
${
LANG2
}
${
f_out
}
1 80
fi
done
# Clone subword-nmt and generate BPE data
if
[
!
-d
${
OUTPUT_DIR
}
/subword-nmt
]
;
then
git clone https://github.com/rsennrich/subword-nmt.git
${
OUTPUT_DIR
}
/subword-nmt
fi
# Generate BPE data and vocabulary
for
num_operations
in
32000
;
do
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
]
;
then
echo
"Learn BPE with
${
num_operations
}
merge operations"
cat
${
OUTPUT_DIR_DATA
}
/train.tok.clean.
${
LANG1
}
${
OUTPUT_DIR_DATA
}
/train.tok.clean.
${
LANG2
}
|
\
${
OUTPUT_DIR
}
/subword-nmt/learn_bpe.py
-s
$num_operations
>
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
fi
for
l
in
${
LANG1
}
${
LANG2
}
;
do
for
f
in
`
ls
${
OUTPUT_DIR_DATA
}
/
*
.
$l
|
grep
"
\(
train
${
dev_test_data
}
\)\.
tok
\(\.
clean
\)\?\.
$l
$"
`
;
do
f_base
=
${
f
%.*
}
# dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base
=
${
f_base
##*/
}
# train.tok train.tok.clean newstest2016.tok
f_out
=
${
OUTPUT_DIR_BPE_DATA
}
/
${
f_base
}
.bpe.
${
num_operations
}
.
$l
if
[
!
-e
$f_out
]
;
then
echo
"Apply BPE to "
$f
${
OUTPUT_DIR
}
/subword-nmt/apply_bpe.py
-c
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
<
$f
>
$f_out
fi
done
done
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
]
;
then
echo
"Create vocabulary for BPE data"
cat
${
OUTPUT_DIR_BPE_DATA
}
/train.tok.clean.bpe.
${
num_operations
}
.
${
LANG1
}
${
OUTPUT_DIR_BPE_DATA
}
/train.tok.clean.bpe.
${
num_operations
}
.
${
LANG2
}
|
\
${
OUTPUT_DIR
}
/subword-nmt/get_vocab.py |
cut
-f1
-d
' '
>
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
fi
done
# Adapt to the reader
for
f
in
${
OUTPUT_DIR_BPE_DATA
}
/
*
.bpe.
${
num_operations
}
.
${
LANG1
}
;
do
f_base
=
${
f
%.*
}
# dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out
=
${
f_base
}
.
${
LANG1
}
-
${
LANG2
}
if
[
!
-e
$f_out
]
;
then
paste
-d
'\t'
$f_base
.
${
LANG1
}
$f_base
.
${
LANG2
}
>
$f_out
fi
done
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/vocab_all.bpe.
${
num_operations
}
]
;
then
sed
'1i\<s>\n<e>\n<unk>'
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
>
${
OUTPUT_DIR_BPE_DATA
}
/vocab_all.bpe.
${
num_operations
}
fi
echo
"All done."
transformer/predict.py
浏览文件 @
2292264a
...
...
@@ -77,11 +77,12 @@ def do_predict(args):
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
unk_mark
=
args
.
special_token
[
2
])
unk_mark
=
args
.
special_token
[
2
],
byte_data
=
True
)
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
trg_idx2word
=
Seq2SeqDataset
.
load_dict
(
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
,
byte_data
=
True
)
batch_sampler
=
Seq2SeqBatchSampler
(
dataset
=
dataset
,
use_token_batch
=
False
,
...
...
@@ -92,7 +93,11 @@ def do_predict(args):
batch_sampler
=
batch_sampler
,
places
=
device
,
collate_fn
=
partial
(
prepare_infer_input
,
src_pad_idx
=
args
.
eos_idx
,
n_head
=
args
.
n_head
),
prepare_infer_input
,
bos_idx
=
args
.
bos_idx
,
eos_idx
=
args
.
eos_idx
,
src_pad_idx
=
args
.
eos_idx
,
n_head
=
args
.
n_head
),
num_workers
=
0
,
return_list
=
True
)
...
...
@@ -122,7 +127,7 @@ def do_predict(args):
# load the trained model
assert
args
.
init_from_params
,
(
"Please set init_from_params to load the infer model."
)
transformer
.
load
(
os
.
path
.
join
(
args
.
init_from_params
,
"transformer"
)
)
transformer
.
load
(
args
.
init_from_params
)
# TODO: use model.predict when support variant length
f
=
open
(
args
.
output_file
,
"wb"
)
...
...
transformer/reader.py
浏览文件 @
2292264a
...
...
@@ -112,12 +112,15 @@ def prepare_train_input(insts, bos_idx, eos_idx, src_pad_idx, trg_pad_idx,
return
data_inputs
def
prepare_infer_input
(
insts
,
src_pad_idx
,
n_head
):
def
prepare_infer_input
(
insts
,
bos_idx
,
eos_idx
,
src_pad_idx
,
n_head
):
"""
Put all padded data needed by beam search decoder into a list.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
[
inst
[
0
]
+
[
eos_idx
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
1
,
1
]).
astype
(
"float32"
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
)
...
...
@@ -487,5 +490,11 @@ class Seq2SeqBatchSampler(BatchSampler):
yield
batch_indices
def
__len__
(
self
):
if
not
self
.
_use_token_batch
:
batch_number
=
(
len
(
self
.
_dataset
)
+
self
.
_batch_size
*
self
.
_nranks
-
1
)
//
(
self
.
_batch_size
*
self
.
_nranks
)
else
:
# TODO(guosheng): fix the uncertain length
return
0
batch_number
=
1
return
batch_number
transformer/run.sh
已删除
100644 → 0
浏览文件 @
ee442428
python
-u
train.py
\
--epoch
30
\
--src_vocab_fpath
wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--training_file
wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de.tiny
\
--validation_file
wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
4096
\
--print_step
1
\
--use_cuda
True
\
--random_seed
1000
\
--save_step
10
\
--eager_run
True
#--init_from_pretrain_model base_model_dygraph/step_100000/ \
#--init_from_checkpoint trained_models/step_200/transformer
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
exit
echo
`
date
`
python
-u
predict.py
\
--src_vocab_fpath
wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--trg_vocab_fpath
wmt16_ende_data_bpe/vocab_all.bpe.32000
\
--special_token
'<s>'
'<e>'
'<unk>'
\
--predict_file
wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de
\
--batch_size
64
\
--init_from_params
base_model_dygraph/step_100000/
\
--beam_size
5
\
--max_out_len
255
\
--output_file
predict.txt
\
--eager_run
True
#--max_length 500 \
#--n_head 16 \
#--d_model 1024 \
#--d_inner_hid 4096 \
#--prepostprocess_dropout 0.3
echo
`
date
`
\ No newline at end of file
transformer/train.py
浏览文件 @
2292264a
...
...
@@ -34,6 +34,7 @@ from transformer import Transformer, CrossEntropyCriterion
class
TrainCallback
(
ProgBarLogger
):
def
__init__
(
self
,
args
,
verbose
=
2
):
# TODO(guosheng): save according to step
super
(
TrainCallback
,
self
).
__init__
(
args
.
print_step
,
verbose
)
# the best cross-entropy value with label smoothing
loss_normalizer
=
-
(
...
...
@@ -141,7 +142,6 @@ def do_train(args):
eval_freq
=
1
,
save_freq
=
1
,
save_dir
=
args
.
save_model
,
verbose
=
2
,
callbacks
=
[
TrainCallback
(
args
)])
...
...
transformer/transformer.yaml
浏览文件 @
2292264a
# used for continuous evaluation
enable_ce
:
False
eager_run
:
Fals
e
eager_run
:
Tru
e
# The frequency to save trained models when training.
save_step
:
10000
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录