Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
30ccfc67
M
models
项目概览
PaddlePaddle
/
models
大约 2 年 前同步成功
通知
232
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
30ccfc67
编写于
12月 15, 2020
作者:
L
liu zhengxi
提交者:
GitHub
12月 15, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Transformer] Simplify transformer reader and fix TranslationDataset (#5035)
* fix translation dataset and simplify transformer reader
上级
8c9d8f56
变更
19
显示空白变更内容
内联
并排
Showing
19 changed file
with
352 addition
and
1327 deletion
+352
-1327
PaddleNLP/benchmark/transformer/configs/transformer.big.yaml
PaddleNLP/benchmark/transformer/configs/transformer.big.yaml
+2
-16
PaddleNLP/benchmark/transformer/dygraph/predict.py
PaddleNLP/benchmark/transformer/dygraph/predict.py
+13
-13
PaddleNLP/benchmark/transformer/dygraph/train.py
PaddleNLP/benchmark/transformer/dygraph/train.py
+20
-23
PaddleNLP/benchmark/transformer/gen_data.sh
PaddleNLP/benchmark/transformer/gen_data.sh
+0
-239
PaddleNLP/benchmark/transformer/reader.py
PaddleNLP/benchmark/transformer/reader.py
+85
-357
PaddleNLP/benchmark/transformer/static/train.py
PaddleNLP/benchmark/transformer/static/train.py
+1
-3
PaddleNLP/docs/datasets.md
PaddleNLP/docs/datasets.md
+1
-1
PaddleNLP/examples/machine_translation/transformer/README.md
PaddleNLP/examples/machine_translation/transformer/README.md
+19
-11
PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml
...ine_translation/transformer/configs/transformer.base.yaml
+3
-17
PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml
...hine_translation/transformer/configs/transformer.big.yaml
+3
-17
PaddleNLP/examples/machine_translation/transformer/gen_data.sh
...eNLP/examples/machine_translation/transformer/gen_data.sh
+0
-239
PaddleNLP/examples/machine_translation/transformer/images/multi_head_attention.png
...e_translation/transformer/images/multi_head_attention.png
+0
-0
PaddleNLP/examples/machine_translation/transformer/images/transformer_network.png
...ne_translation/transformer/images/transformer_network.png
+0
-0
PaddleNLP/examples/machine_translation/transformer/predict.py
...leNLP/examples/machine_translation/transformer/predict.py
+2
-3
PaddleNLP/examples/machine_translation/transformer/reader.py
PaddleNLP/examples/machine_translation/transformer/reader.py
+85
-357
PaddleNLP/examples/machine_translation/transformer/train.py
PaddleNLP/examples/machine_translation/transformer/train.py
+20
-23
PaddleNLP/paddlenlp/data/sampler.py
PaddleNLP/paddlenlp/data/sampler.py
+1
-1
PaddleNLP/paddlenlp/data/vocab.py
PaddleNLP/paddlenlp/data/vocab.py
+7
-1
PaddleNLP/paddlenlp/datasets/translation.py
PaddleNLP/paddlenlp/datasets/translation.py
+90
-6
未找到文件。
PaddleNLP/benchmark/transformer/configs/transformer.big.yaml
浏览文件 @
30ccfc67
...
@@ -10,35 +10,21 @@ init_from_pretrain_model: ""
...
@@ -10,35 +10,21 @@ init_from_pretrain_model: ""
init_from_params
:
"
./trained_models/step_final/"
init_from_params
:
"
./trained_models/step_final/"
# The directory for saving model
# The directory for saving model
save_model
:
"
trained_models"
save_model
:
"
trained_models"
# The directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
# Set seed for CE or debug
random_seed
:
None
random_seed
:
None
# The pattern to match training data files.
training_file
:
"
../gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file
:
"
../gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file
:
"
../gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to.
# The file to output the translation results of predict_file to.
output_file
:
"
predict.txt"
output_file
:
"
predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath
:
"
../gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath
:
"
../gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary.
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
# The directory to store data.
root
:
None
# Whether to use cuda
# Whether to use cuda
use_gpu
:
True
use_gpu
:
True
# Args for reader, see reader.py for details
# Args for reader, see reader.py for details
token_delimiter
:
"
"
use_token_batch
:
True
pool_size
:
200000
pool_size
:
200000
sort_type
:
"
global"
sort_type
:
"
global"
shuffle
:
False
shuffle_batch
:
False
batch_size
:
4096
batch_size
:
4096
infer_batch_size
:
16
infer_batch_size
:
16
...
...
PaddleNLP/benchmark/transformer/dygraph/predict.py
浏览文件 @
30ccfc67
...
@@ -52,8 +52,7 @@ def do_predict(args):
...
@@ -52,8 +52,7 @@ def do_predict(args):
paddle
.
set_device
(
place
)
paddle
.
set_device
(
place
)
# Define data loader
# Define data loader
(
test_loader
,
test_loader
,
to_tokens
=
reader
.
create_infer_loader
(
args
)
test_steps_fn
),
trg_idx2word
=
reader
.
create_infer_loader
(
args
)
# Define model
# Define model
transformer
=
InferTransformerModel
(
transformer
=
InferTransformerModel
(
...
@@ -90,6 +89,7 @@ def do_predict(args):
...
@@ -90,6 +89,7 @@ def do_predict(args):
transformer
.
eval
()
transformer
.
eval
()
f
=
open
(
args
.
output_file
,
"w"
)
f
=
open
(
args
.
output_file
,
"w"
)
with
paddle
.
no_grad
():
for
(
src_word
,
)
in
test_loader
:
for
(
src_word
,
)
in
test_loader
:
finished_seq
=
transformer
(
src_word
=
src_word
)
finished_seq
=
transformer
(
src_word
=
src_word
)
finished_seq
=
finished_seq
.
numpy
().
transpose
([
0
,
2
,
1
])
finished_seq
=
finished_seq
.
numpy
().
transpose
([
0
,
2
,
1
])
...
@@ -98,7 +98,7 @@ def do_predict(args):
...
@@ -98,7 +98,7 @@ def do_predict(args):
if
beam_idx
>=
args
.
n_best
:
if
beam_idx
>=
args
.
n_best
:
break
break
id_list
=
post_process_seq
(
beam
,
args
.
bos_idx
,
args
.
eos_idx
)
id_list
=
post_process_seq
(
beam
,
args
.
bos_idx
,
args
.
eos_idx
)
word_list
=
[
trg_idx2word
[
id
]
for
id
in
id_list
]
word_list
=
to_tokens
(
id_list
)
sequence
=
" "
.
join
(
word_list
)
+
"
\n
"
sequence
=
" "
.
join
(
word_list
)
+
"
\n
"
f
.
write
(
sequence
)
f
.
write
(
sequence
)
...
...
PaddleNLP/benchmark/transformer/dygraph/train.py
浏览文件 @
30ccfc67
...
@@ -51,9 +51,7 @@ def do_train(args):
...
@@ -51,9 +51,7 @@ def do_train(args):
paddle
.
seed
(
random_seed
)
paddle
.
seed
(
random_seed
)
# Define data loader
# Define data loader
(
train_loader
,
train_steps_fn
),
(
eval_loader
,
(
train_loader
),
(
eval_loader
)
=
reader
.
create_data_loader
(
args
)
eval_steps_fn
)
=
reader
.
create_data_loader
(
args
,
trainer_count
,
rank
)
# Define model
# Define model
transformer
=
TransformerModel
(
transformer
=
TransformerModel
(
...
@@ -176,7 +174,6 @@ def do_train(args):
...
@@ -176,7 +174,6 @@ def do_train(args):
if
step_idx
%
args
.
save_step
==
0
and
step_idx
!=
0
:
if
step_idx
%
args
.
save_step
==
0
and
step_idx
!=
0
:
# Validation
# Validation
if
args
.
validation_file
:
transformer
.
eval
()
transformer
.
eval
()
total_sum_cost
=
0
total_sum_cost
=
0
total_token_num
=
0
total_token_num
=
0
...
...
PaddleNLP/benchmark/transformer/gen_data.sh
已删除
100644 → 0
浏览文件 @
8c9d8f56
#! /usr/bin/env bash
set
-e
OUTPUT_DIR
=
$PWD
/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA
=
"
${
OUTPUT_DIR
}
/wmt14_ende_data"
OUTPUT_DIR_BPE_DATA
=
"
${
OUTPUT_DIR
}
/wmt14_ende_data_bpe"
LANG1
=
"en"
LANG2
=
"de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA
=(
'http://statmt.org/wmt13/training-parallel-europarl-v7.tgz'
'europarl-v7.de-en.en'
'europarl-v7.de-en.de'
'http://statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en'
'commoncrawl.de-en.de'
'http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz'
'news-commentary-v12.de-en.en'
'news-commentary-v12.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
# source & reference
DEV_TEST_DATA
=(
'http://data.statmt.org/wmt17/translation-task/dev.tgz'
'newstest2013-ref.de.sgm'
'newstest2013-src.en.sgm'
'http://statmt.org/wmt14/test-full.tgz'
'newstest2014-deen-ref.en.sgm'
'newstest2014-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir
-p
$OUTPUT_DIR_DATA
$OUTPUT_DIR_BPE_DATA
# Extract training data
for
((
i
=
0
;
i<
${#
TRAIN_DATA
[@]
}
;
i+
=
3
))
;
do
data_url
=
${
TRAIN_DATA
[i]
}
data_tgz
=
${
data_url
##*/
}
# training-parallel-commoncrawl.tgz
data
=
${
data_tgz
%.*
}
# training-parallel-commoncrawl
data_lang1
=
${
TRAIN_DATA
[i+1]
}
data_lang2
=
${
TRAIN_DATA
[i+2]
}
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
]
;
then
echo
"Download "
${
data_url
}
echo
"Dir "
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
wget
-O
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
${
data_url
}
fi
if
[
!
-d
${
OUTPUT_DIR_DATA
}
/
${
data
}
]
;
then
echo
"Extract "
${
data_tgz
}
mkdir
-p
${
OUTPUT_DIR_DATA
}
/
${
data
}
tar_type
=
${
data_tgz
:0-3
}
if
[
${
tar_type
}
==
"tar"
]
;
then
tar
-xvf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
else
tar
-xvzf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
fi
fi
# concatenate all training data
for
data_lang
in
$data_lang1
$data_lang2
;
do
for
f
in
`
find
${
OUTPUT_DIR_DATA
}
/
${
data
}
-regex
".*/
${
data_lang
}
"
`
;
do
data_dir
=
`
dirname
$f
`
data_file
=
`
basename
$f
`
f_base
=
${
f
%.*
}
f_ext
=
${
f
##*.
}
if
[
$f_ext
==
"gz"
]
;
then
gunzip
$f
l
=
${
f_base
##*.
}
f_base
=
${
f_base
%.*
}
else
l
=
${
f_ext
}
fi
if
[
$i
-eq
0
]
;
then
cat
${
f_base
}
.
$l
>
${
OUTPUT_DIR_DATA
}
/train.
$l
else
cat
${
f_base
}
.
$l
>>
${
OUTPUT_DIR_DATA
}
/train.
$l
fi
done
done
done
# Clone mosesdecoder
if
[
!
-d
${
OUTPUT_DIR
}
/mosesdecoder
]
;
then
echo
"Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git
${
OUTPUT_DIR
}
/mosesdecoder
fi
# Extract develop and test data
dev_test_data
=
""
for
((
i
=
0
;
i<
${#
DEV_TEST_DATA
[@]
}
;
i+
=
3
))
;
do
data_url
=
${
DEV_TEST_DATA
[i]
}
data_tgz
=
${
data_url
##*/
}
# training-parallel-commoncrawl.tgz
data
=
${
data_tgz
%.*
}
# training-parallel-commoncrawl
data_lang1
=
${
DEV_TEST_DATA
[i+1]
}
data_lang2
=
${
DEV_TEST_DATA
[i+2]
}
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
]
;
then
echo
"Download "
${
data_url
}
wget
-O
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
${
data_url
}
fi
if
[
!
-d
${
OUTPUT_DIR_DATA
}
/
${
data
}
]
;
then
echo
"Extract "
${
data_tgz
}
mkdir
-p
${
OUTPUT_DIR_DATA
}
/
${
data
}
tar_type
=
${
data_tgz
:0-3
}
if
[
${
tar_type
}
==
"tar"
]
;
then
tar
-xvf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
else
tar
-xvzf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
fi
fi
for
data_lang
in
$data_lang1
$data_lang2
;
do
for
f
in
`
find
${
OUTPUT_DIR_DATA
}
/
${
data
}
-regex
".*/
${
data_lang
}
"
`
;
do
echo
"input-from-sgm"
data_dir
=
`
dirname
$f
`
data_file
=
`
basename
$f
`
data_out
=
`
echo
${
data_file
}
|
cut
-d
'-'
-f
1
`
# newstest2016
l
=
`
echo
${
data_file
}
|
cut
-d
'.'
-f
2
`
# en
dev_test_data
=
"
${
dev_test_data
}
\|
${
data_out
}
"
# to make regexp
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_out
}
.
$l
]
;
then
${
OUTPUT_DIR
}
/mosesdecoder/scripts/ems/support/input-from-sgm.perl
\
<
$f
>
${
OUTPUT_DIR_DATA
}
/
${
data_out
}
.
$l
fi
done
done
done
# Tokenize data
for
l
in
${
LANG1
}
${
LANG2
}
;
do
for
f
in
`
ls
${
OUTPUT_DIR_DATA
}
/
*
.
$l
|
grep
"
\(
train
\|
newstest2013
\)\.
$l
$"
`
;
do
f_base
=
${
f
%.*
}
# dir/train dir/newstest2013
f_out
=
$f_base
.tok.
$l
f_tmp
=
$f_base
.tmp.
$l
if
[
!
-e
$f_out
]
;
then
echo
"Tokenize "
$f
cat
$f
|
\
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl
$l
|
\
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/remove-non-printing-char.perl |
\
tee
-a
$tmp
/valid.raw.
$l
|
\
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/tokenizer.perl
-a
-l
$l
-threads
8
>>
$f_out
echo
$f_out
fi
done
done
for
l
in
${
LANG1
}
${
LANG2
}
;
do
for
f
in
`
ls
${
OUTPUT_DIR_DATA
}
/
*
.
$l
|
grep
"
\(
newstest2014
\)\.
$l
$"
`
;
do
f_base
=
${
f
%.*
}
# dir/newstest2014
f_out
=
$f_base
.tok.
$l
if
[
!
-e
$f_out
]
;
then
echo
"Tokenize "
$f
cat
$f
|
\
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/tokenizer.perl
-a
-l
$l
-threads
8
>>
$f_out
echo
$f_out
fi
done
done
# Clean data
for
f
in
${
OUTPUT_DIR_DATA
}
/train.
${
LANG1
}
${
OUTPUT_DIR_DATA
}
/train.tok.
${
LANG1
}
;
do
f_base
=
${
f
%.*
}
# dir/train dir/train.tok
f_out
=
${
f_base
}
.clean
if
[
!
-e
$f_out
.
${
LANG1
}
]
&&
[
!
-e
$f_out
.
${
LANG2
}
]
;
then
echo
"Clean "
${
f_base
}
${
OUTPUT_DIR
}
/mosesdecoder/scripts/training/clean-corpus-n.perl
$f_base
${
LANG1
}
${
LANG2
}
${
f_out
}
1 256
fi
done
python
-m
pip
install
subword-nmt
# Generate BPE data and vocabulary
for
num_operations
in
33708
;
do
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
]
;
then
echo
"Learn BPE with
${
num_operations
}
merge operations"
cat
${
OUTPUT_DIR_DATA
}
/train.tok.clean.
${
LANG1
}
${
OUTPUT_DIR_DATA
}
/train.tok.clean.
${
LANG2
}
|
\
subword-nmt learn-bpe
-s
$num_operations
>
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
fi
for
l
in
${
LANG1
}
${
LANG2
}
;
do
for
f
in
`
ls
${
OUTPUT_DIR_DATA
}
/
*
.
$l
|
grep
"
\(
train
${
dev_test_data
}
\)\.
tok
\(\.
clean
\)\?\.
$l
$"
`
;
do
f_base
=
${
f
%.*
}
# dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base
=
${
f_base
##*/
}
# train.tok train.tok.clean newstest2016.tok
f_out
=
${
OUTPUT_DIR_BPE_DATA
}
/
${
f_base
}
.bpe.
${
num_operations
}
.
$l
if
[
!
-e
$f_out
]
;
then
echo
"Apply BPE to "
$f
subword-nmt apply-bpe
-c
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
<
$f
>
$f_out
fi
done
done
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
]
;
then
echo
"Create vocabulary for BPE data"
cat
${
OUTPUT_DIR_BPE_DATA
}
/train.tok.clean.bpe.
${
num_operations
}
.
${
LANG1
}
${
OUTPUT_DIR_BPE_DATA
}
/train.tok.clean.bpe.
${
num_operations
}
.
${
LANG2
}
|
\
subword-nmt get-vocab |
cut
-f1
-d
' '
>
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
fi
done
# Adapt to the reader
for
f
in
${
OUTPUT_DIR_BPE_DATA
}
/
*
.bpe.
${
num_operations
}
.
${
LANG1
}
;
do
f_base
=
${
f
%.*
}
# dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out
=
${
f_base
}
.
${
LANG1
}
-
${
LANG2
}
if
[
!
-e
$f_out
]
;
then
paste
-d
'\t'
$f_base
.
${
LANG1
}
$f_base
.
${
LANG2
}
>
$f_out
fi
done
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/vocab_all.bpe.
${
num_operations
}
]
;
then
sed
'1i\<s>\n<e>\n<unk>'
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
>
${
OUTPUT_DIR_BPE_DATA
}
/vocab_all.bpe.
${
num_operations
}
fi
echo
"All done."
PaddleNLP/benchmark/transformer/reader.py
浏览文件 @
30ccfc67
...
@@ -22,79 +22,105 @@ from functools import partial
...
@@ -22,79 +22,105 @@ from functools import partial
import
numpy
as
np
import
numpy
as
np
from
paddle.io
import
BatchSampler
,
DataLoader
,
Dataset
from
paddle.io
import
BatchSampler
,
DataLoader
,
Dataset
from
paddlenlp.data
import
Pad
from
paddlenlp.data
import
Pad
from
paddlenlp.datasets
import
WMT14ende
from
paddlenlp.data.sampler
import
SamplerHelper
def
create_infer_loader
(
args
):
def
min_max_filer
(
data
,
max_len
,
min_len
=
0
):
dataset
=
TransformerDataset
(
# 1 for special tokens.
fpattern
=
args
.
predict_file
,
data_min_len
=
min
(
len
(
data
[
0
]),
len
(
data
[
1
]))
+
1
src_vocab_fpath
=
args
.
src_vocab_fpath
,
data_max_len
=
max
(
len
(
data
[
0
]),
len
(
data
[
1
]))
+
1
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
return
(
data_min_len
>=
min_len
)
and
(
data_max_len
<=
max_len
)
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
def
create_data_loader
(
args
):
unk_mark
=
args
.
special_token
[
2
])
root
=
None
if
args
.
root
==
"None"
else
args
.
root
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
(
src_vocab
,
trg_vocab
)
=
WMT14ende
.
get_vocab
(
root
=
root
)
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
args
.
src_vocab_size
,
args
.
trg_vocab_size
=
len
(
src_vocab
),
len
(
trg_vocab
)
trg_idx2word
=
TransformerDataset
.
load_dict
(
transform_func
=
WMT14ende
.
get_default_transform_func
(
root
=
root
)
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
datasets
=
[
batch_sampler
=
TransformerBatchSampler
(
WMT14ende
.
get_datasets
(
dataset
=
dataset
,
mode
=
m
,
transform_func
=
transform_func
)
for
m
in
[
"train"
,
"dev"
]
use_token_batch
=
False
,
]
batch_size
=
args
.
infer_batch_size
,
max_length
=
args
.
max_length
)
def
_max_token_fn
(
current_idx
,
current_batch_size
,
tokens_sofar
,
data_source
):
return
max
(
tokens_sofar
,
len
(
data_source
[
current_idx
][
0
])
+
1
,
len
(
data_source
[
current_idx
][
1
])
+
1
)
def
_key
(
size_so_far
,
minibatch_len
):
return
size_so_far
*
minibatch_len
data_loaders
=
[(
None
)]
*
2
for
i
,
dataset
in
enumerate
(
datasets
):
m
=
dataset
.
mode
dataset
=
dataset
.
filter
(
partial
(
min_max_filer
,
max_len
=
args
.
max_length
))
sampler
=
SamplerHelper
(
dataset
)
src_key
=
(
lambda
x
,
data_source
:
len
(
data_source
[
x
][
0
])
+
1
)
if
args
.
sort_type
==
SortType
.
GLOBAL
:
buffer_size
=
-
1
trg_key
=
(
lambda
x
,
data_source
:
len
(
data_source
[
x
][
1
])
+
1
)
# Sort twice
sampler
=
sampler
.
sort
(
key
=
trg_key
,
buffer_size
=
buffer_size
).
sort
(
key
=
src_key
,
buffer_size
=
buffer_size
)
else
:
sampler
=
sampler
.
shuffle
()
if
args
.
sort_type
==
SortType
.
POOL
:
buffer_size
=
args
.
pool_size
sampler
=
sampler
.
sort
(
key
=
src_key
,
buffer_size
=
buffer_size
)
batch_sampler
=
sampler
.
batch
(
batch_size
=
args
.
batch_size
,
drop_last
=
False
,
batch_size_fn
=
_max_token_fn
,
key
=
_key
)
if
m
==
"train"
:
batch_sampler
=
batch_sampler
.
shard
()
data_loader
=
DataLoader
(
data_loader
=
DataLoader
(
dataset
=
dataset
,
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
batch_sampler
=
batch_sampler
,
collate_fn
=
partial
(
collate_fn
=
partial
(
prepare_infer
_input
,
prepare_train
_input
,
bos_idx
=
args
.
bos_idx
,
bos_idx
=
args
.
bos_idx
,
eos_idx
=
args
.
eos_idx
,
eos_idx
=
args
.
eos_idx
,
pad_idx
=
args
.
e
os_idx
),
pad_idx
=
args
.
b
os_idx
),
num_workers
=
0
,
num_workers
=
0
,
return_list
=
True
)
return_list
=
True
)
data_loaders
=
(
data_loader
,
batch_sampler
.
__len__
)
data_loaders
[
i
]
=
(
data_loader
)
return
data_loaders
,
trg_idx2word
return
data_loaders
def
create_data_loader
(
args
,
world_size
=
1
,
rank
=
0
):
def
create_infer_loader
(
args
):
data_loaders
=
[(
None
,
None
)]
*
2
root
=
None
if
args
.
root
==
"None"
else
args
.
root
data_files
=
[
args
.
training_file
,
args
.
validation_file
(
src_vocab
,
trg_vocab
)
=
WMT14ende
.
get_vocab
(
root
=
root
)
]
if
args
.
validation_file
else
[
args
.
training_file
]
args
.
src_vocab_size
,
args
.
trg_vocab_size
=
len
(
src_vocab
),
len
(
trg_vocab
)
for
i
,
data_file
in
enumerate
(
data_files
):
transform_func
=
WMT14ende
.
get_default_transform_func
(
root
=
root
)
dataset
=
TransformerDataset
(
dataset
=
WMT14ende
.
get_datasets
(
fpattern
=
data_file
,
mode
=
"test"
,
transform_func
=
transform_func
).
filter
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
partial
(
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
min_max_filer
,
max_len
=
args
.
max_length
))
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
batch_sampler
=
SamplerHelper
(
dataset
).
batch
(
end_mark
=
args
.
special_token
[
1
],
batch_size
=
args
.
infer_batch_size
,
drop_last
=
False
)
unk_mark
=
args
.
special_token
[
2
])
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
batch_sampler
=
TransformerBatchSampler
(
dataset
=
dataset
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
use_token_batch
=
args
.
use_token_batch
,
max_length
=
args
.
max_length
,
distribute_mode
=
True
if
i
==
0
else
False
,
world_size
=
world_size
,
rank
=
rank
)
data_loader
=
DataLoader
(
data_loader
=
DataLoader
(
dataset
=
dataset
,
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
batch_sampler
=
batch_sampler
,
collate_fn
=
partial
(
collate_fn
=
partial
(
prepare_train
_input
,
prepare_infer
_input
,
bos_idx
=
args
.
bos_idx
,
bos_idx
=
args
.
bos_idx
,
eos_idx
=
args
.
eos_idx
,
eos_idx
=
args
.
eos_idx
,
pad_idx
=
args
.
bos_idx
),
pad_idx
=
args
.
bos_idx
),
num_workers
=
0
,
num_workers
=
0
,
return_list
=
True
)
return_list
=
True
)
data_loaders
[
i
]
=
(
data_loader
,
batch_sampler
.
__len__
)
return
data_loader
,
trg_vocab
.
to_tokens
return
data_loaders
def
prepare_train_input
(
insts
,
bos_idx
,
eos_idx
,
pad_idx
):
def
prepare_train_input
(
insts
,
bos_idx
,
eos_idx
,
pad_idx
):
...
@@ -126,301 +152,3 @@ class SortType(object):
...
@@ -126,301 +152,3 @@ class SortType(object):
GLOBAL
=
'global'
GLOBAL
=
'global'
POOL
=
'pool'
POOL
=
'pool'
NONE
=
"none"
NONE
=
"none"
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
,
add_end
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
self
.
_add_beg
=
add_beg
self
.
_add_end
=
add_end
def
__call__
(
self
,
sentence
):
return
([
self
.
_beg
]
if
self
.
_add_beg
else
[])
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
([
self
.
_end
]
if
self
.
_add_end
else
[])
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
fields
):
return
[
converter
(
field
)
for
field
,
converter
in
zip
(
fields
,
self
.
_converters
)
]
class
SentenceBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
self
.
batch
.
append
(
info
)
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
tmp
=
self
.
batch
self
.
batch
=
[]
return
tmp
class
TokenBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
max_len
=
cur_len
return
result
else
:
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
lens
):
self
.
i
=
i
# take bos and eos into account
self
.
min_len
=
min
(
lens
[
0
]
+
1
,
lens
[
1
]
+
1
)
self
.
max_len
=
max
(
lens
[
0
]
+
1
,
lens
[
1
]
+
1
)
self
.
src_len
=
lens
[
0
]
self
.
trg_len
=
lens
[
1
]
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
TransformerDataset
(
Dataset
):
def
__init__
(
self
,
src_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
field_delimiter
=
"
\t
"
,
token_delimiter
=
" "
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
trg_fpattern
=
None
):
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
fpattern
,
trg_fpattern
)
def
load_src_trg_ids
(
self
,
fpattern
,
trg_fpattern
=
None
):
src_converter
=
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
,
add_end
=
False
)
trg_converter
=
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
,
add_end
=
False
)
converters
=
ComposedConverter
([
src_converter
,
trg_converter
])
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
[]
self
.
_sample_infos
=
[]
slots
=
[
self
.
_src_seq_ids
,
self
.
_trg_seq_ids
]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
trg_fpattern
)):
lens
=
[]
for
field
,
slot
in
zip
(
converters
(
line
),
slots
):
slot
.
append
(
field
)
lens
.
append
(
len
(
field
))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
lens
))
def
_load_lines
(
self
,
fpattern
,
trg_fpattern
=
None
):
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
sorted
(
fpaths
)
# TODO: Add custum sort
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
(
f_mode
,
f_encoding
,
endl
)
=
(
"r"
,
"utf8"
,
"
\n
"
)
if
trg_fpattern
is
None
:
for
fpath
in
fpaths
:
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
for
line
in
f
:
fields
=
line
.
strip
(
endl
).
split
(
self
.
_field_delimiter
)
yield
fields
else
:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
trg_fpaths
=
glob
.
glob
(
trg_fpattern
)
trg_fpaths
=
sorted
(
trg_fpaths
)
assert
len
(
fpaths
)
==
len
(
trg_fpaths
),
"the number of source language data files must equal
\
with that of source language"
for
fpath
,
trg_fpath
in
zip
(
fpaths
,
trg_fpaths
):
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
with
io
.
open
(
trg_fpath
,
f_mode
,
encoding
=
f_encoding
)
as
trg_f
:
for
line
in
zip
(
f
,
trg_f
):
fields
=
[
field
.
strip
(
endl
)
for
field
in
line
]
yield
fields
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
word_dict
=
{}
(
f_mode
,
f_encoding
,
endl
)
=
(
"r"
,
"utf8"
,
"
\n
"
)
with
io
.
open
(
dict_path
,
f_mode
,
encoding
=
f_encoding
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
endl
)
else
:
word_dict
[
line
.
strip
(
endl
)]
=
idx
return
word_dict
def
get_vocab_summary
(
self
):
return
len
(
self
.
_src_vocab
),
len
(
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
def
__getitem__
(
self
,
idx
):
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
]
)
if
self
.
_trg_seq_ids
else
self
.
_src_seq_ids
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
_sample_infos
)
class
TransformerBatchSampler
(
BatchSampler
):
def
__init__
(
self
,
dataset
,
batch_size
,
pool_size
=
10000
,
sort_type
=
SortType
.
NONE
,
min_length
=
0
,
max_length
=
100
,
shuffle
=
False
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
clip_last_batch
=
False
,
distribute_mode
=
True
,
seed
=
0
,
world_size
=
1
,
rank
=
0
):
for
arg
,
value
in
locals
().
items
():
if
arg
!=
"self"
:
setattr
(
self
,
"_"
+
arg
,
value
)
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
# for multi-devices
self
.
_distribute_mode
=
distribute_mode
self
.
_nranks
=
world_size
self
.
_local_rank
=
rank
def
__iter__
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
_dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
trg_len
)
infos
=
sorted
(
infos
,
key
=
lambda
x
:
x
.
src_len
)
else
:
if
self
.
_shuffle
:
infos
=
self
.
_dataset
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
infos
=
self
.
_dataset
.
_sample_infos
if
self
.
_sort_type
==
SortType
.
POOL
:
reverse
=
True
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
# to avoid placing short next to long sentences
reverse
=
not
reverse
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
,
reverse
=
reverse
)
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
self
.
_nranks
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
batches
.
append
(
batch_creator
.
batch
)
if
self
.
_shuffle_batch
:
self
.
_random
.
shuffle
(
batches
)
if
not
self
.
_use_token_batch
:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches
=
[[
batch
[
self
.
_batch_size
*
i
:
self
.
_batch_size
*
(
i
+
1
)]
for
i
in
range
(
self
.
_nranks
)
]
for
batch
in
batches
]
batches
=
list
(
itertools
.
chain
.
from_iterable
(
batches
))
self
.
batch_number
=
(
len
(
batches
)
+
self
.
_nranks
-
1
)
//
self
.
_nranks
# for multi-device
for
batch_id
,
batch
in
enumerate
(
batches
):
if
not
self
.
_distribute_mode
or
(
batch_id
%
self
.
_nranks
==
self
.
_local_rank
):
batch_indices
=
[
info
.
i
for
info
in
batch
]
yield
batch_indices
if
self
.
_distribute_mode
and
len
(
batches
)
%
self
.
_nranks
!=
0
:
if
self
.
_local_rank
>=
len
(
batches
)
%
self
.
_nranks
:
# use previous data to pad
yield
batch_indices
def
__len__
(
self
):
if
hasattr
(
self
,
"batch_number"
):
#
return
self
.
batch_number
if
not
self
.
_use_token_batch
:
batch_number
=
(
len
(
self
.
_dataset
)
+
self
.
_batch_size
*
self
.
_nranks
-
1
)
//
(
self
.
_batch_size
*
self
.
_nranks
)
else
:
# for uncertain batch number, the actual value is self.batch_number
batch_number
=
sys
.
maxsize
return
batch_number
PaddleNLP/benchmark/transformer/static/train.py
浏览文件 @
30ccfc67
...
@@ -63,9 +63,7 @@ def do_train(args):
...
@@ -63,9 +63,7 @@ def do_train(args):
paddle
.
seed
(
random_seed
)
paddle
.
seed
(
random_seed
)
# Define data loader
# Define data loader
# NOTE: To guarantee all data is involved, use world_size=1 and rank=0.
(
train_loader
),
(
eval_loader
)
=
reader
.
create_data_loader
(
args
)
(
train_loader
,
train_steps_fn
),
(
eval_loader
,
eval_steps_fn
)
=
reader
.
create_data_loader
(
args
)
train_program
=
paddle
.
static
.
Program
()
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
...
...
PaddleNLP/docs/datasets.md
浏览文件 @
30ccfc67
...
@@ -39,7 +39,7 @@
...
@@ -39,7 +39,7 @@
| 数据集名称 | 简介 | 调用方法 |
| 数据集名称 | 简介 | 调用方法 |
| ---- | --------- | ------ |
| ---- | --------- | ------ |
|
[
IWSLT15
](
https://workshop2015.iwslt.org/
)
| IWSLT'15 English-Vietnamese data 英语-越南语翻译数据集|
`paddlenlp.datasets.IWSLT15`
|
|
[
IWSLT15
](
https://workshop2015.iwslt.org/
)
| IWSLT'15 English-Vietnamese data 英语-越南语翻译数据集|
`paddlenlp.datasets.IWSLT15`
|
|
[
WMT14
](
http://www.statmt.org/wmt14/translation-task.html
)
| WMT14 EN-DE 英语-德语翻译数据集|
`paddlenlp.datasets.WMT14`
|
|
[
WMT14
](
http://www.statmt.org/wmt14/translation-task.html
)
| WMT14 EN-DE 英语-德语翻译数据集|
`paddlenlp.datasets.WMT14
ende
`
|
## 时序预测
## 时序预测
...
...
PaddleNLP/examples/machine_translation/transformer/README.md
浏览文件 @
30ccfc67
...
@@ -4,13 +4,12 @@
...
@@ -4,13 +4,12 @@
```
text
```
text
.
.
├── images
# README 文档中的图片
├── images
/
# README 文档中的图片
├── predict.py # 预测脚本
├── predict.py # 预测脚本
├── reader.py # 数据读取接口
├── reader.py # 数据读取接口
├── README.md # 文档
├── README.md # 文档
├── train.py # 训练脚本
├── train.py # 训练脚本
├── transformer.py # 模型定义文件
└── configs/ # 配置文件
└── transformer.yaml # 配置文件
```
```
## 模型简介
## 模型简介
...
@@ -46,6 +45,15 @@
...
@@ -46,6 +45,15 @@
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将
[
WMT'14 EN-DE 数据集
](
http://www.statmt.org/wmt14/translation-task.html
)
作为示例提供。
公开数据集:WMT 翻译大赛是机器翻译领域最具权威的国际评测大赛,其中英德翻译任务提供了一个中等规模的数据集,这个数据集是较多论文中使用的数据集,也是 Transformer 论文中用到的一个数据集。我们也将
[
WMT'14 EN-DE 数据集
](
http://www.statmt.org/wmt14/translation-task.html
)
作为示例提供。
同时,我们提供了一份已经处理好的数据集,可以编写如下代码,对应的数据集将会自动下载并且解压到
`~/.paddlenlp/datasets/machine_translation/WMT14ende/`
。
```
python
# 获取默认的数据处理方式
transform_func
=
WMT14ende
.
get_default_transform_func
(
root
=
root
)
# 下载并处理 WMT14.en-de 翻译数据集
dataset
=
WMT14ende
.
get_datasets
(
mode
=
"train"
,
transform_func
=
transform_func
)
```
### 单机训练
### 单机训练
### 单机单卡
### 单机单卡
...
@@ -55,10 +63,10 @@
...
@@ -55,10 +63,10 @@
```
sh
```
sh
# setting visible devices for training
# setting visible devices for training
export
CUDA_VISIBLE_DEVICES
=
0
export
CUDA_VISIBLE_DEVICES
=
0
python train.py
python train.py
--config
./configs/transformer.base.yaml
```
```
可以在
transformer.yaml 文件中设置相应的参数,比如设置控制最大迭代次数的
`max_iter`
等
。
可以在
`configs/transformer.big.yaml`
和
`configs/transformer.base.yaml`
文件中设置相应的参数
。
### 单机多卡
### 单机多卡
...
@@ -66,7 +74,7 @@ python train.py
...
@@ -66,7 +74,7 @@ python train.py
```
sh
```
sh
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
export
CUDA_VISIBLE_DEVICES
=
0,1,2,3,4,5,6,7
python
-m
paddle.distributed.launch
--gpus
"0,1,2,3,4,5,6,7"
train.py
python
-m
paddle.distributed.launch
--gpus
"0,1,2,3,4,5,6,7"
train.py
--config
./configs/transformer.base.yaml
```
```
...
@@ -80,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0
...
@@ -80,7 +88,7 @@ export CUDA_VISIBLE_DEVICES=0
python predict.py
python predict.py
```
```
由
`predict_file`
指定的文件中文本的翻译结果会输出到
`output_file`
指定的文件。执行预测时需要设置
`init_from_params`
来给出模型所在目录,更多参数的使用可以在
`
transformer
.yaml`
文件中查阅注释说明并进行更改设置。需要注意的是,目前预测仅实现了单卡的预测,原因在于,翻译后面需要的模型评估依赖于预测结果写入文件顺序,多卡情况下,目前不能保证结果写入文件的顺序。
由
`predict_file`
指定的文件中文本的翻译结果会输出到
`output_file`
指定的文件。执行预测时需要设置
`init_from_params`
来给出模型所在目录,更多参数的使用可以在
`
configs/transformer.big.yaml`
和
`configs/transformer.base
.yaml`
文件中查阅注释说明并进行更改设置。需要注意的是,目前预测仅实现了单卡的预测,原因在于,翻译后面需要的模型评估依赖于预测结果写入文件顺序,多卡情况下,目前不能保证结果写入文件的顺序。
### 模型评估
### 模型评估
...
@@ -91,13 +99,13 @@ python predict.py
...
@@ -91,13 +99,13 @@ python predict.py
# 还原 predict.txt 中的预测结果为 tokenize 后的数据
# 还原 predict.txt 中的预测结果为 tokenize 后的数据
sed
-r
's/(@@ )|(@@ ?$)//g'
predict.txt
>
predict.tok.txt
sed
-r
's/(@@ )|(@@ ?$)//g'
predict.txt
>
predict.tok.txt
# 若无 BLEU 评估工具,需先进行下载
# 若无 BLEU 评估工具,需先进行下载
#
git clone https://github.com/moses-smt/mosesdecoder.git
git clone https://github.com/moses-smt/mosesdecoder.git
# 以英德翻译 newstest2014 测试数据为例
# 以英德翻译 newstest2014 测试数据为例
perl
gen_data/mosesdecoder/scripts/generic/multi-bleu.perl gen_data
/wmt14_ende_data/newstest2014.tok.de < predict.tok.txt
perl
mosesdecoder/scripts/generic/multi-bleu.perl ~/.paddlenlp/datasets/machine_translation/WMT14ende/WMT14.en-de
/wmt14_ende_data/newstest2014.tok.de < predict.tok.txt
```
```
可以看到类似如下的结果:
可以看到类似如下的结果
,此处结果是 big model 在 newstest2014 上的结果
:
```
```
BLEU = 2
6.35, 57.7/32.1/20.0/13.0 (BP=1.000, ratio=1.013, hyp_len=63903, ref_len=63078
)
BLEU = 2
7.48, 58.6/33.2/21.1/13.9 (BP=1.000, ratio=1.012, hyp_len=65312, ref_len=64506
)
```
```
## 进阶使用
## 进阶使用
...
...
PaddleNLP/examples/machine_translation/transformer/configs/transformer.base.yaml
浏览文件 @
30ccfc67
...
@@ -10,37 +10,23 @@ init_from_pretrain_model: ""
...
@@ -10,37 +10,23 @@ init_from_pretrain_model: ""
init_from_params
:
"
./trained_models/step_final/"
init_from_params
:
"
./trained_models/step_final/"
# The directory for saving model
# The directory for saving model
save_model
:
"
trained_models"
save_model
:
"
trained_models"
# The directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
# Set seed for CE or debug
random_seed
:
None
random_seed
:
None
# The pattern to match training data files.
training_file
:
"
gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file
:
"
gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file
:
"
gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to.
# The file to output the translation results of predict_file to.
output_file
:
"
predict.txt"
output_file
:
"
predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath
:
"
gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath
:
"
gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary.
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
# The directory to store data.
root
:
None
# Whether to use cuda
# Whether to use cuda
use_gpu
:
True
use_gpu
:
True
# Args for reader, see reader.py for details
# Args for reader, see reader.py for details
token_delimiter
:
"
"
use_token_batch
:
True
pool_size
:
200000
pool_size
:
200000
sort_type
:
"
pool"
sort_type
:
"
pool"
shuffle
:
True
shuffle_batch
:
True
batch_size
:
4096
batch_size
:
4096
infer_batch_size
:
32
infer_batch_size
:
8
# Hyparams for training:
# Hyparams for training:
# The number of epoches for training
# The number of epoches for training
...
...
PaddleNLP/examples/machine_translation/transformer/configs/transformer.big.yaml
浏览文件 @
30ccfc67
...
@@ -10,37 +10,23 @@ init_from_pretrain_model: ""
...
@@ -10,37 +10,23 @@ init_from_pretrain_model: ""
init_from_params
:
"
./trained_models/step_final/"
init_from_params
:
"
./trained_models/step_final/"
# The directory for saving model
# The directory for saving model
save_model
:
"
trained_models"
save_model
:
"
trained_models"
# The directory for saving inference model.
inference_model_dir
:
"
infer_model"
# Set seed for CE or debug
# Set seed for CE or debug
random_seed
:
None
random_seed
:
None
# The pattern to match training data files.
training_file
:
"
gen_data/wmt14_ende_data_bpe/train.tok.clean.bpe.33708.en-de"
# The pattern to match validation data files.
validation_file
:
"
gen_data/wmt14_ende_data_bpe/newstest2013.tok.bpe.33708.en-de"
# The pattern to match test data files.
predict_file
:
"
gen_data/wmt14_ende_data_bpe/newstest2014.tok.bpe.33708.en-de"
# The file to output the translation results of predict_file to.
# The file to output the translation results of predict_file to.
output_file
:
"
predict.txt"
output_file
:
"
predict.txt"
# The path of vocabulary file of source language.
src_vocab_fpath
:
"
gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The path of vocabulary file of target language.
trg_vocab_fpath
:
"
gen_data/wmt14_ende_data_bpe/vocab_all.bpe.33708"
# The <bos>, <eos> and <unk> tokens in the dictionary.
# The <bos>, <eos> and <unk> tokens in the dictionary.
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
special_token
:
[
"
<s>"
,
"
<e>"
,
"
<unk>"
]
# The directory to store data.
root
:
None
# Whether to use cuda
# Whether to use cuda
use_gpu
:
True
use_gpu
:
True
# Args for reader, see reader.py for details
# Args for reader, see reader.py for details
token_delimiter
:
"
"
use_token_batch
:
True
pool_size
:
200000
pool_size
:
200000
sort_type
:
"
pool"
sort_type
:
"
pool"
shuffle
:
True
shuffle_batch
:
True
batch_size
:
4096
batch_size
:
4096
infer_batch_size
:
16
infer_batch_size
:
8
# Hyparams for training:
# Hyparams for training:
# The number of epoches for training
# The number of epoches for training
...
...
PaddleNLP/examples/machine_translation/transformer/gen_data.sh
已删除
100644 → 0
浏览文件 @
8c9d8f56
#! /usr/bin/env bash
set
-e
OUTPUT_DIR
=
$PWD
/gen_data
###############################################################################
# change these variables for other WMT data
###############################################################################
OUTPUT_DIR_DATA
=
"
${
OUTPUT_DIR
}
/wmt14_ende_data"
OUTPUT_DIR_BPE_DATA
=
"
${
OUTPUT_DIR
}
/wmt14_ende_data_bpe"
LANG1
=
"en"
LANG2
=
"de"
# each of TRAIN_DATA: data_url data_file_lang1 data_file_lang2
TRAIN_DATA
=(
'http://statmt.org/wmt13/training-parallel-europarl-v7.tgz'
'europarl-v7.de-en.en'
'europarl-v7.de-en.de'
'http://statmt.org/wmt13/training-parallel-commoncrawl.tgz'
'commoncrawl.de-en.en'
'commoncrawl.de-en.de'
'http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz'
'news-commentary-v12.de-en.en'
'news-commentary-v12.de-en.de'
)
# each of DEV_TEST_DATA: data_url data_file_lang1 data_file_lang2
# source & reference
DEV_TEST_DATA
=(
'http://data.statmt.org/wmt17/translation-task/dev.tgz'
'newstest2013-ref.de.sgm'
'newstest2013-src.en.sgm'
'http://statmt.org/wmt14/test-full.tgz'
'newstest2014-deen-ref.en.sgm'
'newstest2014-deen-src.de.sgm'
)
###############################################################################
###############################################################################
# change these variables for other WMT data
###############################################################################
# OUTPUT_DIR_DATA="${OUTPUT_DIR}/wmt14_enfr_data"
# OUTPUT_DIR_BPE_DATA="${OUTPUT_DIR}/wmt14_enfr_data_bpe"
# LANG1="en"
# LANG2="fr"
# # each of TRAIN_DATA: ata_url data_tgz data_file
# TRAIN_DATA=(
# 'http://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz'
# 'commoncrawl.fr-en.en' 'commoncrawl.fr-en.fr'
# 'http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz'
# 'training/europarl-v7.fr-en.en' 'training/europarl-v7.fr-en.fr'
# 'http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz'
# 'training/news-commentary-v9.fr-en.en' 'training/news-commentary-v9.fr-en.fr'
# 'http://www.statmt.org/wmt10/training-giga-fren.tar'
# 'giga-fren.release2.fixed.en.*' 'giga-fren.release2.fixed.fr.*'
# 'http://www.statmt.org/wmt13/training-parallel-un.tgz'
# 'un/undoc.2000.fr-en.en' 'un/undoc.2000.fr-en.fr'
# )
# # each of DEV_TEST_DATA: data_url data_tgz data_file_lang1 data_file_lang2
# DEV_TEST_DATA=(
# 'http://data.statmt.org/wmt16/translation-task/dev.tgz'
# '.*/newstest201[45]-fren-ref.en.sgm' '.*/newstest201[45]-fren-src.fr.sgm'
# 'http://data.statmt.org/wmt16/translation-task/test.tgz'
# '.*/newstest2016-fren-ref.en.sgm' '.*/newstest2016-fren-src.fr.sgm'
# )
###############################################################################
mkdir
-p
$OUTPUT_DIR_DATA
$OUTPUT_DIR_BPE_DATA
# Extract training data
for
((
i
=
0
;
i<
${#
TRAIN_DATA
[@]
}
;
i+
=
3
))
;
do
data_url
=
${
TRAIN_DATA
[i]
}
data_tgz
=
${
data_url
##*/
}
# training-parallel-commoncrawl.tgz
data
=
${
data_tgz
%.*
}
# training-parallel-commoncrawl
data_lang1
=
${
TRAIN_DATA
[i+1]
}
data_lang2
=
${
TRAIN_DATA
[i+2]
}
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
]
;
then
echo
"Download "
${
data_url
}
echo
"Dir "
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
wget
-O
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
${
data_url
}
fi
if
[
!
-d
${
OUTPUT_DIR_DATA
}
/
${
data
}
]
;
then
echo
"Extract "
${
data_tgz
}
mkdir
-p
${
OUTPUT_DIR_DATA
}
/
${
data
}
tar_type
=
${
data_tgz
:0-3
}
if
[
${
tar_type
}
==
"tar"
]
;
then
tar
-xvf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
else
tar
-xvzf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
fi
fi
# concatenate all training data
for
data_lang
in
$data_lang1
$data_lang2
;
do
for
f
in
`
find
${
OUTPUT_DIR_DATA
}
/
${
data
}
-regex
".*/
${
data_lang
}
"
`
;
do
data_dir
=
`
dirname
$f
`
data_file
=
`
basename
$f
`
f_base
=
${
f
%.*
}
f_ext
=
${
f
##*.
}
if
[
$f_ext
==
"gz"
]
;
then
gunzip
$f
l
=
${
f_base
##*.
}
f_base
=
${
f_base
%.*
}
else
l
=
${
f_ext
}
fi
if
[
$i
-eq
0
]
;
then
cat
${
f_base
}
.
$l
>
${
OUTPUT_DIR_DATA
}
/train.
$l
else
cat
${
f_base
}
.
$l
>>
${
OUTPUT_DIR_DATA
}
/train.
$l
fi
done
done
done
# Clone mosesdecoder
if
[
!
-d
${
OUTPUT_DIR
}
/mosesdecoder
]
;
then
echo
"Cloning moses for data processing"
git clone https://github.com/moses-smt/mosesdecoder.git
${
OUTPUT_DIR
}
/mosesdecoder
fi
# Extract develop and test data
dev_test_data
=
""
for
((
i
=
0
;
i<
${#
DEV_TEST_DATA
[@]
}
;
i+
=
3
))
;
do
data_url
=
${
DEV_TEST_DATA
[i]
}
data_tgz
=
${
data_url
##*/
}
# training-parallel-commoncrawl.tgz
data
=
${
data_tgz
%.*
}
# training-parallel-commoncrawl
data_lang1
=
${
DEV_TEST_DATA
[i+1]
}
data_lang2
=
${
DEV_TEST_DATA
[i+2]
}
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
]
;
then
echo
"Download "
${
data_url
}
wget
-O
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
${
data_url
}
fi
if
[
!
-d
${
OUTPUT_DIR_DATA
}
/
${
data
}
]
;
then
echo
"Extract "
${
data_tgz
}
mkdir
-p
${
OUTPUT_DIR_DATA
}
/
${
data
}
tar_type
=
${
data_tgz
:0-3
}
if
[
${
tar_type
}
==
"tar"
]
;
then
tar
-xvf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
else
tar
-xvzf
${
OUTPUT_DIR_DATA
}
/
${
data_tgz
}
-C
${
OUTPUT_DIR_DATA
}
/
${
data
}
fi
fi
for
data_lang
in
$data_lang1
$data_lang2
;
do
for
f
in
`
find
${
OUTPUT_DIR_DATA
}
/
${
data
}
-regex
".*/
${
data_lang
}
"
`
;
do
echo
"input-from-sgm"
data_dir
=
`
dirname
$f
`
data_file
=
`
basename
$f
`
data_out
=
`
echo
${
data_file
}
|
cut
-d
'-'
-f
1
`
# newstest2016
l
=
`
echo
${
data_file
}
|
cut
-d
'.'
-f
2
`
# en
dev_test_data
=
"
${
dev_test_data
}
\|
${
data_out
}
"
# to make regexp
if
[
!
-e
${
OUTPUT_DIR_DATA
}
/
${
data_out
}
.
$l
]
;
then
${
OUTPUT_DIR
}
/mosesdecoder/scripts/ems/support/input-from-sgm.perl
\
<
$f
>
${
OUTPUT_DIR_DATA
}
/
${
data_out
}
.
$l
fi
done
done
done
# Tokenize data
for
l
in
${
LANG1
}
${
LANG2
}
;
do
for
f
in
`
ls
${
OUTPUT_DIR_DATA
}
/
*
.
$l
|
grep
"
\(
train
\|
newstest2013
\)\.
$l
$"
`
;
do
f_base
=
${
f
%.*
}
# dir/train dir/newstest2013
f_out
=
$f_base
.tok.
$l
f_tmp
=
$f_base
.tmp.
$l
if
[
!
-e
$f_out
]
;
then
echo
"Tokenize "
$f
cat
$f
|
\
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/normalize-punctuation.perl
$l
|
\
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/remove-non-printing-char.perl |
\
tee
-a
$tmp
/valid.raw.
$l
|
\
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/tokenizer.perl
-a
-l
$l
-threads
8
>>
$f_out
echo
$f_out
fi
done
done
for
l
in
${
LANG1
}
${
LANG2
}
;
do
for
f
in
`
ls
${
OUTPUT_DIR_DATA
}
/
*
.
$l
|
grep
"
\(
newstest2014
\)\.
$l
$"
`
;
do
f_base
=
${
f
%.*
}
# dir/newstest2014
f_out
=
$f_base
.tok.
$l
if
[
!
-e
$f_out
]
;
then
echo
"Tokenize "
$f
cat
$f
|
\
${
OUTPUT_DIR
}
/mosesdecoder/scripts/tokenizer/tokenizer.perl
-a
-l
$l
-threads
8
>>
$f_out
echo
$f_out
fi
done
done
# Clean data
for
f
in
${
OUTPUT_DIR_DATA
}
/train.
${
LANG1
}
${
OUTPUT_DIR_DATA
}
/train.tok.
${
LANG1
}
;
do
f_base
=
${
f
%.*
}
# dir/train dir/train.tok
f_out
=
${
f_base
}
.clean
if
[
!
-e
$f_out
.
${
LANG1
}
]
&&
[
!
-e
$f_out
.
${
LANG2
}
]
;
then
echo
"Clean "
${
f_base
}
${
OUTPUT_DIR
}
/mosesdecoder/scripts/training/clean-corpus-n.perl
$f_base
${
LANG1
}
${
LANG2
}
${
f_out
}
1 256
fi
done
python
-m
pip
install
subword-nmt
# Generate BPE data and vocabulary
for
num_operations
in
33708
;
do
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
]
;
then
echo
"Learn BPE with
${
num_operations
}
merge operations"
cat
${
OUTPUT_DIR_DATA
}
/train.tok.clean.
${
LANG1
}
${
OUTPUT_DIR_DATA
}
/train.tok.clean.
${
LANG2
}
|
\
subword-nmt learn-bpe
-s
$num_operations
>
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
fi
for
l
in
${
LANG1
}
${
LANG2
}
;
do
for
f
in
`
ls
${
OUTPUT_DIR_DATA
}
/
*
.
$l
|
grep
"
\(
train
${
dev_test_data
}
\)\.
tok
\(\.
clean
\)\?\.
$l
$"
`
;
do
f_base
=
${
f
%.*
}
# dir/train.tok dir/train.tok.clean dir/newstest2016.tok
f_base
=
${
f_base
##*/
}
# train.tok train.tok.clean newstest2016.tok
f_out
=
${
OUTPUT_DIR_BPE_DATA
}
/
${
f_base
}
.bpe.
${
num_operations
}
.
$l
if
[
!
-e
$f_out
]
;
then
echo
"Apply BPE to "
$f
subword-nmt apply-bpe
-c
${
OUTPUT_DIR_BPE_DATA
}
/bpe.
${
num_operations
}
<
$f
>
$f_out
fi
done
done
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
]
;
then
echo
"Create vocabulary for BPE data"
cat
${
OUTPUT_DIR_BPE_DATA
}
/train.tok.clean.bpe.
${
num_operations
}
.
${
LANG1
}
${
OUTPUT_DIR_BPE_DATA
}
/train.tok.clean.bpe.
${
num_operations
}
.
${
LANG2
}
|
\
subword-nmt get-vocab |
cut
-f1
-d
' '
>
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
fi
done
# Adapt to the reader
for
f
in
${
OUTPUT_DIR_BPE_DATA
}
/
*
.bpe.
${
num_operations
}
.
${
LANG1
}
;
do
f_base
=
${
f
%.*
}
# dir/train.tok.clean.bpe.32000 dir/newstest2016.tok.bpe.32000
f_out
=
${
f_base
}
.
${
LANG1
}
-
${
LANG2
}
if
[
!
-e
$f_out
]
;
then
paste
-d
'\t'
$f_base
.
${
LANG1
}
$f_base
.
${
LANG2
}
>
$f_out
fi
done
if
[
!
-e
${
OUTPUT_DIR_BPE_DATA
}
/vocab_all.bpe.
${
num_operations
}
]
;
then
sed
'1i\<s>\n<e>\n<unk>'
${
OUTPUT_DIR_BPE_DATA
}
/vocab.bpe.
${
num_operations
}
>
${
OUTPUT_DIR_BPE_DATA
}
/vocab_all.bpe.
${
num_operations
}
fi
echo
"All done."
PaddleNLP/examples/machine_translation/transformer/images/multi_head_attention.png
0 → 100644
浏览文件 @
30ccfc67
104.5 KB
PaddleNLP/examples/machine_translation/transformer/images/transformer_network.png
0 → 100644
浏览文件 @
30ccfc67
259.1 KB
PaddleNLP/examples/machine_translation/transformer/predict.py
浏览文件 @
30ccfc67
...
@@ -48,8 +48,7 @@ def do_predict(args):
...
@@ -48,8 +48,7 @@ def do_predict(args):
paddle
.
set_device
(
place
)
paddle
.
set_device
(
place
)
# Define data loader
# Define data loader
(
test_loader
,
test_loader
,
to_tokens
=
reader
.
create_infer_loader
(
args
)
test_steps_fn
),
trg_idx2word
=
reader
.
create_infer_loader
(
args
)
# Define model
# Define model
transformer
=
InferTransformerModel
(
transformer
=
InferTransformerModel
(
...
@@ -95,7 +94,7 @@ def do_predict(args):
...
@@ -95,7 +94,7 @@ def do_predict(args):
if
beam_idx
>=
args
.
n_best
:
if
beam_idx
>=
args
.
n_best
:
break
break
id_list
=
post_process_seq
(
beam
,
args
.
bos_idx
,
args
.
eos_idx
)
id_list
=
post_process_seq
(
beam
,
args
.
bos_idx
,
args
.
eos_idx
)
word_list
=
[
trg_idx2word
[
id
]
for
id
in
id_list
]
word_list
=
to_tokens
(
id_list
)
sequence
=
" "
.
join
(
word_list
)
+
"
\n
"
sequence
=
" "
.
join
(
word_list
)
+
"
\n
"
f
.
write
(
sequence
)
f
.
write
(
sequence
)
...
...
PaddleNLP/examples/machine_translation/transformer/reader.py
浏览文件 @
30ccfc67
...
@@ -22,79 +22,105 @@ from functools import partial
...
@@ -22,79 +22,105 @@ from functools import partial
import
numpy
as
np
import
numpy
as
np
from
paddle.io
import
BatchSampler
,
DataLoader
,
Dataset
from
paddle.io
import
BatchSampler
,
DataLoader
,
Dataset
from
paddlenlp.data
import
Pad
from
paddlenlp.data
import
Pad
from
paddlenlp.datasets
import
WMT14ende
from
paddlenlp.data.sampler
import
SamplerHelper
def
create_infer_loader
(
args
):
def
min_max_filer
(
data
,
max_len
,
min_len
=
0
):
dataset
=
TransformerDataset
(
# 1 for special tokens.
fpattern
=
args
.
predict_file
,
data_min_len
=
min
(
len
(
data
[
0
]),
len
(
data
[
1
]))
+
1
src_vocab_fpath
=
args
.
src_vocab_fpath
,
data_max_len
=
max
(
len
(
data
[
0
]),
len
(
data
[
1
]))
+
1
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
return
(
data_min_len
>=
min_len
)
and
(
data_max_len
<=
max_len
)
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
end_mark
=
args
.
special_token
[
1
],
def
create_data_loader
(
args
):
unk_mark
=
args
.
special_token
[
2
])
root
=
None
if
args
.
root
==
"None"
else
args
.
root
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
(
src_vocab
,
trg_vocab
)
=
WMT14ende
.
get_vocab
(
root
=
root
)
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
args
.
src_vocab_size
,
args
.
trg_vocab_size
=
len
(
src_vocab
),
len
(
trg_vocab
)
trg_idx2word
=
TransformerDataset
.
load_dict
(
transform_func
=
WMT14ende
.
get_default_transform_func
(
root
=
root
)
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
datasets
=
[
batch_sampler
=
TransformerBatchSampler
(
WMT14ende
.
get_datasets
(
dataset
=
dataset
,
mode
=
m
,
transform_func
=
transform_func
)
for
m
in
[
"train"
,
"dev"
]
use_token_batch
=
False
,
]
batch_size
=
args
.
infer_batch_size
,
max_length
=
args
.
max_length
)
def
_max_token_fn
(
current_idx
,
current_batch_size
,
tokens_sofar
,
data_source
):
return
max
(
tokens_sofar
,
len
(
data_source
[
current_idx
][
0
])
+
1
,
len
(
data_source
[
current_idx
][
1
])
+
1
)
def
_key
(
size_so_far
,
minibatch_len
):
return
size_so_far
*
minibatch_len
data_loaders
=
[(
None
)]
*
2
for
i
,
dataset
in
enumerate
(
datasets
):
m
=
dataset
.
mode
dataset
=
dataset
.
filter
(
partial
(
min_max_filer
,
max_len
=
args
.
max_length
))
sampler
=
SamplerHelper
(
dataset
)
src_key
=
(
lambda
x
,
data_source
:
len
(
data_source
[
x
][
0
])
+
1
)
if
args
.
sort_type
==
SortType
.
GLOBAL
:
buffer_size
=
-
1
trg_key
=
(
lambda
x
,
data_source
:
len
(
data_source
[
x
][
1
])
+
1
)
# Sort twice
sampler
=
sampler
.
sort
(
key
=
trg_key
,
buffer_size
=
buffer_size
).
sort
(
key
=
src_key
,
buffer_size
=
buffer_size
)
else
:
sampler
=
sampler
.
shuffle
()
if
args
.
sort_type
==
SortType
.
POOL
:
buffer_size
=
args
.
pool_size
sampler
=
sampler
.
sort
(
key
=
src_key
,
buffer_size
=
buffer_size
)
batch_sampler
=
sampler
.
batch
(
batch_size
=
args
.
batch_size
,
drop_last
=
False
,
batch_size_fn
=
_max_token_fn
,
key
=
_key
)
if
m
==
"train"
:
batch_sampler
=
batch_sampler
.
shard
()
data_loader
=
DataLoader
(
data_loader
=
DataLoader
(
dataset
=
dataset
,
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
batch_sampler
=
batch_sampler
,
collate_fn
=
partial
(
collate_fn
=
partial
(
prepare_infer
_input
,
prepare_train
_input
,
bos_idx
=
args
.
bos_idx
,
bos_idx
=
args
.
bos_idx
,
eos_idx
=
args
.
eos_idx
,
eos_idx
=
args
.
eos_idx
,
pad_idx
=
args
.
bos_idx
),
pad_idx
=
args
.
bos_idx
),
num_workers
=
0
,
num_workers
=
0
,
return_list
=
True
)
return_list
=
True
)
data_loaders
=
(
data_loader
,
batch_sampler
.
__len__
)
data_loaders
[
i
]
=
(
data_loader
)
return
data_loaders
,
trg_idx2word
return
data_loaders
def
create_data_loader
(
args
,
world_size
=
1
,
rank
=
0
):
def
create_infer_loader
(
args
):
data_loaders
=
[(
None
,
None
)]
*
2
root
=
None
if
args
.
root
==
"None"
else
args
.
root
data_files
=
[
args
.
training_file
,
args
.
validation_file
(
src_vocab
,
trg_vocab
)
=
WMT14ende
.
get_vocab
(
root
=
root
)
]
if
args
.
validation_file
else
[
args
.
training_file
]
args
.
src_vocab_size
,
args
.
trg_vocab_size
=
len
(
src_vocab
),
len
(
trg_vocab
)
for
i
,
data_file
in
enumerate
(
data_files
):
transform_func
=
WMT14ende
.
get_default_transform_func
(
root
=
root
)
dataset
=
TransformerDataset
(
dataset
=
WMT14ende
.
get_datasets
(
fpattern
=
data_file
,
mode
=
"test"
,
transform_func
=
transform_func
).
filter
(
src_vocab_fpath
=
args
.
src_vocab_fpath
,
partial
(
trg_vocab_fpath
=
args
.
trg_vocab_fpath
,
min_max_filer
,
max_len
=
args
.
max_length
))
token_delimiter
=
args
.
token_delimiter
,
start_mark
=
args
.
special_token
[
0
],
batch_sampler
=
SamplerHelper
(
dataset
).
batch
(
end_mark
=
args
.
special_token
[
1
],
batch_size
=
args
.
infer_batch_size
,
drop_last
=
False
)
unk_mark
=
args
.
special_token
[
2
])
args
.
src_vocab_size
,
args
.
trg_vocab_size
,
args
.
bos_idx
,
args
.
eos_idx
,
\
args
.
unk_idx
=
dataset
.
get_vocab_summary
()
batch_sampler
=
TransformerBatchSampler
(
dataset
=
dataset
,
batch_size
=
args
.
batch_size
,
pool_size
=
args
.
pool_size
,
sort_type
=
args
.
sort_type
,
shuffle
=
args
.
shuffle
,
shuffle_batch
=
args
.
shuffle_batch
,
use_token_batch
=
args
.
use_token_batch
,
max_length
=
args
.
max_length
,
distribute_mode
=
True
if
i
==
0
else
False
,
world_size
=
world_size
,
rank
=
rank
)
data_loader
=
DataLoader
(
data_loader
=
DataLoader
(
dataset
=
dataset
,
dataset
=
dataset
,
batch_sampler
=
batch_sampler
,
batch_sampler
=
batch_sampler
,
collate_fn
=
partial
(
collate_fn
=
partial
(
prepare_train
_input
,
prepare_infer
_input
,
bos_idx
=
args
.
bos_idx
,
bos_idx
=
args
.
bos_idx
,
eos_idx
=
args
.
eos_idx
,
eos_idx
=
args
.
eos_idx
,
pad_idx
=
args
.
bos_idx
),
pad_idx
=
args
.
bos_idx
),
num_workers
=
0
,
num_workers
=
0
,
return_list
=
True
)
return_list
=
True
)
data_loaders
[
i
]
=
(
data_loader
,
batch_sampler
.
__len__
)
return
data_loader
,
trg_vocab
.
to_tokens
return
data_loaders
def
prepare_train_input
(
insts
,
bos_idx
,
eos_idx
,
pad_idx
):
def
prepare_train_input
(
insts
,
bos_idx
,
eos_idx
,
pad_idx
):
...
@@ -126,301 +152,3 @@ class SortType(object):
...
@@ -126,301 +152,3 @@ class SortType(object):
GLOBAL
=
'global'
GLOBAL
=
'global'
POOL
=
'pool'
POOL
=
'pool'
NONE
=
"none"
NONE
=
"none"
class
Converter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
,
add_beg
,
add_end
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
self
.
_add_beg
=
add_beg
self
.
_add_end
=
add_end
def
__call__
(
self
,
sentence
):
return
([
self
.
_beg
]
if
self
.
_add_beg
else
[])
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
([
self
.
_end
]
if
self
.
_add_end
else
[])
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
fields
):
return
[
converter
(
field
)
for
field
,
converter
in
zip
(
fields
,
self
.
_converters
)
]
class
SentenceBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
self
.
batch
.
append
(
info
)
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
tmp
=
self
.
batch
self
.
batch
=
[]
return
tmp
class
TokenBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
max_len
=
cur_len
return
result
else
:
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
lens
):
self
.
i
=
i
# take bos and eos into account
self
.
min_len
=
min
(
lens
[
0
]
+
1
,
lens
[
1
]
+
1
)
self
.
max_len
=
max
(
lens
[
0
]
+
1
,
lens
[
1
]
+
1
)
self
.
src_len
=
lens
[
0
]
self
.
trg_len
=
lens
[
1
]
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
TransformerDataset
(
Dataset
):
def
__init__
(
self
,
src_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
field_delimiter
=
"
\t
"
,
token_delimiter
=
" "
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
trg_fpattern
=
None
):
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_bos_idx
=
self
.
_src_vocab
[
start_mark
]
self
.
_eos_idx
=
self
.
_src_vocab
[
end_mark
]
self
.
_unk_idx
=
self
.
_src_vocab
[
unk_mark
]
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
fpattern
,
trg_fpattern
)
def
load_src_trg_ids
(
self
,
fpattern
,
trg_fpattern
=
None
):
src_converter
=
Converter
(
vocab
=
self
.
_src_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
,
add_end
=
False
)
trg_converter
=
Converter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_bos_idx
,
end
=
self
.
_eos_idx
,
unk
=
self
.
_unk_idx
,
delimiter
=
self
.
_token_delimiter
,
add_beg
=
False
,
add_end
=
False
)
converters
=
ComposedConverter
([
src_converter
,
trg_converter
])
self
.
_src_seq_ids
=
[]
self
.
_trg_seq_ids
=
[]
self
.
_sample_infos
=
[]
slots
=
[
self
.
_src_seq_ids
,
self
.
_trg_seq_ids
]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
trg_fpattern
)):
lens
=
[]
for
field
,
slot
in
zip
(
converters
(
line
),
slots
):
slot
.
append
(
field
)
lens
.
append
(
len
(
field
))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
lens
))
def
_load_lines
(
self
,
fpattern
,
trg_fpattern
=
None
):
fpaths
=
glob
.
glob
(
fpattern
)
fpaths
=
sorted
(
fpaths
)
# TODO: Add custum sort
assert
len
(
fpaths
)
>
0
,
"no matching file to the provided data path"
(
f_mode
,
f_encoding
,
endl
)
=
(
"r"
,
"utf8"
,
"
\n
"
)
if
trg_fpattern
is
None
:
for
fpath
in
fpaths
:
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
for
line
in
f
:
fields
=
line
.
strip
(
endl
).
split
(
self
.
_field_delimiter
)
yield
fields
else
:
# separated source and target language data files
# assume we can get aligned data by sort the two language files
trg_fpaths
=
glob
.
glob
(
trg_fpattern
)
trg_fpaths
=
sorted
(
trg_fpaths
)
assert
len
(
fpaths
)
==
len
(
trg_fpaths
),
"the number of source language data files must equal
\
with that of source language"
for
fpath
,
trg_fpath
in
zip
(
fpaths
,
trg_fpaths
):
with
io
.
open
(
fpath
,
f_mode
,
encoding
=
f_encoding
)
as
f
:
with
io
.
open
(
trg_fpath
,
f_mode
,
encoding
=
f_encoding
)
as
trg_f
:
for
line
in
zip
(
f
,
trg_f
):
fields
=
[
field
.
strip
(
endl
)
for
field
in
line
]
yield
fields
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
word_dict
=
{}
(
f_mode
,
f_encoding
,
endl
)
=
(
"r"
,
"utf8"
,
"
\n
"
)
with
io
.
open
(
dict_path
,
f_mode
,
encoding
=
f_encoding
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
endl
)
else
:
word_dict
[
line
.
strip
(
endl
)]
=
idx
return
word_dict
def
get_vocab_summary
(
self
):
return
len
(
self
.
_src_vocab
),
len
(
self
.
_trg_vocab
),
self
.
_bos_idx
,
self
.
_eos_idx
,
self
.
_unk_idx
def
__getitem__
(
self
,
idx
):
return
(
self
.
_src_seq_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
]
)
if
self
.
_trg_seq_ids
else
self
.
_src_seq_ids
[
idx
]
def
__len__
(
self
):
return
len
(
self
.
_sample_infos
)
class
TransformerBatchSampler
(
BatchSampler
):
def
__init__
(
self
,
dataset
,
batch_size
,
pool_size
=
10000
,
sort_type
=
SortType
.
NONE
,
min_length
=
0
,
max_length
=
100
,
shuffle
=
False
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
clip_last_batch
=
False
,
distribute_mode
=
True
,
seed
=
0
,
world_size
=
1
,
rank
=
0
):
for
arg
,
value
in
locals
().
items
():
if
arg
!=
"self"
:
setattr
(
self
,
"_"
+
arg
,
value
)
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
# for multi-devices
self
.
_distribute_mode
=
distribute_mode
self
.
_nranks
=
world_size
self
.
_local_rank
=
rank
def
__iter__
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
_dataset
.
_sample_infos
,
key
=
lambda
x
:
x
.
trg_len
)
infos
=
sorted
(
infos
,
key
=
lambda
x
:
x
.
src_len
)
else
:
if
self
.
_shuffle
:
infos
=
self
.
_dataset
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
infos
=
self
.
_dataset
.
_sample_infos
if
self
.
_sort_type
==
SortType
.
POOL
:
reverse
=
True
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
# to avoid placing short next to long sentences
reverse
=
not
reverse
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
,
reverse
=
reverse
)
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
*
self
.
_nranks
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
batches
.
append
(
batch_creator
.
batch
)
if
self
.
_shuffle_batch
:
self
.
_random
.
shuffle
(
batches
)
if
not
self
.
_use_token_batch
:
# when producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches
=
[[
batch
[
self
.
_batch_size
*
i
:
self
.
_batch_size
*
(
i
+
1
)]
for
i
in
range
(
self
.
_nranks
)
]
for
batch
in
batches
]
batches
=
list
(
itertools
.
chain
.
from_iterable
(
batches
))
self
.
batch_number
=
(
len
(
batches
)
+
self
.
_nranks
-
1
)
//
self
.
_nranks
# for multi-device
for
batch_id
,
batch
in
enumerate
(
batches
):
if
not
self
.
_distribute_mode
or
(
batch_id
%
self
.
_nranks
==
self
.
_local_rank
):
batch_indices
=
[
info
.
i
for
info
in
batch
]
yield
batch_indices
if
self
.
_distribute_mode
and
len
(
batches
)
%
self
.
_nranks
!=
0
:
if
self
.
_local_rank
>=
len
(
batches
)
%
self
.
_nranks
:
# use previous data to pad
yield
batch_indices
def
__len__
(
self
):
if
hasattr
(
self
,
"batch_number"
):
#
return
self
.
batch_number
if
not
self
.
_use_token_batch
:
batch_number
=
(
len
(
self
.
_dataset
)
+
self
.
_batch_size
*
self
.
_nranks
-
1
)
//
(
self
.
_batch_size
*
self
.
_nranks
)
else
:
# for uncertain batch number, the actual value is self.batch_number
batch_number
=
sys
.
maxsize
return
batch_number
PaddleNLP/examples/machine_translation/transformer/train.py
浏览文件 @
30ccfc67
...
@@ -43,9 +43,7 @@ def do_train(args):
...
@@ -43,9 +43,7 @@ def do_train(args):
paddle
.
seed
(
random_seed
)
paddle
.
seed
(
random_seed
)
# Define data loader
# Define data loader
(
train_loader
,
train_steps_fn
),
(
eval_loader
,
(
train_loader
),
(
eval_loader
)
=
reader
.
create_data_loader
(
args
)
eval_steps_fn
)
=
reader
.
create_data_loader
(
args
,
trainer_count
,
rank
)
# Define model
# Define model
transformer
=
TransformerModel
(
transformer
=
TransformerModel
(
...
@@ -150,7 +148,6 @@ def do_train(args):
...
@@ -150,7 +148,6 @@ def do_train(args):
if
step_idx
%
args
.
save_step
==
0
and
step_idx
!=
0
:
if
step_idx
%
args
.
save_step
==
0
and
step_idx
!=
0
:
# Validation
# Validation
if
args
.
validation_file
:
transformer
.
eval
()
transformer
.
eval
()
total_sum_cost
=
0
total_sum_cost
=
0
total_token_num
=
0
total_token_num
=
0
...
...
PaddleNLP/paddlenlp/data/sampler.py
浏览文件 @
30ccfc67
...
@@ -137,7 +137,7 @@ class SamplerHelper(object):
...
@@ -137,7 +137,7 @@ class SamplerHelper(object):
"""
"""
Sort samples according to given callable cmp or key.
Sort samples according to given callable cmp or key.
Args:
Args:
cmp (callable): The func
a
tion of comparison. Default: None.
cmp (callable): The function of comparison. Default: None.
key (callable): Return element to be compared. Default: None.
key (callable): Return element to be compared. Default: None.
reverse (bool): If True, it means in descending order, and False means in ascending order. Default: False.
reverse (bool): If True, it means in descending order, and False means in ascending order. Default: False.
buffer_size (int): Buffer size for sort. If buffer_size < 0 or buffer_size is more than the length of the data,
buffer_size (int): Buffer size for sort. If buffer_size < 0 or buffer_size is more than the length of the data,
...
...
PaddleNLP/paddlenlp/data/vocab.py
浏览文件 @
30ccfc67
...
@@ -16,6 +16,7 @@ import collections
...
@@ -16,6 +16,7 @@ import collections
import
io
import
io
import
json
import
json
import
os
import
os
import
warnings
class
Vocab
(
object
):
class
Vocab
(
object
):
...
@@ -179,7 +180,12 @@ class Vocab(object):
...
@@ -179,7 +180,12 @@ class Vocab(object):
tokens
=
[]
tokens
=
[]
for
idx
in
indices
:
for
idx
in
indices
:
if
not
isinstance
(
idx
,
int
)
or
idx
>
max_idx
:
if
not
isinstance
(
idx
,
int
):
warnings
.
warn
(
"The type of `to_tokens()`'s input `indices` is not `int` which will be forcibly transfered to `int`. "
)
idx
=
int
(
idx
)
if
idx
>
max_idx
:
raise
ValueError
(
raise
ValueError
(
'Token index {} in the provided `indices` is invalid.'
.
'Token index {} in the provided `indices` is invalid.'
.
format
(
idx
))
format
(
idx
))
...
...
PaddleNLP/paddlenlp/datasets/translation.py
浏览文件 @
30ccfc67
...
@@ -13,7 +13,7 @@ from paddlenlp.data.sampler import SamplerHelper
...
@@ -13,7 +13,7 @@ from paddlenlp.data.sampler import SamplerHelper
from
paddlenlp.utils.env
import
DATA_HOME
from
paddlenlp.utils.env
import
DATA_HOME
from
paddle.dataset.common
import
md5file
from
paddle.dataset.common
import
md5file
__all__
=
[
'TranslationDataset'
,
'IWSLT15'
]
__all__
=
[
'TranslationDataset'
,
'IWSLT15'
,
'WMT14ende'
]
def
sequential_transforms
(
*
transforms
):
def
sequential_transforms
(
*
transforms
):
...
@@ -29,8 +29,8 @@ def get_default_tokenizer():
...
@@ -29,8 +29,8 @@ def get_default_tokenizer():
"""Only support split tokenizer
"""Only support split tokenizer
"""
"""
def
_split_tokenizer
(
x
):
def
_split_tokenizer
(
x
,
delimiter
=
None
):
return
x
.
split
()
return
x
.
split
(
delimiter
)
return
_split_tokenizer
return
_split_tokenizer
...
@@ -50,9 +50,9 @@ class TranslationDataset(paddle.io.Dataset):
...
@@ -50,9 +50,9 @@ class TranslationDataset(paddle.io.Dataset):
MD5
=
None
MD5
=
None
VOCAB_INFO
=
None
VOCAB_INFO
=
None
UNK_TOKEN
=
None
UNK_TOKEN
=
None
PAD_TOKEN
=
None
BOS_TOKEN
=
None
BOS_TOKEN
=
None
EOS_TOKEN
=
None
EOS_TOKEN
=
None
PAD_TOKEN
=
None
def
__init__
(
self
,
data
):
def
__init__
(
self
,
data
):
self
.
data
=
data
self
.
data
=
data
...
@@ -143,14 +143,14 @@ class TranslationDataset(paddle.io.Dataset):
...
@@ -143,14 +143,14 @@ class TranslationDataset(paddle.io.Dataset):
tgt_file_path
=
os
.
path
.
join
(
root
,
tgt_vocab_filename
)
tgt_file_path
=
os
.
path
.
join
(
root
,
tgt_vocab_filename
)
src_vocab
=
Vocab
.
load_vocabulary
(
src_vocab
=
Vocab
.
load_vocabulary
(
src_file_path
,
filepath
=
src_file_path
,
unk_token
=
cls
.
UNK_TOKEN
,
unk_token
=
cls
.
UNK_TOKEN
,
pad_token
=
cls
.
PAD_TOKEN
,
pad_token
=
cls
.
PAD_TOKEN
,
bos_token
=
cls
.
BOS_TOKEN
,
bos_token
=
cls
.
BOS_TOKEN
,
eos_token
=
cls
.
EOS_TOKEN
)
eos_token
=
cls
.
EOS_TOKEN
)
tgt_vocab
=
Vocab
.
load_vocabulary
(
tgt_vocab
=
Vocab
.
load_vocabulary
(
tgt_file_path
,
filepath
=
tgt_file_path
,
unk_token
=
cls
.
UNK_TOKEN
,
unk_token
=
cls
.
UNK_TOKEN
,
pad_token
=
cls
.
PAD_TOKEN
,
pad_token
=
cls
.
PAD_TOKEN
,
bos_token
=
cls
.
BOS_TOKEN
,
bos_token
=
cls
.
BOS_TOKEN
,
...
@@ -273,6 +273,90 @@ class IWSLT15(TranslationDataset):
...
@@ -273,6 +273,90 @@ class IWSLT15(TranslationDataset):
transform_func
[
1
](
data
[
1
]))
for
data
in
self
.
data
]
transform_func
[
1
](
data
[
1
]))
for
data
in
self
.
data
]
class
WMT14ende
(
TranslationDataset
):
"""
WMT14 English to German translation dataset.
Args:
mode(str, optional): It could be 'train', 'dev' or 'test'. Default: 'train'.
root(str, optional): If None, dataset will be downloaded in
`/root/.paddlenlp/datasets/machine_translation/WMT14ende/`. Default: None.
transform_func(callable, optional): If not None, it transforms raw data
to index data. Default: None.
Examples:
.. code-block:: python
from paddlenlp.datasets import WMT14ende
transform_func = WMT14ende.get_default_transform_func(root=root)
train_dataset = WMT14ende.get_datasets(mode="train", transform_func=transform_func)
"""
URL
=
"https://paddlenlp.bj.bcebos.com/datasets/WMT14.en-de.tar.gz"
SPLITS
=
{
'train'
:
TranslationDataset
.
META_INFO
(
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data_bpe"
,
"train.tok.clean.bpe.33708.en"
),
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data_bpe"
,
"train.tok.clean.bpe.33708.de"
),
"c7c0b77e672fc69f20be182ae37ff62c"
,
"1865ece46948fda1209d3b7794770a0a"
),
'dev'
:
TranslationDataset
.
META_INFO
(
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data_bpe"
,
"newstest2013.tok.bpe.33708.en"
),
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data_bpe"
,
"newstest2013.tok.bpe.33708.de"
),
"aa4228a4bedb6c45d67525fbfbcee75e"
,
"9b1eeaff43a6d5e78a381a9b03170501"
),
'test'
:
TranslationDataset
.
META_INFO
(
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data_bpe"
,
"newstest2014.tok.bpe.33708.en"
),
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data_bpe"
,
"newstest2014.tok.bpe.33708.de"
),
"c9403eacf623c6e2d9e5a1155bdff0b5"
,
"0058855b55e37c4acfcb8cffecba1050"
),
'dev-eval'
:
TranslationDataset
.
META_INFO
(
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data"
,
"newstest2013.tok.en"
),
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data"
,
"newstest2013.tok.de"
),
"d74712eb35578aec022265c439831b0e"
,
"6ff76ced35b70e63a61ecec77a1c418f"
),
'test-eval'
:
TranslationDataset
.
META_INFO
(
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data"
,
"newstest2014.tok.en"
),
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data"
,
"newstest2014.tok.de"
),
"8cce2028e4ca3d4cc039dfd33adbfb43"
,
"a1b1f4c47f487253e1ac88947b68b3b8"
)
}
VOCAB_INFO
=
(
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data_bpe"
,
"vocab_all.bpe.33708"
),
os
.
path
.
join
(
"WMT14.en-de"
,
"wmt14_ende_data_bpe"
,
"vocab_all.bpe.33708"
),
"2fc775b7df37368e936a8e1f63846bb0"
,
"2fc775b7df37368e936a8e1f63846bb0"
)
UNK_TOKEN
=
"<unk>"
BOS_TOKEN
=
"<s>"
EOS_TOKEN
=
"<e>"
MD5
=
"5506d213dba4124121c682368257bae4"
def
__init__
(
self
,
mode
=
"train"
,
root
=
None
,
transform_func
=
None
):
if
mode
not
in
(
"train"
,
"dev"
,
"test"
,
"dev-eval"
,
"test-eval"
):
raise
TypeError
(
'`train`, `dev`, `test`, `dev-eval` or `test-eval` is supported but `{}` is passed in'
.
format
(
mode
))
if
transform_func
is
not
None
and
len
(
transform_func
)
!=
2
:
if
len
(
transform_func
)
!=
2
:
raise
ValueError
(
"`transform_func` must have length of two for"
"source and target."
)
self
.
data
=
WMT14ende
.
get_data
(
mode
=
mode
,
root
=
root
)
self
.
mode
=
mode
if
transform_func
is
not
None
:
self
.
data
=
[(
transform_func
[
0
](
data
[
0
]),
transform_func
[
1
](
data
[
1
]))
for
data
in
self
.
data
]
super
(
WMT14ende
,
self
).
__init__
(
self
.
data
)
# For test, not API
# For test, not API
def
prepare_train_input
(
insts
,
pad_id
):
def
prepare_train_input
(
insts
,
pad_id
):
src
,
src_length
=
Pad
(
pad_val
=
pad_id
,
ret_length
=
True
)(
src
,
src_length
=
Pad
(
pad_val
=
pad_id
,
ret_length
=
True
)(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录