提交 67323767 编写于 作者: A AndyELiu 提交者: Yibing Liu

submit code for joint embedding paper (#2896)

上级 af1197ed
## 简介
### 任务说明
机器翻译的输入一般是源语言的句子。但在很多实际系统中,比如语音识别系统的输出或者基于拼音的文字输入,源语言句子一般包含很多同音字错误, 这会导致翻译出现很多意想不到的错误。由于可以同时获得发音信息,我们提出了一种在输入端加入发音信息,进而在模型的嵌入层
融合文字信息和发音信息的翻译方法,大大提高了翻译模型对同音字错误的抵抗能力。
文章地址:https://arxiv.org/abs/1810.06729
### 效果说明
我们使用LDC Chinese-to-English数据集训练。中文词典用的是[DaCiDian](https://github.com/aishell-foundation/DaCiDian)。 在newstest2006上进行评测,效果如下所示:
| beta=0 | beta=0.50 | beta=0.85 | beta=0.95 |
|-|-|-|-|
| 47.96 | 48.71 | 48.85 | 48.46 |
beta代表发音信息的权重。这表明,即使将绝大部分权重放在发音信息上,翻译的效果依然很好。与此同时,翻译系统对同音字错误的抵抗力大大提高。
## 安装说明
1. paddle安装
本项目依赖于 PaddlePaddle Fluid 1.3.1 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装
2. 环境依赖
请参考PaddlePaddle[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html)部分的内容
## 如何训练
1. 数据格式
数据格式和[Paddle机器翻译](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer)的格式一致。为了获得输入句子的发音信息,需要额外提供源语言的发音基本单元和发音的词典。
A) 发音基本单元文件
中文的发音基本单元是拼音,将所有的拼音放在一个文件,类似:
<unk>
bo
li
。。。
B)发音词典
根据DaCiDian,对bpe后的源语言中的token赋予一个或者几个发音,类似:
▁玻利维亚 bo li wei ya
▁举行 ju xing
▁总统 zong tong
▁与 yu
巴斯 ba si
▁这个 zhei ge|zhe ge
。。。
2. 训练模型
数据准备完成后,可以使用 `train.py` 脚本进行训练。例子如下:
```sh
python train.py \
--src_vocab_fpath nist_data/vocab_all.28000 \
--trg_vocab_fpath nist_data/vocab_all.28000 \
--train_file_pattern nist_data/nist_train.txt \
--phoneme_vocab_fpath nist_data/zh_pinyins.txt \
--lexicon_fpath nist_data/zh_lexicon.txt \
--batch_size 2048 \
--use_token_batch True \
--sort_type pool \
--pool_size 200000 \
--use_py_reader False \
--use_mem_opt False \
--enable_ce False \
--fetch_steps 1 \
pass_num 100 \
learning_rate 2.0 \
warmup_steps 8000 \
beta2 0.997 \
d_model 512 \
d_inner_hid 2048 \
n_head 8 \
weight_sharing True \
max_length 256 \
save_freq 10000 \
beta 0.85 \
model_dir pinyin_models_beta085 \
ckpt_dir pinyin_ckpts_beta085
```
上述命令中设置了源语言词典文件路径(`src_vocab_fpath`)、目标语言词典文件路径(`trg_vocab_fpath`)、训练数据文件(`train_file_pattern`,支持通配符), 发音单元文件路径(`phoneme_vocab_fpath`), 发音词典路径(`lexicon_fpath`)等数据相关的参数和构造 batch 方式(`use_token_batch` 指定了数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看:
```sh
python train.py --help
```
更多模型训练相关的参数则在 `config.py` 中的 `ModelHyperParams``TrainTaskConfig` 内定义;`ModelHyperParams` 定义了 embedding 维度等模型超参数,`TrainTaskConfig` 定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖 `config.py` 中的配置.
注意,如训练时更改了模型配置,使用 `infer.py` 预测时需要使用对应相同的模型配置;另外,训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用指定的 GPU。
## 如何预测
使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到标准输出:
```sh
python infer.py \
--src_vocab_fpath nist_data/vocab_all.28000 \
--trg_vocab_fpath nist_data/vocab_all.28000 \
--test_file_pattern nist_data/nist_test.txt \
--phoneme_vocab_fpath nist_data/zh_pinyins.txt \
--lexicon_fpath nist_data/zh_lexicon.txt \
--batch_size 32 \
model_path pinyin_models_beta085/iter_200000.infer.model \
beam_size 5 \
max_out_len 255 \
beta 0.85
```
class TrainTaskConfig(object):
# support both CPU and GPU now.
use_gpu = True
# the epoch number to train.
pass_num = 30
# the number of sequences contained in a mini-batch.
# deprecated, set batch_size in args.
batch_size = 32
# the hyper parameters for Adam optimizer.
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 2.0
beta1 = 0.9
beta2 = 0.997
eps = 1e-9
# the parameters for learning rate scheduling.
warmup_steps = 8000
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
# the directory for saving trained models.
model_dir = "trained_models"
# the directory for saving checkpoints.
ckpt_dir = "trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path = None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step = 0
# the frequency to save trained models.
save_freq = 10000
class InferTaskConfig(object):
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 10
# the parameters for beam search.
beam_size = 5
max_out_len = 256
# the number of decoded sentences to output.
n_best = 1
# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = True
# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"
class ModelHyperParams(object):
# These following five vocabularies related configurations will be set
# automatically according to the passed vocabulary path and special tokens.
# size of source word dictionary.
src_vocab_size = 10000
# size of target word dictionay
trg_vocab_size = 10000
# size of phone dictionary
phone_vocab_size = 1000
# ratio of phoneme embeddings
beta = 0.0
# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2
# index for <unk> in phonemes
phone_pad_idx = 0
# max length of sequences deciding the size of position encoding table.
max_length = 256
# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.
d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 2048
# the dimension that keys are projected to for dot-product attention.
d_key = 64
# the dimension that values are projected to for dot-product attention.
d_value = 64
# number of head used in multi-head attention.
n_head = 8
# number of sub-layers to be stacked in the encoder and decoder.
n_layer = 6
# dropout rates of different modules.
prepostprocess_dropout = 0.1
attention_dropout = 0.1
relu_dropout = 0.1
# to process before each sub-layer
preprocess_cmd = "n" # layer normalization
# to process after each sub-layer
postprocess_cmd = "da" # dropout + residual connection
# random seed used in dropout for CE.
dropout_seed = None
# the flag indicating whether to share embedding and softmax weights.
# vocabularies in source and target should be same for weight sharing.
weight_sharing = True
def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except Exception: # for file path
pass
setattr(g_cfg, key, value)
break
# The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder.
batch_size = -1
# The placeholder for squence length in compile time.
seq_len = 256
# The placeholder for phoneme sequence length in comiple time.
phone_len = 16
# The placeholder for head number in compile time.
n_head = 8
# The placeholder for model dim in compile time.
d_model = 512
# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size, max_src_len_in_batch, 1]
"src_word": [(batch_size, seq_len, 1), "int64", 2],
# The actual data shape of src_pos is:
# [batch_size, max_src_len_in_batch, 1]
"src_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
"src_phone": [(batch_size, seq_len, phone_len, 1), "int64"],
"src_phone_mask": [(batch_size, seq_len, phone_len), "int64"],
# The actual data shape of trg_word is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_word": [(batch_size, seq_len, 1), "int64",
2], # lod_level is only used in fast decoder.
# The actual data shape of trg_pos is:
# [batch_size, max_trg_len_in_batch, 1]
"trg_pos": [(batch_size, seq_len, 1), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(batch_size, seq_len, d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(batch_size * seq_len, 1), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(batch_size * seq_len, 1), "float32"],
# This input is used in beam-search decoder.
"init_score": [(batch_size, 1), "float32", 2],
# This input is used in beam-search decoder for the first gather
# (cell states updation)
"init_idx": [(batch_size, ), "int32"],
}
# Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = (
"src_word_emb_table",
"trg_word_emb_table",
)
phone_emb_param_name = "phone_emb_table"
# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table",
)
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"src_phone",
"src_phone_mask",
)
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output",
)
label_data_input_fields = (
"lbl_word",
"lbl_weight",
)
# In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = (
"trg_word",
"init_score",
"init_idx",
"trg_src_attn_bias",
)
# Set seed for CE
dropout_seed = None
import argparse
import ast
import multiprocessing
import numpy as np
import os
import sys
from functools import partial
import paddle
import paddle.fluid as fluid
import reader
from config import *
from desc import *
from model import fast_decode as fast_decoder
from train import pad_batch_data, pad_phoneme_data, prepare_data_generator
def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
parser.add_argument(
"--src_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of source language.")
parser.add_argument(
"--trg_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of target language.")
parser.add_argument(
"--phoneme_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of phonemes.")
parser.add_argument(
"--lexicon_fpath",
type=str,
required=True,
help="The path of lexicon of source language.")
parser.add_argument(
"--test_file_pattern",
type=str,
required=True,
help="The pattern to match test data files.")
parser.add_argument(
"--batch_size",
type=int,
default=50,
help="The number of examples in one run for sequence generation.")
parser.add_argument(
"--pool_size",
type=int,
default=10000,
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
"--token_delimiter",
type=lambda x: str(x.encode().decode("unicode-escape")),
default=" ",
help="The delimiter used to split tokens in source or target sentences. "
"For EN-DE BPE data we provided, use spaces as token delimiter. ")
parser.add_argument(
"--use_mem_opt",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use memory optimization.")
parser.add_argument(
"--use_py_reader",
type=ast.literal_eval,
default=True,
help="The flag indicating whether to use py_reader.")
parser.add_argument(
"--use_parallel_exe",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use ParallelExecutor.")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
# Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath)
dict_args = [
"src_vocab_size",
str(len(src_dict)), "trg_vocab_size",
str(len(trg_dict)), "phone_vocab_size",
str(len(phone_dict)), "bos_idx",
str(src_dict[args.special_token[0]]), "eos_idx",
str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]])
]
merge_cfg_from_list(args.opts + dict_args,
[InferTaskConfig, ModelHyperParams])
return args
def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = [
idx for idx in seq[:eos_pos + 1]
if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx)
]
return seq
def prepare_batch_input(insts, data_input_names, src_pad_idx, phone_pad_idx,
bos_idx, n_head, d_model, place):
"""
Put all padded data needed by beam search decoder into a dict.
"""
src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data(
[inst[0] for inst in insts], src_pad_idx, n_head, is_target=False)
src_word = src_word.reshape(-1, src_max_len, 1)
src_pos = src_pos.reshape(-1, src_max_len, 1)
src_phone, src_phone_mask, max_phone_len = pad_phoneme_data(
[inst[1] for inst in insts], phone_pad_idx, src_max_len)
# start tokens
trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64")
trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :],
[1, 1, 1, 1]).astype("float32")
trg_word = trg_word.reshape(-1, 1, 1)
def to_lodtensor(data, place, lod=None):
data_tensor = fluid.LoDTensor()
data_tensor.set(data, place)
if lod is not None:
data_tensor.set_lod(lod)
return data_tensor
# beamsearch_op must use tensors with lod
init_score = to_lodtensor(
np.zeros_like(trg_word, dtype="float32").reshape(-1, 1), place,
[range(trg_word.shape[0] + 1)] * 2)
trg_word = to_lodtensor(trg_word, place,
[range(trg_word.shape[0] + 1)] * 2)
init_idx = np.asarray(range(len(insts)), dtype="int32")
data_input_dict = dict(
zip(data_input_names, [
src_word, src_pos, src_slf_attn_bias, src_phone, src_phone_mask,
trg_word, init_score, init_idx, trg_src_attn_bias
]))
return data_input_dict
def prepare_feed_dict_list(data_generator, count, place):
"""
Prepare the list of feed dict for multi-devices.
"""
feed_dict_list = []
if data_generator is not None: # use_py_reader == False
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
data = next(data_generator)
for idx, data_buffer in enumerate(data):
data_input_dict = prepare_batch_input(
data_buffer, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.phone_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
feed_dict_list.append(data_input_dict)
return feed_dict_list if len(feed_dict_list) == count else None
def py_reader_provider_wrapper(data_reader, place):
"""
Data provider needed by fluid.layers.py_reader.
"""
def py_reader_provider():
data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields
for batch_id, data in enumerate(data_reader()):
data_input_dict = prepare_batch_input(
data, data_input_names, ModelHyperParams.eos_idx,
ModelHyperParams.phone_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model, place)
yield [data_input_dict[item] for item in data_input_names]
return py_reader_provider
def fast_infer(args):
"""
Inference by beam search decoder based solely on Fluid operators.
"""
out_ids, out_scores, pyreader = fast_decoder(
ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size,
ModelHyperParams.phone_vocab_size,
ModelHyperParams.max_length + 1,
ModelHyperParams.n_layer,
ModelHyperParams.n_head,
ModelHyperParams.d_key,
ModelHyperParams.d_value,
ModelHyperParams.d_model,
ModelHyperParams.d_inner_hid,
ModelHyperParams.prepostprocess_dropout,
ModelHyperParams.attention_dropout,
ModelHyperParams.relu_dropout,
ModelHyperParams.preprocess_cmd,
ModelHyperParams.postprocess_cmd,
ModelHyperParams.weight_sharing,
InferTaskConfig.beam_size,
InferTaskConfig.max_out_len,
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
beta=ModelHyperParams.beta,
use_py_reader=args.use_py_reader)
# This is used here to set dropout to the test mode.
infer_program = fluid.default_main_program().clone(for_test=True)
if args.use_mem_opt:
fluid.memory_optimize(infer_program)
if InferTaskConfig.use_gpu:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_vars(
exe,
InferTaskConfig.model_path,
vars=[
var for var in infer_program.list_vars()
if isinstance(var, fluid.framework.Parameter)
])
exec_strategy = fluid.ExecutionStrategy()
# For faster executor
exec_strategy.use_experimental_executor = True
exec_strategy.num_threads = 1
build_strategy = fluid.BuildStrategy()
infer_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu,
main_program=infer_program,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
# data reader settings for inference
args.train_file_pattern = args.test_file_pattern
args.use_token_batch = False
args.sort_type = reader.SortType.NONE
args.shuffle = False
args.shuffle_batch = False
test_data = prepare_data_generator(
args,
is_test=False,
count=dev_count,
pyreader=pyreader,
py_reader_provider_wrapper=py_reader_provider_wrapper,
place=place)
if args.use_py_reader:
pyreader.start()
data_generator = None
else:
data_generator = test_data()
trg_idx2word = reader.DataReader.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)
while True:
try:
feed_dict_list = prepare_feed_dict_list(data_generator, dev_count,
place)
if args.use_parallel_exe:
seq_ids, seq_scores = infer_exe.run(
fetch_list=[out_ids.name, out_scores.name],
feed=feed_dict_list,
return_numpy=False)
else:
seq_ids, seq_scores = exe.run(
program=infer_program,
fetch_list=[out_ids.name, out_scores.name],
feed=feed_dict_list[0]
if feed_dict_list is not None else None,
return_numpy=False,
use_program_cache=True)
seq_ids_list, seq_scores_list = [
seq_ids
], [seq_scores] if isinstance(
seq_ids, paddle.fluid.LoDTensor) else (seq_ids, seq_scores)
for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list):
# How to parse the results:
# Suppose the lod of seq_ids is:
# [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]]
# then from lod[0]:
# there are 2 source sentences, beam width is 3.
# from lod[1]:
# the first source sentence has 3 hyps; the lengths are 12, 12, 16
# the second source sentence has 3 hyps; the lengths are 14, 13, 15
hyps = [[] for i in range(len(seq_ids.lod()[0]) - 1)]
scores = [[] for i in range(len(seq_scores.lod()[0]) - 1)]
for i in range(len(seq_ids.lod()[0]) -
1): # for each source sentence
start = seq_ids.lod()[0][i]
end = seq_ids.lod()[0][i + 1]
for j in range(end - start): # for each candidate
sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([
trg_idx2word[idx] for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end])
]))
scores[i].append(np.array(seq_scores)[sub_end - 1])
print(hyps[i][-1])
if len(hyps[i]) >= InferTaskConfig.n_best:
break
except (StopIteration, fluid.core.EOFException):
# The data pass is over.
if args.use_py_reader:
pyreader.reset()
break
if __name__ == "__main__":
args = parse_args()
fast_infer(args)
此差异已折叠。
import glob
import six
import os
import random
import tarfile
import numpy as np
class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"
class SrcConverter(object):
def __init__(self, vocab, end, unk, delimiter, lexicon):
self._vocab = vocab
self._end = end
self._unk = unk
self._delimiter = delimiter
self._lexicon = lexicon
def __call__(self, sentence):
src_seqs = []
src_ph_seqs = []
unk_phs = self._lexicon['<unk>']
for w in sentence.split(self._delimiter):
src_seqs.append(self._vocab.get(w, self._unk))
ph_groups = self._lexicon.get(w, unk_phs)
src_ph_seqs.append(random.choice(ph_groups))
src_seqs.append(self._end)
src_ph_seqs.append(unk_phs[0])
return src_seqs, src_ph_seqs
class TgtConverter(object):
def __init__(self, vocab, beg, end, unk, delimiter):
self._vocab = vocab
self._beg = beg
self._end = end
self._unk = unk
self._delimiter = delimiter
def __call__(self, sentence):
return [self._beg] + [
self._vocab.get(w, self._unk)
for w in sentence.split(self._delimiter)
] + [self._end]
class ComposedConverter(object):
def __init__(self, converters):
self._converters = converters
def __call__(self, parallel_sentence):
return [
self._converters[i](parallel_sentence[i])
for i in range(len(self._converters))
]
class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size
def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp
class TokenBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self.max_len = -1
self._batch_size = batch_size
def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self.batch) + 1) > self._batch_size:
result = self.batch
self.batch = [info]
self.max_len = cur_len
return result
else:
self.max_len = max_len
self.batch.append(info)
class SampleInfo(object):
def __init__(self, i, max_len, min_len):
self.i = i
self.min_len = min_len
self.max_len = max_len
class MinMaxFilter(object):
def __init__(self, max_len, min_len, underlying_creator):
self._min_len = min_len
self._max_len = max_len
self._creator = underlying_creator
def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len:
return
else:
return self._creator.append(info)
@property
def batch(self):
return self._creator.batch
class DataReader(object):
"""
The data reader loads all data from files and produces batches of data
in the way corresponding to settings.
An example of returning a generator producing data batches whose data
is shuffled in each pass and sorted in each pool:
```
train_data = DataReader(
src_vocab_fpath='data/src_vocab_file',
trg_vocab_fpath='data/trg_vocab_file',
fpattern='data/part-*',
use_token_batch=True,
batch_size=2000,
pool_size=10000,
sort_type=SortType.POOL,
shuffle=True,
shuffle_batch=True,
start_mark='<s>',
end_mark='<e>',
unk_mark='<unk>',
clip_last_batch=False).batch_generator
```
:param src_vocab_fpath: The path of vocabulary file of source language.
:type src_vocab_fpath: basestring
:param trg_vocab_fpath: The path of vocabulary file of target language.
:type trg_vocab_fpath: basestring
:param fpattern: The pattern to match data files.
:type fpattern: basestring
:param batch_size: The number of sequences contained in a mini-batch.
or the maximum number of tokens (include paddings) contained in a
mini-batch.
:type batch_size: int
:param pool_size: The size of pool buffer.
:type pool_size: int
:param sort_type: The grain to sort by length: 'global' for all
instances; 'pool' for instances in pool; 'none' for no sort.
:type sort_type: basestring
:param clip_last_batch: Whether to clip the last uncompleted batch.
:type clip_last_batch: bool
:param tar_fname: The data file in tar if fpattern matches a tar file.
:type tar_fname: basestring
:param min_length: The minimum length used to filt sequences.
:type min_length: int
:param max_length: The maximum length used to filt sequences.
:type max_length: int
:param shuffle: Whether to shuffle all instances.
:type shuffle: bool
:param shuffle_batch: Whether to shuffle the generated batches.
:type shuffle_batch: bool
:param use_token_batch: Whether to produce batch data according to
token number.
:type use_token_batch: bool
:param field_delimiter: The delimiter used to split source and target in
each line of data file.
:type field_delimiter: basestring
:param token_delimiter: The delimiter used to split tokens in source or
target sentences.
:type token_delimiter: basestring
:param start_mark: The token representing for the beginning of
sentences in dictionary.
:type start_mark: basestring
:param end_mark: The token representing for the end of sentences
in dictionary.
:type end_mark: basestring
:param unk_mark: The token representing for unknown word in dictionary.
:type unk_mark: basestring
:param seed: The seed for random.
:type seed: int
"""
def __init__(self,
src_vocab_fpath,
trg_vocab_fpath,
fpattern,
phoneme_vocab_fpath,
lexicon_fpath,
batch_size,
pool_size,
sort_type=SortType.GLOBAL,
clip_last_batch=True,
tar_fname=None,
min_length=0,
max_length=100,
shuffle=True,
shuffle_batch=False,
use_token_batch=False,
field_delimiter="\t",
token_delimiter=" ",
start_mark="<s>",
end_mark="<e>",
unk_mark="<unk>",
seed=0):
self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True
if trg_vocab_fpath is not None:
self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._only_src = False
self._phoneme_vocab = self.load_dict(phoneme_vocab_fpath)
self._lexicon = self.load_lexicon(lexicon_fpath, self._phoneme_vocab)
self._pool_size = pool_size
self._batch_size = batch_size
self._use_token_batch = use_token_batch
self._sort_type = sort_type
self._clip_last_batch = clip_last_batch
self._shuffle = shuffle
self._shuffle_batch = shuffle_batch
self._min_length = min_length
self._max_length = max_length
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname,
unk_mark)
self._random = np.random
self._random.seed(seed)
def load_lexicon(self, lexicon_path, phoneme_vocab):
lexicon = {}
with open(lexicon_path) as fp:
for line in fp:
tokens = line.strip().split()
word = tokens[0]
all_phone_str = ' '.join(tokens[1:])
phone_strs = all_phone_str.split('|')
phone_groups = []
for phone_str in phone_strs:
cur_phone_seq = [
phoneme_vocab[x] for x in phone_str.split()
]
phone_groups.append(cur_phone_seq)
lexicon[word] = phone_groups
lexicon['<unk>'] = [[phoneme_vocab['<unk>']]]
return lexicon
def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname,
unk_mark):
converters = [
SrcConverter(
vocab=self._src_vocab,
end=self._src_vocab[end_mark],
unk=self._src_vocab[unk_mark],
delimiter=self._token_delimiter,
lexicon=self._lexicon)
]
if not self._only_src:
converters.append(
TgtConverter(
vocab=self._trg_vocab,
beg=self._trg_vocab[start_mark],
end=self._trg_vocab[end_mark],
unk=self._trg_vocab[unk_mark],
delimiter=self._token_delimiter))
converters = ComposedConverter(converters)
self._src_seq_ids = []
self._src_phone_ids = []
self._trg_seq_ids = None if self._only_src else []
self._sample_infos = []
for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
src_trg_ids = converters(line)
self._src_seq_ids.append(src_trg_ids[0][0])
self._src_phone_ids.append(src_trg_ids[0][1])
lens = [len(src_trg_ids[0][0])]
if not self._only_src:
self._trg_seq_ids.append(src_trg_ids[1])
lens.append(len(src_trg_ids[1]))
self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
def _load_lines(self, fpattern, tar_fname):
fpaths = glob.glob(fpattern)
if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]):
if tar_fname is None:
raise Exception("If tar file provided, please set tar_fname.")
f = tarfile.open(fpaths[0], "r")
for line in f.extractfile(tar_fname):
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src
and len(fields) == 2) or (self._only_src
and len(fields) == 1):
yield fields
else:
for fpath in fpaths:
if not os.path.isfile(fpath):
raise IOError("Invalid file: %s" % fpath)
with open(fpath, "rb") as f:
for line in f:
if six.PY3:
line = line.decode()
fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src and len(fields) == 2) or (
self._only_src and len(fields) == 1):
yield fields
@staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {}
with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict):
if six.PY3:
line = line.decode()
if reverse:
word_dict[idx] = line.strip("\n")
else:
word_dict[line.strip("\n")] = idx
return word_dict
def batch_generator(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._sample_infos, key=lambda x: x.max_len)
else:
if self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else:
infos = self._sample_infos
if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# to avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.max_len,
reverse=reverse)
# concat batch
batches = []
batch_creator = TokenBatchCreator(
self._batch_size
) if self._use_token_batch else SentenceBatchCreator(self._batch_size)
batch_creator = MinMaxFilter(self._max_length, self._min_length,
batch_creator)
for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)
if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)
if self._shuffle_batch:
self._random.shuffle(batches)
for batch in batches:
batch_ids = [info.i for info in batch]
if self._only_src:
yield [[(self._src_seq_ids[idx], self._src_phone_ids[idx])]
for idx in batch_ids]
else:
yield [(self._src_seq_ids[idx], self._src_phone_ids[idx],
self._trg_seq_ids[idx][:-1],
self._trg_seq_ids[idx][1:]) for idx in batch_ids]
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册