Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
67323767
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
67323767
编写于
7月 22, 2019
作者:
A
AndyELiu
提交者:
Yibing Liu
7月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
submit code for joint embedding paper (#2896)
上级
af1197ed
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
2889 addition
and
0 deletion
+2889
-0
PaddleNLP/Research/ACL2019-JEMT/README.md
PaddleNLP/Research/ACL2019-JEMT/README.md
+128
-0
PaddleNLP/Research/ACL2019-JEMT/config.py
PaddleNLP/Research/ACL2019-JEMT/config.py
+117
-0
PaddleNLP/Research/ACL2019-JEMT/desc.py
PaddleNLP/Research/ACL2019-JEMT/desc.py
+107
-0
PaddleNLP/Research/ACL2019-JEMT/infer.py
PaddleNLP/Research/ACL2019-JEMT/infer.py
+342
-0
PaddleNLP/Research/ACL2019-JEMT/model.py
PaddleNLP/Research/ACL2019-JEMT/model.py
+984
-0
PaddleNLP/Research/ACL2019-JEMT/reader.py
PaddleNLP/Research/ACL2019-JEMT/reader.py
+385
-0
PaddleNLP/Research/ACL2019-JEMT/train.py
PaddleNLP/Research/ACL2019-JEMT/train.py
+826
-0
未找到文件。
PaddleNLP/Research/ACL2019-JEMT/README.md
0 → 100644
浏览文件 @
67323767
## 简介
### 任务说明
机器翻译的输入一般是源语言的句子。但在很多实际系统中,比如语音识别系统的输出或者基于拼音的文字输入,源语言句子一般包含很多同音字错误, 这会导致翻译出现很多意想不到的错误。由于可以同时获得发音信息,我们提出了一种在输入端加入发音信息,进而在模型的嵌入层
融合文字信息和发音信息的翻译方法,大大提高了翻译模型对同音字错误的抵抗能力。
文章地址:https://arxiv.org/abs/1810.06729
### 效果说明
我们使用LDC Chinese-to-English数据集训练。中文词典用的是
[
DaCiDian
](
https://github.com/aishell-foundation/DaCiDian
)
。 在newstest2006上进行评测,效果如下所示:
| beta=0 | beta=0.50 | beta=0.85 | beta=0.95 |
|-|-|-|-|
| 47.96 | 48.71 | 48.85 | 48.46 |
beta代表发音信息的权重。这表明,即使将绝大部分权重放在发音信息上,翻译的效果依然很好。与此同时,翻译系统对同音字错误的抵抗力大大提高。
## 安装说明
1.
paddle安装
本项目依赖于 PaddlePaddle Fluid 1.3.1 及以上版本,请参考
[
安装指南
](
http://www.paddlepaddle.org/#quick-start
)
进行安装
2.
环境依赖
请参考PaddlePaddle
[
安装说明
](
http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html
)
部分的内容
## 如何训练
1.
数据格式
数据格式和
[
Paddle机器翻译
](
https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer
)
的格式一致。为了获得输入句子的发音信息,需要额外提供源语言的发音基本单元和发音的词典。
A) 发音基本单元文件
中文的发音基本单元是拼音,将所有的拼音放在一个文件,类似:
<unk>
bo
li
。。。
B)发音词典
根据DaCiDian,对bpe后的源语言中的token赋予一个或者几个发音,类似:
▁玻利维亚 bo li wei ya
▁举行 ju xing
▁总统 zong tong
▁与 yu
巴斯 ba si
▁这个 zhei ge|zhe ge
。。。
2.
训练模型
数据准备完成后,可以使用
`train.py`
脚本进行训练。例子如下:
```
sh
python train.py
\
--src_vocab_fpath
nist_data/vocab_all.28000
\
--trg_vocab_fpath
nist_data/vocab_all.28000
\
--train_file_pattern
nist_data/nist_train.txt
\
--phoneme_vocab_fpath
nist_data/zh_pinyins.txt
\
--lexicon_fpath
nist_data/zh_lexicon.txt
\
--batch_size
2048
\
--use_token_batch
True
\
--sort_type
pool
\
--pool_size
200000
\
--use_py_reader
False
\
--use_mem_opt
False
\
--enable_ce
False
\
--fetch_steps
1
\
pass_num 100
\
learning_rate 2.0
\
warmup_steps 8000
\
beta2 0.997
\
d_model 512
\
d_inner_hid 2048
\
n_head 8
\
weight_sharing True
\
max_length 256
\
save_freq 10000
\
beta 0.85
\
model_dir pinyin_models_beta085
\
ckpt_dir pinyin_ckpts_beta085
```
上述命令中设置了源语言词典文件路径(
`src_vocab_fpath`
)、目标语言词典文件路径(
`trg_vocab_fpath`
)、训练数据文件(
`train_file_pattern`
,支持通配符), 发音单元文件路径(
`phoneme_vocab_fpath`
), 发音词典路径(
`lexicon_fpath`
)等数据相关的参数和构造 batch 方式(
`use_token_batch`
指定了数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看:
```
sh
python train.py
--help
```
更多模型训练相关的参数则在
`config.py`
中的
`ModelHyperParams`
和
`TrainTaskConfig`
内定义;
`ModelHyperParams`
定义了 embedding 维度等模型超参数,
`TrainTaskConfig`
定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖
`config.py`
中的配置.
注意,如训练时更改了模型配置,使用
`infer.py`
预测时需要使用对应相同的模型配置;另外,训练时默认使用所有 GPU,可以通过
`CUDA_VISIBLE_DEVICES`
环境变量来设置使用指定的 GPU。
## 如何预测
使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到标准输出:
```
sh
python infer.py
\
--src_vocab_fpath
nist_data/vocab_all.28000
\
--trg_vocab_fpath
nist_data/vocab_all.28000
\
--test_file_pattern
nist_data/nist_test.txt
\
--phoneme_vocab_fpath
nist_data/zh_pinyins.txt
\
--lexicon_fpath
nist_data/zh_lexicon.txt
\
--batch_size
32
\
model_path pinyin_models_beta085/iter_200000.infer.model
\
beam_size 5
\
max_out_len 255
\
beta 0.85
```
PaddleNLP/Research/ACL2019-JEMT/config.py
0 → 100644
浏览文件 @
67323767
class
TrainTaskConfig
(
object
):
# support both CPU and GPU now.
use_gpu
=
True
# the epoch number to train.
pass_num
=
30
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size
=
32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate
=
2.0
beta1
=
0.9
beta2
=
0.997
eps
=
1e-9
# the parameters for learning rate scheduling.
warmup_steps
=
8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps
=
0.1
# the directory for saving trained models.
model_dir
=
"trained_models"
# the directory for saving checkpoints.
ckpt_dir
=
"trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path
=
None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step
=
0
# the frequency to save trained models.
save_freq
=
10000
class
InferTaskConfig
(
object
):
use_gpu
=
True
# the number of examples in one run for sequence generation.
batch_size
=
10
# the parameters for beam search.
beam_size
=
5
max_out_len
=
256
# the number of decoded sentences to output.
n_best
=
1
# the flags indicating whether to output the special tokens.
output_bos
=
False
output_eos
=
False
output_unk
=
True
# the directory for loading the trained model.
model_path
=
"trained_models/pass_1.infer.model"
class
ModelHyperParams
(
object
):
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size
=
10000
# size of target word dictionay
trg_vocab_size
=
10000
# size of phone dictionary
phone_vocab_size
=
1000
# ratio of phoneme embeddings
beta
=
0.0
# index for <bos> token
bos_idx
=
0
# index for <eos> token
eos_idx
=
1
# index for <unk> token
unk_idx
=
2
# index for <unk> in phonemes
phone_pad_idx
=
0
# max length of sequences deciding the size of position encoding table.
max_length
=
256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model
=
512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid
=
2048
# the dimension that keys are projected to for dot-product attention.
d_key
=
64
# the dimension that values are projected to for dot-product attention.
d_value
=
64
# number of head used in multi-head attention.
n_head
=
8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer
=
6
# dropout rates of different modules.
prepostprocess_dropout
=
0.1
attention_dropout
=
0.1
relu_dropout
=
0.1
# to process before each sub-layer
preprocess_cmd
=
"n"
# layer normalization
# to process after each sub-layer
postprocess_cmd
=
"da"
# dropout + residual connection
# random seed used in dropout for CE.
dropout_seed
=
None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing
=
True
def
merge_cfg_from_list
(
cfg_list
,
g_cfgs
):
"""
Set the above global configurations using the cfg_list.
"""
assert
len
(
cfg_list
)
%
2
==
0
for
key
,
value
in
zip
(
cfg_list
[
0
::
2
],
cfg_list
[
1
::
2
]):
for
g_cfg
in
g_cfgs
:
if
hasattr
(
g_cfg
,
key
):
try
:
value
=
eval
(
value
)
except
Exception
:
# for file path
pass
setattr
(
g_cfg
,
key
,
value
)
break
PaddleNLP/Research/ACL2019-JEMT/desc.py
0 → 100644
浏览文件 @
67323767
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size
=
-
1
# The placeholder for squence length in compile time.
seq_len
=
256
# The placeholder for phoneme sequence length in comiple time.
phone_len
=
16
# The placeholder for head number in compile time.
n_head
=
8
# The placeholder for model dim in compile time.
d_model
=
512
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs
=
{
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word"
:
[(
batch_size
,
seq_len
,
1
),
"int64"
,
2
],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos"
:
[(
batch_size
,
seq_len
,
1
),
"int64"
],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
"src_phone"
:
[(
batch_size
,
seq_len
,
phone_len
,
1
),
"int64"
],
"src_phone_mask"
:
[(
batch_size
,
seq_len
,
phone_len
),
"int64"
],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word"
:
[(
batch_size
,
seq_len
,
1
),
"int64"
,
2
],
# lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos"
:
[(
batch_size
,
seq_len
,
1
),
"int64"
],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias"
:
[(
batch_size
,
n_head
,
seq_len
,
seq_len
),
"float32"
],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output"
:
[(
batch_size
,
seq_len
,
d_model
),
"float32"
],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word"
:
[(
batch_size
*
seq_len
,
1
),
"int64"
],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight"
:
[(
batch_size
*
seq_len
,
1
),
"float32"
],
# This input is used in beam-search decoder.
"init_score"
:
[(
batch_size
,
1
),
"float32"
,
2
],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx"
:
[(
batch_size
,
),
"int32"
],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names
=
(
"src_word_emb_table"
,
"trg_word_emb_table"
,
)
phone_emb_param_name
=
"phone_emb_table"
# Names of position encoding table which will be initialized externally.
pos_enc_param_names
=
(
"src_pos_enc_table"
,
"trg_pos_enc_table"
,
)
# separated inputs for different usages.
encoder_data_input_fields
=
(
"src_word"
,
"src_pos"
,
"src_slf_attn_bias"
,
"src_phone"
,
"src_phone_mask"
,
)
decoder_data_input_fields
=
(
"trg_word"
,
"trg_pos"
,
"trg_slf_attn_bias"
,
"trg_src_attn_bias"
,
"enc_output"
,
)
label_data_input_fields
=
(
"lbl_word"
,
"lbl_weight"
,
)
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields
=
(
"trg_word"
,
"init_score"
,
"init_idx"
,
"trg_src_attn_bias"
,
)
# Set seed for CE
dropout_seed
=
None
PaddleNLP/Research/ACL2019-JEMT/infer.py
0 → 100644
浏览文件 @
67323767
import
argparse
import
ast
import
multiprocessing
import
numpy
as
np
import
os
import
sys
from
functools
import
partial
import
paddle
import
paddle.fluid
as
fluid
import
reader
from
config
import
*
from
desc
import
*
from
model
import
fast_decode
as
fast_decoder
from
train
import
pad_batch_data
,
pad_phoneme_data
,
prepare_data_generator
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
"Training for Transformer."
)
parser
.
add_argument
(
"--src_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of source language."
)
parser
.
add_argument
(
"--trg_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of target language."
)
parser
.
add_argument
(
"--phoneme_vocab_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of vocabulary file of phonemes."
)
parser
.
add_argument
(
"--lexicon_fpath"
,
type
=
str
,
required
=
True
,
help
=
"The path of lexicon of source language."
)
parser
.
add_argument
(
"--test_file_pattern"
,
type
=
str
,
required
=
True
,
help
=
"The pattern to match test data files."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
50
,
help
=
"The number of examples in one run for sequence generation."
)
parser
.
add_argument
(
"--pool_size"
,
type
=
int
,
default
=
10000
,
help
=
"The buffer size to pool data."
)
parser
.
add_argument
(
"--special_token"
,
type
=
str
,
default
=
[
"<s>"
,
"<e>"
,
"<unk>"
],
nargs
=
3
,
help
=
"The <bos>, <eos> and <unk> tokens in the dictionary."
)
parser
.
add_argument
(
"--token_delimiter"
,
type
=
lambda
x
:
str
(
x
.
encode
().
decode
(
"unicode-escape"
)),
default
=
" "
,
help
=
"The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. "
)
parser
.
add_argument
(
"--use_mem_opt"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"The flag indicating whether to use memory optimization."
)
parser
.
add_argument
(
"--use_py_reader"
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
"The flag indicating whether to use py_reader."
)
parser
.
add_argument
(
"--use_parallel_exe"
,
type
=
ast
.
literal_eval
,
default
=
False
,
help
=
"The flag indicating whether to use ParallelExecutor."
)
parser
.
add_argument
(
'opts'
,
help
=
'See config.py for all options'
,
default
=
None
,
nargs
=
argparse
.
REMAINDER
)
args
=
parser
.
parse_args
()
# Append args related to dict
src_dict
=
reader
.
DataReader
.
load_dict
(
args
.
src_vocab_fpath
)
trg_dict
=
reader
.
DataReader
.
load_dict
(
args
.
trg_vocab_fpath
)
phone_dict
=
reader
.
DataReader
.
load_dict
(
args
.
phoneme_vocab_fpath
)
dict_args
=
[
"src_vocab_size"
,
str
(
len
(
src_dict
)),
"trg_vocab_size"
,
str
(
len
(
trg_dict
)),
"phone_vocab_size"
,
str
(
len
(
phone_dict
)),
"bos_idx"
,
str
(
src_dict
[
args
.
special_token
[
0
]]),
"eos_idx"
,
str
(
src_dict
[
args
.
special_token
[
1
]]),
"unk_idx"
,
str
(
src_dict
[
args
.
special_token
[
2
]])
]
merge_cfg_from_list
(
args
.
opts
+
dict_args
,
[
InferTaskConfig
,
ModelHyperParams
])
return
args
def
post_process_seq
(
seq
,
bos_idx
=
ModelHyperParams
.
bos_idx
,
eos_idx
=
ModelHyperParams
.
eos_idx
,
output_bos
=
InferTaskConfig
.
output_bos
,
output_eos
=
InferTaskConfig
.
output_eos
):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos
=
len
(
seq
)
-
1
for
i
,
idx
in
enumerate
(
seq
):
if
idx
==
eos_idx
:
eos_pos
=
i
break
seq
=
[
idx
for
idx
in
seq
[:
eos_pos
+
1
]
if
(
output_bos
or
idx
!=
bos_idx
)
and
(
output_eos
or
idx
!=
eos_idx
)
]
return
seq
def
prepare_batch_input
(
insts
,
data_input_names
,
src_pad_idx
,
phone_pad_idx
,
bos_idx
,
n_head
,
d_model
,
place
):
"""
Put all padded data needed by beam search decoder into a dict.
"""
src_word
,
src_pos
,
src_slf_attn_bias
,
src_max_len
=
pad_batch_data
(
[
inst
[
0
]
for
inst
in
insts
],
src_pad_idx
,
n_head
,
is_target
=
False
)
src_word
=
src_word
.
reshape
(
-
1
,
src_max_len
,
1
)
src_pos
=
src_pos
.
reshape
(
-
1
,
src_max_len
,
1
)
src_phone
,
src_phone_mask
,
max_phone_len
=
pad_phoneme_data
(
[
inst
[
1
]
for
inst
in
insts
],
phone_pad_idx
,
src_max_len
)
# start tokens
trg_word
=
np
.
asarray
([[
bos_idx
]]
*
len
(
insts
),
dtype
=
"int64"
)
trg_src_attn_bias
=
np
.
tile
(
src_slf_attn_bias
[:,
:,
::
src_max_len
,
:],
[
1
,
1
,
1
,
1
]).
astype
(
"float32"
)
trg_word
=
trg_word
.
reshape
(
-
1
,
1
,
1
)
def
to_lodtensor
(
data
,
place
,
lod
=
None
):
data_tensor
=
fluid
.
LoDTensor
()
data_tensor
.
set
(
data
,
place
)
if
lod
is
not
None
:
data_tensor
.
set_lod
(
lod
)
return
data_tensor
# beamsearch_op must use tensors with lod
init_score
=
to_lodtensor
(
np
.
zeros_like
(
trg_word
,
dtype
=
"float32"
).
reshape
(
-
1
,
1
),
place
,
[
range
(
trg_word
.
shape
[
0
]
+
1
)]
*
2
)
trg_word
=
to_lodtensor
(
trg_word
,
place
,
[
range
(
trg_word
.
shape
[
0
]
+
1
)]
*
2
)
init_idx
=
np
.
asarray
(
range
(
len
(
insts
)),
dtype
=
"int32"
)
data_input_dict
=
dict
(
zip
(
data_input_names
,
[
src_word
,
src_pos
,
src_slf_attn_bias
,
src_phone
,
src_phone_mask
,
trg_word
,
init_score
,
init_idx
,
trg_src_attn_bias
]))
return
data_input_dict
def
prepare_feed_dict_list
(
data_generator
,
count
,
place
):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list
=
[]
if
data_generator
is
not
None
:
# use_py_reader == False
data_input_names
=
encoder_data_input_fields
+
fast_decoder_data_input_fields
data
=
next
(
data_generator
)
for
idx
,
data_buffer
in
enumerate
(
data
):
data_input_dict
=
prepare_batch_input
(
data_buffer
,
data_input_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
phone_pad_idx
,
ModelHyperParams
.
bos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
,
place
)
feed_dict_list
.
append
(
data_input_dict
)
return
feed_dict_list
if
len
(
feed_dict_list
)
==
count
else
None
def
py_reader_provider_wrapper
(
data_reader
,
place
):
"""
Data provider needed by fluid.layers.py_reader.
"""
def
py_reader_provider
():
data_input_names
=
encoder_data_input_fields
+
fast_decoder_data_input_fields
for
batch_id
,
data
in
enumerate
(
data_reader
()):
data_input_dict
=
prepare_batch_input
(
data
,
data_input_names
,
ModelHyperParams
.
eos_idx
,
ModelHyperParams
.
phone_pad_idx
,
ModelHyperParams
.
bos_idx
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_model
,
place
)
yield
[
data_input_dict
[
item
]
for
item
in
data_input_names
]
return
py_reader_provider
def
fast_infer
(
args
):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
out_ids
,
out_scores
,
pyreader
=
fast_decoder
(
ModelHyperParams
.
src_vocab_size
,
ModelHyperParams
.
trg_vocab_size
,
ModelHyperParams
.
phone_vocab_size
,
ModelHyperParams
.
max_length
+
1
,
ModelHyperParams
.
n_layer
,
ModelHyperParams
.
n_head
,
ModelHyperParams
.
d_key
,
ModelHyperParams
.
d_value
,
ModelHyperParams
.
d_model
,
ModelHyperParams
.
d_inner_hid
,
ModelHyperParams
.
prepostprocess_dropout
,
ModelHyperParams
.
attention_dropout
,
ModelHyperParams
.
relu_dropout
,
ModelHyperParams
.
preprocess_cmd
,
ModelHyperParams
.
postprocess_cmd
,
ModelHyperParams
.
weight_sharing
,
InferTaskConfig
.
beam_size
,
InferTaskConfig
.
max_out_len
,
ModelHyperParams
.
bos_idx
,
ModelHyperParams
.
eos_idx
,
beta
=
ModelHyperParams
.
beta
,
use_py_reader
=
args
.
use_py_reader
)
# This is used here to set dropout to the test mode.
infer_program
=
fluid
.
default_main_program
().
clone
(
for_test
=
True
)
if
args
.
use_mem_opt
:
fluid
.
memory_optimize
(
infer_program
)
if
InferTaskConfig
.
use_gpu
:
place
=
fluid
.
CUDAPlace
(
0
)
dev_count
=
fluid
.
core
.
get_cuda_device_count
()
else
:
place
=
fluid
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
fluid
.
io
.
load_vars
(
exe
,
InferTaskConfig
.
model_path
,
vars
=
[
var
for
var
in
infer_program
.
list_vars
()
if
isinstance
(
var
,
fluid
.
framework
.
Parameter
)
])
exec_strategy
=
fluid
.
ExecutionStrategy
()
# For faster executor
exec_strategy
.
use_experimental_executor
=
True
exec_strategy
.
num_threads
=
1
build_strategy
=
fluid
.
BuildStrategy
()
infer_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
TrainTaskConfig
.
use_gpu
,
main_program
=
infer_program
,
build_strategy
=
build_strategy
,
exec_strategy
=
exec_strategy
)
# data reader settings for inference
args
.
train_file_pattern
=
args
.
test_file_pattern
args
.
use_token_batch
=
False
args
.
sort_type
=
reader
.
SortType
.
NONE
args
.
shuffle
=
False
args
.
shuffle_batch
=
False
test_data
=
prepare_data_generator
(
args
,
is_test
=
False
,
count
=
dev_count
,
pyreader
=
pyreader
,
py_reader_provider_wrapper
=
py_reader_provider_wrapper
,
place
=
place
)
if
args
.
use_py_reader
:
pyreader
.
start
()
data_generator
=
None
else
:
data_generator
=
test_data
()
trg_idx2word
=
reader
.
DataReader
.
load_dict
(
dict_path
=
args
.
trg_vocab_fpath
,
reverse
=
True
)
while
True
:
try
:
feed_dict_list
=
prepare_feed_dict_list
(
data_generator
,
dev_count
,
place
)
if
args
.
use_parallel_exe
:
seq_ids
,
seq_scores
=
infer_exe
.
run
(
fetch_list
=
[
out_ids
.
name
,
out_scores
.
name
],
feed
=
feed_dict_list
,
return_numpy
=
False
)
else
:
seq_ids
,
seq_scores
=
exe
.
run
(
program
=
infer_program
,
fetch_list
=
[
out_ids
.
name
,
out_scores
.
name
],
feed
=
feed_dict_list
[
0
]
if
feed_dict_list
is
not
None
else
None
,
return_numpy
=
False
,
use_program_cache
=
True
)
seq_ids_list
,
seq_scores_list
=
[
seq_ids
],
[
seq_scores
]
if
isinstance
(
seq_ids
,
paddle
.
fluid
.
LoDTensor
)
else
(
seq_ids
,
seq_scores
)
for
seq_ids
,
seq_scores
in
zip
(
seq_ids_list
,
seq_scores_list
):
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps
=
[[]
for
i
in
range
(
len
(
seq_ids
.
lod
()[
0
])
-
1
)]
scores
=
[[]
for
i
in
range
(
len
(
seq_scores
.
lod
()[
0
])
-
1
)]
for
i
in
range
(
len
(
seq_ids
.
lod
()[
0
])
-
1
):
# for each source sentence
start
=
seq_ids
.
lod
()[
0
][
i
]
end
=
seq_ids
.
lod
()[
0
][
i
+
1
]
for
j
in
range
(
end
-
start
):
# for each candidate
sub_start
=
seq_ids
.
lod
()[
1
][
start
+
j
]
sub_end
=
seq_ids
.
lod
()[
1
][
start
+
j
+
1
]
hyps
[
i
].
append
(
" "
.
join
([
trg_idx2word
[
idx
]
for
idx
in
post_process_seq
(
np
.
array
(
seq_ids
)[
sub_start
:
sub_end
])
]))
scores
[
i
].
append
(
np
.
array
(
seq_scores
)[
sub_end
-
1
])
print
(
hyps
[
i
][
-
1
])
if
len
(
hyps
[
i
])
>=
InferTaskConfig
.
n_best
:
break
except
(
StopIteration
,
fluid
.
core
.
EOFException
):
# The data pass is over.
if
args
.
use_py_reader
:
pyreader
.
reset
()
break
if
__name__
==
"__main__"
:
args
=
parse_args
()
fast_infer
(
args
)
PaddleNLP/Research/ACL2019-JEMT/model.py
0 → 100644
浏览文件 @
67323767
此差异已折叠。
点击以展开。
PaddleNLP/Research/ACL2019-JEMT/reader.py
0 → 100644
浏览文件 @
67323767
import
glob
import
six
import
os
import
random
import
tarfile
import
numpy
as
np
class
SortType
(
object
):
GLOBAL
=
'global'
POOL
=
'pool'
NONE
=
"none"
class
SrcConverter
(
object
):
def
__init__
(
self
,
vocab
,
end
,
unk
,
delimiter
,
lexicon
):
self
.
_vocab
=
vocab
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
self
.
_lexicon
=
lexicon
def
__call__
(
self
,
sentence
):
src_seqs
=
[]
src_ph_seqs
=
[]
unk_phs
=
self
.
_lexicon
[
'<unk>'
]
for
w
in
sentence
.
split
(
self
.
_delimiter
):
src_seqs
.
append
(
self
.
_vocab
.
get
(
w
,
self
.
_unk
))
ph_groups
=
self
.
_lexicon
.
get
(
w
,
unk_phs
)
src_ph_seqs
.
append
(
random
.
choice
(
ph_groups
))
src_seqs
.
append
(
self
.
_end
)
src_ph_seqs
.
append
(
unk_phs
[
0
])
return
src_seqs
,
src_ph_seqs
class
TgtConverter
(
object
):
def
__init__
(
self
,
vocab
,
beg
,
end
,
unk
,
delimiter
):
self
.
_vocab
=
vocab
self
.
_beg
=
beg
self
.
_end
=
end
self
.
_unk
=
unk
self
.
_delimiter
=
delimiter
def
__call__
(
self
,
sentence
):
return
[
self
.
_beg
]
+
[
self
.
_vocab
.
get
(
w
,
self
.
_unk
)
for
w
in
sentence
.
split
(
self
.
_delimiter
)
]
+
[
self
.
_end
]
class
ComposedConverter
(
object
):
def
__init__
(
self
,
converters
):
self
.
_converters
=
converters
def
__call__
(
self
,
parallel_sentence
):
return
[
self
.
_converters
[
i
](
parallel_sentence
[
i
])
for
i
in
range
(
len
(
self
.
_converters
))
]
class
SentenceBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
self
.
batch
.
append
(
info
)
if
len
(
self
.
batch
)
==
self
.
_batch_size
:
tmp
=
self
.
batch
self
.
batch
=
[]
return
tmp
class
TokenBatchCreator
(
object
):
def
__init__
(
self
,
batch_size
):
self
.
batch
=
[]
self
.
max_len
=
-
1
self
.
_batch_size
=
batch_size
def
append
(
self
,
info
):
cur_len
=
info
.
max_len
max_len
=
max
(
self
.
max_len
,
cur_len
)
if
max_len
*
(
len
(
self
.
batch
)
+
1
)
>
self
.
_batch_size
:
result
=
self
.
batch
self
.
batch
=
[
info
]
self
.
max_len
=
cur_len
return
result
else
:
self
.
max_len
=
max_len
self
.
batch
.
append
(
info
)
class
SampleInfo
(
object
):
def
__init__
(
self
,
i
,
max_len
,
min_len
):
self
.
i
=
i
self
.
min_len
=
min_len
self
.
max_len
=
max_len
class
MinMaxFilter
(
object
):
def
__init__
(
self
,
max_len
,
min_len
,
underlying_creator
):
self
.
_min_len
=
min_len
self
.
_max_len
=
max_len
self
.
_creator
=
underlying_creator
def
append
(
self
,
info
):
if
info
.
max_len
>
self
.
_max_len
or
info
.
min_len
<
self
.
_min_len
:
return
else
:
return
self
.
_creator
.
append
(
info
)
@
property
def
batch
(
self
):
return
self
.
_creator
.
batch
class
DataReader
(
object
):
"""
The data reader loads all data from files and produces batches of data
in the way corresponding to settings.
An example of returning a generator producing data batches whose data
is shuffled in each pass and sorted in each pool:
```
train_data = DataReader(
src_vocab_fpath='data/src_vocab_file',
trg_vocab_fpath='data/trg_vocab_file',
fpattern='data/part-*',
use_token_batch=True,
batch_size=2000,
pool_size=10000,
sort_type=SortType.POOL,
shuffle=True,
shuffle_batch=True,
start_mark='<s>',
end_mark='<e>',
unk_mark='<unk>',
clip_last_batch=False).batch_generator
```
:param src_vocab_fpath: The path of vocabulary file of source language.
:type src_vocab_fpath: basestring
:param trg_vocab_fpath: The path of vocabulary file of target language.
:type trg_vocab_fpath: basestring
:param fpattern: The pattern to match data files.
:type fpattern: basestring
:param batch_size: The number of sequences contained in a mini-batch.
or the maximum number of tokens (include paddings) contained in a
mini-batch.
:type batch_size: int
:param pool_size: The size of pool buffer.
:type pool_size: int
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param clip_last_batch: Whether to clip the last uncompleted batch.
:type clip_last_batch: bool
:param tar_fname: The data file in tar if fpattern matches a tar file.
:type tar_fname: basestring
:param min_length: The minimum length used to filt sequences.
:type min_length: int
:param max_length: The maximum length used to filt sequences.
:type max_length: int
:param shuffle: Whether to shuffle all instances.
:type shuffle: bool
:param shuffle_batch: Whether to shuffle the generated batches.
:type shuffle_batch: bool
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param field_delimiter: The delimiter used to split source and target in
each line of data file.
:type field_delimiter: basestring
:param token_delimiter: The delimiter used to split tokens in source or
target sentences.
:type token_delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
:param end_mark: The token representing for the end of sentences
in dictionary.
:type end_mark: basestring
:param unk_mark: The token representing for unknown word in dictionary.
:type unk_mark: basestring
:param seed: The seed for random.
:type seed: int
"""
def
__init__
(
self
,
src_vocab_fpath
,
trg_vocab_fpath
,
fpattern
,
phoneme_vocab_fpath
,
lexicon_fpath
,
batch_size
,
pool_size
,
sort_type
=
SortType
.
GLOBAL
,
clip_last_batch
=
True
,
tar_fname
=
None
,
min_length
=
0
,
max_length
=
100
,
shuffle
=
True
,
shuffle_batch
=
False
,
use_token_batch
=
False
,
field_delimiter
=
"
\t
"
,
token_delimiter
=
" "
,
start_mark
=
"<s>"
,
end_mark
=
"<e>"
,
unk_mark
=
"<unk>"
,
seed
=
0
):
self
.
_src_vocab
=
self
.
load_dict
(
src_vocab_fpath
)
self
.
_only_src
=
True
if
trg_vocab_fpath
is
not
None
:
self
.
_trg_vocab
=
self
.
load_dict
(
trg_vocab_fpath
)
self
.
_only_src
=
False
self
.
_phoneme_vocab
=
self
.
load_dict
(
phoneme_vocab_fpath
)
self
.
_lexicon
=
self
.
load_lexicon
(
lexicon_fpath
,
self
.
_phoneme_vocab
)
self
.
_pool_size
=
pool_size
self
.
_batch_size
=
batch_size
self
.
_use_token_batch
=
use_token_batch
self
.
_sort_type
=
sort_type
self
.
_clip_last_batch
=
clip_last_batch
self
.
_shuffle
=
shuffle
self
.
_shuffle_batch
=
shuffle_batch
self
.
_min_length
=
min_length
self
.
_max_length
=
max_length
self
.
_field_delimiter
=
field_delimiter
self
.
_token_delimiter
=
token_delimiter
self
.
load_src_trg_ids
(
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
unk_mark
)
self
.
_random
=
np
.
random
self
.
_random
.
seed
(
seed
)
def
load_lexicon
(
self
,
lexicon_path
,
phoneme_vocab
):
lexicon
=
{}
with
open
(
lexicon_path
)
as
fp
:
for
line
in
fp
:
tokens
=
line
.
strip
().
split
()
word
=
tokens
[
0
]
all_phone_str
=
' '
.
join
(
tokens
[
1
:])
phone_strs
=
all_phone_str
.
split
(
'|'
)
phone_groups
=
[]
for
phone_str
in
phone_strs
:
cur_phone_seq
=
[
phoneme_vocab
[
x
]
for
x
in
phone_str
.
split
()
]
phone_groups
.
append
(
cur_phone_seq
)
lexicon
[
word
]
=
phone_groups
lexicon
[
'<unk>'
]
=
[[
phoneme_vocab
[
'<unk>'
]]]
return
lexicon
def
load_src_trg_ids
(
self
,
end_mark
,
fpattern
,
start_mark
,
tar_fname
,
unk_mark
):
converters
=
[
SrcConverter
(
vocab
=
self
.
_src_vocab
,
end
=
self
.
_src_vocab
[
end_mark
],
unk
=
self
.
_src_vocab
[
unk_mark
],
delimiter
=
self
.
_token_delimiter
,
lexicon
=
self
.
_lexicon
)
]
if
not
self
.
_only_src
:
converters
.
append
(
TgtConverter
(
vocab
=
self
.
_trg_vocab
,
beg
=
self
.
_trg_vocab
[
start_mark
],
end
=
self
.
_trg_vocab
[
end_mark
],
unk
=
self
.
_trg_vocab
[
unk_mark
],
delimiter
=
self
.
_token_delimiter
))
converters
=
ComposedConverter
(
converters
)
self
.
_src_seq_ids
=
[]
self
.
_src_phone_ids
=
[]
self
.
_trg_seq_ids
=
None
if
self
.
_only_src
else
[]
self
.
_sample_infos
=
[]
for
i
,
line
in
enumerate
(
self
.
_load_lines
(
fpattern
,
tar_fname
)):
src_trg_ids
=
converters
(
line
)
self
.
_src_seq_ids
.
append
(
src_trg_ids
[
0
][
0
])
self
.
_src_phone_ids
.
append
(
src_trg_ids
[
0
][
1
])
lens
=
[
len
(
src_trg_ids
[
0
][
0
])]
if
not
self
.
_only_src
:
self
.
_trg_seq_ids
.
append
(
src_trg_ids
[
1
])
lens
.
append
(
len
(
src_trg_ids
[
1
]))
self
.
_sample_infos
.
append
(
SampleInfo
(
i
,
max
(
lens
),
min
(
lens
)))
def
_load_lines
(
self
,
fpattern
,
tar_fname
):
fpaths
=
glob
.
glob
(
fpattern
)
if
len
(
fpaths
)
==
1
and
tarfile
.
is_tarfile
(
fpaths
[
0
]):
if
tar_fname
is
None
:
raise
Exception
(
"If tar file provided, please set tar_fname."
)
f
=
tarfile
.
open
(
fpaths
[
0
],
"r"
)
for
line
in
f
.
extractfile
(
tar_fname
):
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
else
:
for
fpath
in
fpaths
:
if
not
os
.
path
.
isfile
(
fpath
):
raise
IOError
(
"Invalid file: %s"
%
fpath
)
with
open
(
fpath
,
"rb"
)
as
f
:
for
line
in
f
:
if
six
.
PY3
:
line
=
line
.
decode
()
fields
=
line
.
strip
(
"
\n
"
).
split
(
self
.
_field_delimiter
)
if
(
not
self
.
_only_src
and
len
(
fields
)
==
2
)
or
(
self
.
_only_src
and
len
(
fields
)
==
1
):
yield
fields
@
staticmethod
def
load_dict
(
dict_path
,
reverse
=
False
):
word_dict
=
{}
with
open
(
dict_path
,
"rb"
)
as
fdict
:
for
idx
,
line
in
enumerate
(
fdict
):
if
six
.
PY3
:
line
=
line
.
decode
()
if
reverse
:
word_dict
[
idx
]
=
line
.
strip
(
"
\n
"
)
else
:
word_dict
[
line
.
strip
(
"
\n
"
)]
=
idx
return
word_dict
def
batch_generator
(
self
):
# global sort or global shuffle
if
self
.
_sort_type
==
SortType
.
GLOBAL
:
infos
=
sorted
(
self
.
_sample_infos
,
key
=
lambda
x
:
x
.
max_len
)
else
:
if
self
.
_shuffle
:
infos
=
self
.
_sample_infos
self
.
_random
.
shuffle
(
infos
)
else
:
infos
=
self
.
_sample_infos
if
self
.
_sort_type
==
SortType
.
POOL
:
reverse
=
True
for
i
in
range
(
0
,
len
(
infos
),
self
.
_pool_size
):
# to avoid placing short next to long sentences
reverse
=
not
reverse
infos
[
i
:
i
+
self
.
_pool_size
]
=
sorted
(
infos
[
i
:
i
+
self
.
_pool_size
],
key
=
lambda
x
:
x
.
max_len
,
reverse
=
reverse
)
# concat batch
batches
=
[]
batch_creator
=
TokenBatchCreator
(
self
.
_batch_size
)
if
self
.
_use_token_batch
else
SentenceBatchCreator
(
self
.
_batch_size
)
batch_creator
=
MinMaxFilter
(
self
.
_max_length
,
self
.
_min_length
,
batch_creator
)
for
info
in
infos
:
batch
=
batch_creator
.
append
(
info
)
if
batch
is
not
None
:
batches
.
append
(
batch
)
if
not
self
.
_clip_last_batch
and
len
(
batch_creator
.
batch
)
!=
0
:
batches
.
append
(
batch_creator
.
batch
)
if
self
.
_shuffle_batch
:
self
.
_random
.
shuffle
(
batches
)
for
batch
in
batches
:
batch_ids
=
[
info
.
i
for
info
in
batch
]
if
self
.
_only_src
:
yield
[[(
self
.
_src_seq_ids
[
idx
],
self
.
_src_phone_ids
[
idx
])]
for
idx
in
batch_ids
]
else
:
yield
[(
self
.
_src_seq_ids
[
idx
],
self
.
_src_phone_ids
[
idx
],
self
.
_trg_seq_ids
[
idx
][:
-
1
],
self
.
_trg_seq_ids
[
idx
][
1
:])
for
idx
in
batch_ids
]
PaddleNLP/Research/ACL2019-JEMT/train.py
0 → 100644
浏览文件 @
67323767
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录