From 6732376745b1736f212df6ce52021ef66f5ad8db Mon Sep 17 00:00:00 2001 From: AndyELiu Date: Mon, 22 Jul 2019 20:42:31 -0700 Subject: [PATCH] submit code for joint embedding paper (#2896) --- PaddleNLP/Research/ACL2019-JEMT/README.md | 128 +++ PaddleNLP/Research/ACL2019-JEMT/config.py | 117 +++ PaddleNLP/Research/ACL2019-JEMT/desc.py | 107 +++ PaddleNLP/Research/ACL2019-JEMT/infer.py | 342 ++++++++ PaddleNLP/Research/ACL2019-JEMT/model.py | 984 ++++++++++++++++++++++ PaddleNLP/Research/ACL2019-JEMT/reader.py | 385 +++++++++ PaddleNLP/Research/ACL2019-JEMT/train.py | 826 ++++++++++++++++++ 7 files changed, 2889 insertions(+) create mode 100644 PaddleNLP/Research/ACL2019-JEMT/README.md create mode 100644 PaddleNLP/Research/ACL2019-JEMT/config.py create mode 100644 PaddleNLP/Research/ACL2019-JEMT/desc.py create mode 100644 PaddleNLP/Research/ACL2019-JEMT/infer.py create mode 100644 PaddleNLP/Research/ACL2019-JEMT/model.py create mode 100644 PaddleNLP/Research/ACL2019-JEMT/reader.py create mode 100644 PaddleNLP/Research/ACL2019-JEMT/train.py diff --git a/PaddleNLP/Research/ACL2019-JEMT/README.md b/PaddleNLP/Research/ACL2019-JEMT/README.md new file mode 100644 index 00000000..b003b9a2 --- /dev/null +++ b/PaddleNLP/Research/ACL2019-JEMT/README.md @@ -0,0 +1,128 @@ +## 简介 + +### 任务说明 + 机器翻译的输入一般是源语言的句子。但在很多实际系统中,比如语音识别系统的输出或者基于拼音的文字输入,源语言句子一般包含很多同音字错误, 这会导致翻译出现很多意想不到的错误。由于可以同时获得发音信息,我们提出了一种在输入端加入发音信息,进而在模型的嵌入层 +融合文字信息和发音信息的翻译方法,大大提高了翻译模型对同音字错误的抵抗能力。 + + 文章地址:https://arxiv.org/abs/1810.06729 + +### 效果说明 + + 我们使用LDC Chinese-to-English数据集训练。中文词典用的是[DaCiDian](https://github.com/aishell-foundation/DaCiDian)。 在newstest2006上进行评测,效果如下所示: + +| beta=0 | beta=0.50 | beta=0.85 | beta=0.95 | +|-|-|-|-| +| 47.96 | 48.71 | 48.85 | 48.46 | + +beta代表发音信息的权重。这表明,即使将绝大部分权重放在发音信息上,翻译的效果依然很好。与此同时,翻译系统对同音字错误的抵抗力大大提高。 + + +## 安装说明 + +1. paddle安装 + + 本项目依赖于 PaddlePaddle Fluid 1.3.1 及以上版本,请参考 [安装指南](http://www.paddlepaddle.org/#quick-start) 进行安装 + +2. 环境依赖 + + 请参考PaddlePaddle[安装说明](http://paddlepaddle.org/documentation/docs/zh/1.3/beginners_guide/install/index_cn.html)部分的内容 + + + +## 如何训练 + +1. 数据格式 + + 数据格式和[Paddle机器翻译](https://github.com/PaddlePaddle/models/tree/develop/PaddleNLP/neural_machine_translation/transformer)的格式一致。为了获得输入句子的发音信息,需要额外提供源语言的发音基本单元和发音的词典。 + + A) 发音基本单元文件 + + 中文的发音基本单元是拼音,将所有的拼音放在一个文件,类似: + + + + bo + + li + + 。。。 + + B)发音词典 + + 根据DaCiDian,对bpe后的源语言中的token赋予一个或者几个发音,类似: + + ▁玻利维亚 bo li wei ya + + ▁举行 ju xing + + ▁总统 zong tong + + ▁与 yu + + 巴斯 ba si + + ▁这个 zhei ge|zhe ge + + 。。。 + +2. 训练模型 + + 数据准备完成后,可以使用 `train.py` 脚本进行训练。例子如下: + +```sh + python train.py \ + --src_vocab_fpath nist_data/vocab_all.28000 \ + --trg_vocab_fpath nist_data/vocab_all.28000 \ + --train_file_pattern nist_data/nist_train.txt \ + --phoneme_vocab_fpath nist_data/zh_pinyins.txt \ + --lexicon_fpath nist_data/zh_lexicon.txt \ + --batch_size 2048 \ + --use_token_batch True \ + --sort_type pool \ + --pool_size 200000 \ + --use_py_reader False \ + --use_mem_opt False \ + --enable_ce False \ + --fetch_steps 1 \ + pass_num 100 \ + learning_rate 2.0 \ + warmup_steps 8000 \ + beta2 0.997 \ + d_model 512 \ + d_inner_hid 2048 \ + n_head 8 \ + weight_sharing True \ + max_length 256 \ + save_freq 10000 \ + beta 0.85 \ + model_dir pinyin_models_beta085 \ + ckpt_dir pinyin_ckpts_beta085 +``` + +上述命令中设置了源语言词典文件路径(`src_vocab_fpath`)、目标语言词典文件路径(`trg_vocab_fpath`)、训练数据文件(`train_file_pattern`,支持通配符), 发音单元文件路径(`phoneme_vocab_fpath`), 发音词典路径(`lexicon_fpath`)等数据相关的参数和构造 batch 方式(`use_token_batch` 指定了数据按照 token 数目或者 sequence 数目组成 batch)等 reader 相关的参数。有关这些参数更详细的信息可以通过执行以下命令查看: + +```sh +python train.py --help +``` + + 更多模型训练相关的参数则在 `config.py` 中的 `ModelHyperParams` 和 `TrainTaskConfig` 内定义;`ModelHyperParams` 定义了 embedding 维度等模型超参数,`TrainTaskConfig` 定义了 warmup 步数等训练需要的参数。这些参数默认使用了 Transformer 论文中 base model 的配置,如需调整可以在该脚本中进行修改。另外这些参数同样可在执行训练脚本的命令行中设置,传入的配置会合并并覆盖 `config.py` 中的配置. + + 注意,如训练时更改了模型配置,使用 `infer.py` 预测时需要使用对应相同的模型配置;另外,训练时默认使用所有 GPU,可以通过 `CUDA_VISIBLE_DEVICES` 环境变量来设置使用指定的 GPU。 + +## 如何预测 + +使用以上提供的数据和模型,可以按照以下代码进行预测,翻译结果将打印到标准输出: + +```sh +python infer.py \ +--src_vocab_fpath nist_data/vocab_all.28000 \ +--trg_vocab_fpath nist_data/vocab_all.28000 \ +--test_file_pattern nist_data/nist_test.txt \ +--phoneme_vocab_fpath nist_data/zh_pinyins.txt \ +--lexicon_fpath nist_data/zh_lexicon.txt \ +--batch_size 32 \ +model_path pinyin_models_beta085/iter_200000.infer.model \ +beam_size 5 \ +max_out_len 255 \ +beta 0.85 +``` diff --git a/PaddleNLP/Research/ACL2019-JEMT/config.py b/PaddleNLP/Research/ACL2019-JEMT/config.py new file mode 100644 index 00000000..d56fe2f8 --- /dev/null +++ b/PaddleNLP/Research/ACL2019-JEMT/config.py @@ -0,0 +1,117 @@ +class TrainTaskConfig(object): + # support both CPU and GPU now. + use_gpu = True + # the epoch number to train. + pass_num = 30 + # the number of sequences contained in a mini-batch. + # deprecated, set batch_size in args. + batch_size = 32 + # the hyper parameters for Adam optimizer. + # This static learning_rate will be multiplied to the LearningRateScheduler + # derived learning rate the to get the final learning rate. + learning_rate = 2.0 + beta1 = 0.9 + beta2 = 0.997 + eps = 1e-9 + # the parameters for learning rate scheduling. + warmup_steps = 8000 + # the weight used to mix up the ground-truth distribution and the fixed + # uniform distribution in label smoothing when training. + # Set this as zero if label smoothing is not wanted. + label_smooth_eps = 0.1 + # the directory for saving trained models. + model_dir = "trained_models" + # the directory for saving checkpoints. + ckpt_dir = "trained_ckpts" + # the directory for loading checkpoint. + # If provided, continue training from the checkpoint. + ckpt_path = None + # the parameter to initialize the learning rate scheduler. + # It should be provided if use checkpoints, since the checkpoint doesn't + # include the training step counter currently. + start_step = 0 + # the frequency to save trained models. + save_freq = 10000 + + +class InferTaskConfig(object): + use_gpu = True + # the number of examples in one run for sequence generation. + batch_size = 10 + # the parameters for beam search. + beam_size = 5 + max_out_len = 256 + # the number of decoded sentences to output. + n_best = 1 + # the flags indicating whether to output the special tokens. + output_bos = False + output_eos = False + output_unk = True + # the directory for loading the trained model. + model_path = "trained_models/pass_1.infer.model" + + +class ModelHyperParams(object): + # These following five vocabularies related configurations will be set + # automatically according to the passed vocabulary path and special tokens. + # size of source word dictionary. + src_vocab_size = 10000 + # size of target word dictionay + trg_vocab_size = 10000 + # size of phone dictionary + phone_vocab_size = 1000 + # ratio of phoneme embeddings + beta = 0.0 + # index for token + bos_idx = 0 + # index for token + eos_idx = 1 + # index for token + unk_idx = 2 + # index for in phonemes + phone_pad_idx = 0 + # max length of sequences deciding the size of position encoding table. + max_length = 256 + # the dimension for word embeddings, which is also the last dimension of + # the input and output of multi-head attention, position-wise feed-forward + # networks, encoder and decoder. + d_model = 512 + # size of the hidden layer in position-wise feed-forward networks. + d_inner_hid = 2048 + # the dimension that keys are projected to for dot-product attention. + d_key = 64 + # the dimension that values are projected to for dot-product attention. + d_value = 64 + # number of head used in multi-head attention. + n_head = 8 + # number of sub-layers to be stacked in the encoder and decoder. + n_layer = 6 + # dropout rates of different modules. + prepostprocess_dropout = 0.1 + attention_dropout = 0.1 + relu_dropout = 0.1 + # to process before each sub-layer + preprocess_cmd = "n" # layer normalization + # to process after each sub-layer + postprocess_cmd = "da" # dropout + residual connection + # random seed used in dropout for CE. + dropout_seed = None + # the flag indicating whether to share embedding and softmax weights. + # vocabularies in source and target should be same for weight sharing. + weight_sharing = True + + +def merge_cfg_from_list(cfg_list, g_cfgs): + """ + Set the above global configurations using the cfg_list. + """ + assert len(cfg_list) % 2 == 0 + for key, value in zip(cfg_list[0::2], cfg_list[1::2]): + for g_cfg in g_cfgs: + if hasattr(g_cfg, key): + try: + value = eval(value) + except Exception: # for file path + pass + setattr(g_cfg, key, value) + break diff --git a/PaddleNLP/Research/ACL2019-JEMT/desc.py b/PaddleNLP/Research/ACL2019-JEMT/desc.py new file mode 100644 index 00000000..857ef02a --- /dev/null +++ b/PaddleNLP/Research/ACL2019-JEMT/desc.py @@ -0,0 +1,107 @@ +# The placeholder for batch_size in compile time. Must be -1 currently to be +# consistent with some ops' infer-shape output in compile time, such as the +# sequence_expand op used in beamsearch decoder. +batch_size = -1 +# The placeholder for squence length in compile time. +seq_len = 256 +# The placeholder for phoneme sequence length in comiple time. +phone_len = 16 +# The placeholder for head number in compile time. +n_head = 8 +# The placeholder for model dim in compile time. +d_model = 512 +# Here list the data shapes and data types of all inputs. +# The shapes here act as placeholder and are set to pass the infer-shape in +# compile time. +input_descs = { + # The actual data shape of src_word is: + # [batch_size, max_src_len_in_batch, 1] + "src_word": [(batch_size, seq_len, 1), "int64", 2], + # The actual data shape of src_pos is: + # [batch_size, max_src_len_in_batch, 1] + "src_pos": [(batch_size, seq_len, 1), "int64"], + # This input is used to remove attention weights on paddings in the + # encoder. + # The actual data shape of src_slf_attn_bias is: + # [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch] + "src_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], + "src_phone": [(batch_size, seq_len, phone_len, 1), "int64"], + "src_phone_mask": [(batch_size, seq_len, phone_len), "int64"], + # The actual data shape of trg_word is: + # [batch_size, max_trg_len_in_batch, 1] + "trg_word": [(batch_size, seq_len, 1), "int64", + 2], # lod_level is only used in fast decoder. + # The actual data shape of trg_pos is: + # [batch_size, max_trg_len_in_batch, 1] + "trg_pos": [(batch_size, seq_len, 1), "int64"], + # This input is used to remove attention weights on paddings and + # subsequent words in the decoder. + # The actual data shape of trg_slf_attn_bias is: + # [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch] + "trg_slf_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], + # This input is used to remove attention weights on paddings of the source + # input in the encoder-decoder attention. + # The actual data shape of trg_src_attn_bias is: + # [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch] + "trg_src_attn_bias": [(batch_size, n_head, seq_len, seq_len), "float32"], + # This input is used in independent decoder program for inference. + # The actual data shape of enc_output is: + # [batch_size, max_src_len_in_batch, d_model] + "enc_output": [(batch_size, seq_len, d_model), "float32"], + # The actual data shape of label_word is: + # [batch_size * max_trg_len_in_batch, 1] + "lbl_word": [(batch_size * seq_len, 1), "int64"], + # This input is used to mask out the loss of paddding tokens. + # The actual data shape of label_weight is: + # [batch_size * max_trg_len_in_batch, 1] + "lbl_weight": [(batch_size * seq_len, 1), "float32"], + # This input is used in beam-search decoder. + "init_score": [(batch_size, 1), "float32", 2], + # This input is used in beam-search decoder for the first gather + # (cell states updation) + "init_idx": [(batch_size, ), "int32"], +} + +# Names of word embedding table which might be reused for weight sharing. +word_emb_param_names = ( + "src_word_emb_table", + "trg_word_emb_table", +) + +phone_emb_param_name = "phone_emb_table" + +# Names of position encoding table which will be initialized externally. +pos_enc_param_names = ( + "src_pos_enc_table", + "trg_pos_enc_table", +) +# separated inputs for different usages. +encoder_data_input_fields = ( + "src_word", + "src_pos", + "src_slf_attn_bias", + "src_phone", + "src_phone_mask", +) +decoder_data_input_fields = ( + "trg_word", + "trg_pos", + "trg_slf_attn_bias", + "trg_src_attn_bias", + "enc_output", +) +label_data_input_fields = ( + "lbl_word", + "lbl_weight", +) +# In fast decoder, trg_pos (only containing the current time step) is generated +# by ops and trg_slf_attn_bias is not needed. +fast_decoder_data_input_fields = ( + "trg_word", + "init_score", + "init_idx", + "trg_src_attn_bias", +) + +# Set seed for CE +dropout_seed = None diff --git a/PaddleNLP/Research/ACL2019-JEMT/infer.py b/PaddleNLP/Research/ACL2019-JEMT/infer.py new file mode 100644 index 00000000..afcc9fe0 --- /dev/null +++ b/PaddleNLP/Research/ACL2019-JEMT/infer.py @@ -0,0 +1,342 @@ +import argparse +import ast +import multiprocessing +import numpy as np +import os +import sys +from functools import partial + +import paddle +import paddle.fluid as fluid + +import reader +from config import * +from desc import * +from model import fast_decode as fast_decoder +from train import pad_batch_data, pad_phoneme_data, prepare_data_generator + + +def parse_args(): + parser = argparse.ArgumentParser("Training for Transformer.") + parser.add_argument( + "--src_vocab_fpath", + type=str, + required=True, + help="The path of vocabulary file of source language.") + parser.add_argument( + "--trg_vocab_fpath", + type=str, + required=True, + help="The path of vocabulary file of target language.") + parser.add_argument( + "--phoneme_vocab_fpath", + type=str, + required=True, + help="The path of vocabulary file of phonemes.") + parser.add_argument( + "--lexicon_fpath", + type=str, + required=True, + help="The path of lexicon of source language.") + parser.add_argument( + "--test_file_pattern", + type=str, + required=True, + help="The pattern to match test data files.") + parser.add_argument( + "--batch_size", + type=int, + default=50, + help="The number of examples in one run for sequence generation.") + parser.add_argument( + "--pool_size", + type=int, + default=10000, + help="The buffer size to pool data.") + parser.add_argument( + "--special_token", + type=str, + default=["", "", ""], + nargs=3, + help="The , and tokens in the dictionary.") + parser.add_argument( + "--token_delimiter", + type=lambda x: str(x.encode().decode("unicode-escape")), + default=" ", + help="The delimiter used to split tokens in source or target sentences. " + "For EN-DE BPE data we provided, use spaces as token delimiter. ") + parser.add_argument( + "--use_mem_opt", + type=ast.literal_eval, + default=True, + help="The flag indicating whether to use memory optimization.") + parser.add_argument( + "--use_py_reader", + type=ast.literal_eval, + default=True, + help="The flag indicating whether to use py_reader.") + parser.add_argument( + "--use_parallel_exe", + type=ast.literal_eval, + default=False, + help="The flag indicating whether to use ParallelExecutor.") + parser.add_argument( + 'opts', + help='See config.py for all options', + default=None, + nargs=argparse.REMAINDER) + args = parser.parse_args() + # Append args related to dict + src_dict = reader.DataReader.load_dict(args.src_vocab_fpath) + trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath) + phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath) + dict_args = [ + "src_vocab_size", + str(len(src_dict)), "trg_vocab_size", + str(len(trg_dict)), "phone_vocab_size", + str(len(phone_dict)), "bos_idx", + str(src_dict[args.special_token[0]]), "eos_idx", + str(src_dict[args.special_token[1]]), "unk_idx", + str(src_dict[args.special_token[2]]) + ] + merge_cfg_from_list(args.opts + dict_args, + [InferTaskConfig, ModelHyperParams]) + return args + + +def post_process_seq(seq, + bos_idx=ModelHyperParams.bos_idx, + eos_idx=ModelHyperParams.eos_idx, + output_bos=InferTaskConfig.output_bos, + output_eos=InferTaskConfig.output_eos): + """ + Post-process the beam-search decoded sequence. Truncate from the first + and remove the and tokens currently. + """ + eos_pos = len(seq) - 1 + for i, idx in enumerate(seq): + if idx == eos_idx: + eos_pos = i + break + seq = [ + idx for idx in seq[:eos_pos + 1] + if (output_bos or idx != bos_idx) and (output_eos or idx != eos_idx) + ] + return seq + + +def prepare_batch_input(insts, data_input_names, src_pad_idx, phone_pad_idx, + bos_idx, n_head, d_model, place): + """ + Put all padded data needed by beam search decoder into a dict. + """ + src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + src_word = src_word.reshape(-1, src_max_len, 1) + src_pos = src_pos.reshape(-1, src_max_len, 1) + src_phone, src_phone_mask, max_phone_len = pad_phoneme_data( + [inst[1] for inst in insts], phone_pad_idx, src_max_len) + + # start tokens + trg_word = np.asarray([[bos_idx]] * len(insts), dtype="int64") + trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], + [1, 1, 1, 1]).astype("float32") + trg_word = trg_word.reshape(-1, 1, 1) + + def to_lodtensor(data, place, lod=None): + data_tensor = fluid.LoDTensor() + data_tensor.set(data, place) + if lod is not None: + data_tensor.set_lod(lod) + return data_tensor + + # beamsearch_op must use tensors with lod + init_score = to_lodtensor( + np.zeros_like(trg_word, dtype="float32").reshape(-1, 1), place, + [range(trg_word.shape[0] + 1)] * 2) + trg_word = to_lodtensor(trg_word, place, + [range(trg_word.shape[0] + 1)] * 2) + init_idx = np.asarray(range(len(insts)), dtype="int32") + + data_input_dict = dict( + zip(data_input_names, [ + src_word, src_pos, src_slf_attn_bias, src_phone, src_phone_mask, + trg_word, init_score, init_idx, trg_src_attn_bias + ])) + return data_input_dict + + +def prepare_feed_dict_list(data_generator, count, place): + """ + Prepare the list of feed dict for multi-devices. + """ + feed_dict_list = [] + if data_generator is not None: # use_py_reader == False + data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields + data = next(data_generator) + for idx, data_buffer in enumerate(data): + data_input_dict = prepare_batch_input( + data_buffer, data_input_names, ModelHyperParams.eos_idx, + ModelHyperParams.phone_pad_idx, ModelHyperParams.bos_idx, + ModelHyperParams.n_head, ModelHyperParams.d_model, place) + feed_dict_list.append(data_input_dict) + return feed_dict_list if len(feed_dict_list) == count else None + + +def py_reader_provider_wrapper(data_reader, place): + """ + Data provider needed by fluid.layers.py_reader. + """ + + def py_reader_provider(): + data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields + for batch_id, data in enumerate(data_reader()): + data_input_dict = prepare_batch_input( + data, data_input_names, ModelHyperParams.eos_idx, + ModelHyperParams.phone_pad_idx, ModelHyperParams.bos_idx, + ModelHyperParams.n_head, ModelHyperParams.d_model, place) + yield [data_input_dict[item] for item in data_input_names] + + return py_reader_provider + + +def fast_infer(args): + """ + Inference by beam search decoder based solely on Fluid operators. + """ + out_ids, out_scores, pyreader = fast_decoder( + ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size, + ModelHyperParams.phone_vocab_size, + ModelHyperParams.max_length + 1, + ModelHyperParams.n_layer, + ModelHyperParams.n_head, + ModelHyperParams.d_key, + ModelHyperParams.d_value, + ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, + ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, + ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, + ModelHyperParams.postprocess_cmd, + ModelHyperParams.weight_sharing, + InferTaskConfig.beam_size, + InferTaskConfig.max_out_len, + ModelHyperParams.bos_idx, + ModelHyperParams.eos_idx, + beta=ModelHyperParams.beta, + use_py_reader=args.use_py_reader) + + # This is used here to set dropout to the test mode. + infer_program = fluid.default_main_program().clone(for_test=True) + + if args.use_mem_opt: + fluid.memory_optimize(infer_program) + + if InferTaskConfig.use_gpu: + place = fluid.CUDAPlace(0) + dev_count = fluid.core.get_cuda_device_count() + else: + place = fluid.CPUPlace() + dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + fluid.io.load_vars( + exe, + InferTaskConfig.model_path, + vars=[ + var for var in infer_program.list_vars() + if isinstance(var, fluid.framework.Parameter) + ]) + + exec_strategy = fluid.ExecutionStrategy() + # For faster executor + exec_strategy.use_experimental_executor = True + exec_strategy.num_threads = 1 + build_strategy = fluid.BuildStrategy() + infer_exe = fluid.ParallelExecutor( + use_cuda=TrainTaskConfig.use_gpu, + main_program=infer_program, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + + # data reader settings for inference + args.train_file_pattern = args.test_file_pattern + args.use_token_batch = False + args.sort_type = reader.SortType.NONE + args.shuffle = False + args.shuffle_batch = False + test_data = prepare_data_generator( + args, + is_test=False, + count=dev_count, + pyreader=pyreader, + py_reader_provider_wrapper=py_reader_provider_wrapper, + place=place) + if args.use_py_reader: + pyreader.start() + data_generator = None + else: + data_generator = test_data() + trg_idx2word = reader.DataReader.load_dict( + dict_path=args.trg_vocab_fpath, reverse=True) + + while True: + try: + feed_dict_list = prepare_feed_dict_list(data_generator, dev_count, + place) + if args.use_parallel_exe: + seq_ids, seq_scores = infer_exe.run( + fetch_list=[out_ids.name, out_scores.name], + feed=feed_dict_list, + return_numpy=False) + else: + seq_ids, seq_scores = exe.run( + program=infer_program, + fetch_list=[out_ids.name, out_scores.name], + feed=feed_dict_list[0] + if feed_dict_list is not None else None, + return_numpy=False, + use_program_cache=True) + seq_ids_list, seq_scores_list = [ + seq_ids + ], [seq_scores] if isinstance( + seq_ids, paddle.fluid.LoDTensor) else (seq_ids, seq_scores) + for seq_ids, seq_scores in zip(seq_ids_list, seq_scores_list): + # How to parse the results: + # Suppose the lod of seq_ids is: + # [[0, 3, 6], [0, 12, 24, 40, 54, 67, 82]] + # then from lod[0]: + # there are 2 source sentences, beam width is 3. + # from lod[1]: + # the first source sentence has 3 hyps; the lengths are 12, 12, 16 + # the second source sentence has 3 hyps; the lengths are 14, 13, 15 + hyps = [[] for i in range(len(seq_ids.lod()[0]) - 1)] + scores = [[] for i in range(len(seq_scores.lod()[0]) - 1)] + for i in range(len(seq_ids.lod()[0]) - + 1): # for each source sentence + start = seq_ids.lod()[0][i] + end = seq_ids.lod()[0][i + 1] + for j in range(end - start): # for each candidate + sub_start = seq_ids.lod()[1][start + j] + sub_end = seq_ids.lod()[1][start + j + 1] + hyps[i].append(" ".join([ + trg_idx2word[idx] for idx in post_process_seq( + np.array(seq_ids)[sub_start:sub_end]) + ])) + scores[i].append(np.array(seq_scores)[sub_end - 1]) + print(hyps[i][-1]) + if len(hyps[i]) >= InferTaskConfig.n_best: + break + except (StopIteration, fluid.core.EOFException): + # The data pass is over. + if args.use_py_reader: + pyreader.reset() + break + + +if __name__ == "__main__": + args = parse_args() + fast_infer(args) diff --git a/PaddleNLP/Research/ACL2019-JEMT/model.py b/PaddleNLP/Research/ACL2019-JEMT/model.py new file mode 100644 index 00000000..c0a9c375 --- /dev/null +++ b/PaddleNLP/Research/ACL2019-JEMT/model.py @@ -0,0 +1,984 @@ +from functools import partial +import numpy as np + +import paddle.fluid as fluid +import paddle.fluid.layers as layers + +from desc import * + + +def wrap_layer_with_block(layer, block_idx): + """ + Make layer define support indicating block, by which we can add layers + to other blocks within current block. This will make it easy to define + cache among while loop. + """ + + class BlockGuard(object): + """ + BlockGuard class. + + BlockGuard class is used to switch to the given block in a program by + using the Python `with` keyword. + """ + + def __init__(self, block_idx=None, main_program=None): + self.main_program = fluid.default_main_program( + ) if main_program is None else main_program + self.old_block_idx = self.main_program.current_block().idx + self.new_block_idx = block_idx + + def __enter__(self): + self.main_program.current_block_idx = self.new_block_idx + + def __exit__(self, exc_type, exc_val, exc_tb): + self.main_program.current_block_idx = self.old_block_idx + if exc_type is not None: + return False # re-raise exception + return True + + def layer_wrapper(*args, **kwargs): + with BlockGuard(block_idx): + return layer(*args, **kwargs) + + return layer_wrapper + + +def position_encoding_init(n_position, d_pos_vec): + """ + Generate the initial values for the sinusoid position encoding table. + """ + channels = d_pos_vec + position = np.arange(n_position) + num_timescales = channels // 2 + log_timescale_increment = ( + np.log(float(1e4) / float(1)) / (num_timescales - 1)) + inv_timescales = np.exp( + np.arange(num_timescales)) * -log_timescale_increment + scaled_time = np.expand_dims(position, 1) * np.expand_dims( + inv_timescales, 0) + signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) + signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant') + position_enc = signal + return position_enc.astype("float32") + + +def multi_head_attention(queries, + keys, + values, + attn_bias, + d_key, + d_value, + d_model, + n_head=1, + dropout_rate=0., + cache=None, + gather_idx=None, + static_kv=False): + """ + Multi-Head Attention. Note that attn_bias is added to the logit before + computing softmax activiation to mask certain selected positions so that + they will not considered in attention weights. + """ + keys = queries if keys is None else keys + values = keys if values is None else values + + if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3): + raise ValueError( + "Inputs: quries, keys and values should all be 3-D tensors.") + + def __compute_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Add linear projection to queries, keys, and values. + """ + q = layers.fc( + input=queries, + size=d_key * n_head, + bias_attr=False, + num_flatten_dims=2) + # For encoder-decoder attention in inference, insert the ops and vars + # into global block to use as cache among beam search. + fc_layer = wrap_layer_with_block( + layers.fc, + fluid.default_main_program().current_block(). + parent_idx) if cache is not None and static_kv else layers.fc + k = fc_layer( + input=keys, + size=d_key * n_head, + bias_attr=False, + num_flatten_dims=2) + v = fc_layer( + input=values, + size=d_value * n_head, + bias_attr=False, + num_flatten_dims=2) + return q, k, v + + def __split_heads_qkv(queries, keys, values, n_head, d_key, d_value): + """ + Reshape input tensors at the last dimension to split multi-heads + and then transpose. Specifically, transform the input tensor with shape + [bs, max_sequence_length, n_head * hidden_dim] to the output tensor + with shape [bs, n_head, max_sequence_length, hidden_dim]. + """ + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + reshaped_q = layers.reshape( + x=queries, shape=[0, 0, n_head, d_key], inplace=True) + # permuate the dimensions into: + # [batch_size, n_head, max_sequence_len, hidden_size_per_head] + q = layers.transpose(x=reshaped_q, perm=[0, 2, 1, 3]) + # For encoder-decoder attention in inference, insert the ops and vars + # into global block to use as cache among beam search. + reshape_layer = wrap_layer_with_block( + layers.reshape, + fluid.default_main_program().current_block(). + parent_idx) if cache is not None and static_kv else layers.reshape + transpose_layer = wrap_layer_with_block( + layers.transpose, + fluid.default_main_program().current_block().parent_idx + ) if cache is not None and static_kv else layers.transpose + reshaped_k = reshape_layer( + x=keys, shape=[0, 0, n_head, d_key], inplace=True) + k = transpose_layer(x=reshaped_k, perm=[0, 2, 1, 3]) + reshaped_v = reshape_layer( + x=values, shape=[0, 0, n_head, d_value], inplace=True) + v = transpose_layer(x=reshaped_v, perm=[0, 2, 1, 3]) + + if cache is not None: # only for faster inference + if static_kv: # For encoder-decoder attention in inference + cache_k, cache_v = cache["static_k"], cache["static_v"] + # To init the static_k and static_v in cache. + # Maybe we can use condition_op(if_else) to do these at the first + # step in while loop to replace these, however it might be less + # efficient. + static_cache_init = wrap_layer_with_block( + layers.assign, + fluid.default_main_program().current_block().parent_idx) + static_cache_init(k, cache_k) + static_cache_init(v, cache_v) + else: # For decoder self-attention in inference + cache_k, cache_v = cache["k"], cache["v"] + # gather cell states corresponding to selected parent + select_k = layers.gather(cache_k, index=gather_idx) + select_v = layers.gather(cache_v, index=gather_idx) + if not static_kv: + # For self attention in inference, use cache and concat time steps. + select_k = layers.concat([select_k, k], axis=2) + select_v = layers.concat([select_v, v], axis=2) + # update cell states(caches) cached in global block + layers.assign(select_k, cache_k) + layers.assign(select_v, cache_v) + return q, select_k, select_v + return q, k, v + + def __combine_heads(x): + """ + Transpose and then reshape the last two dimensions of inpunt tensor x + so that it becomes one dimension, which is reverse to __split_heads. + """ + if len(x.shape) != 4: + raise ValueError("Input(x) should be a 4-D Tensor.") + + trans_x = layers.transpose(x, perm=[0, 2, 1, 3]) + # The value 0 in shape attr means copying the corresponding dimension + # size of the input as the output dimension size. + return layers.reshape( + x=trans_x, + shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]], + inplace=True) + + def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate): + """ + Scaled Dot-Product Attention + """ + product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_key**-0.5) + if attn_bias: + product += attn_bias + weights = layers.softmax(product) + if dropout_rate: + weights = layers.dropout( + weights, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False) + out = layers.matmul(weights, v) + return out + + q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value) + q, k, v = __split_heads_qkv(q, k, v, n_head, d_key, d_value) + + ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_model, + dropout_rate) + + out = __combine_heads(ctx_multiheads) + + # Project back to the model size. + proj_out = layers.fc( + input=out, size=d_model, bias_attr=False, num_flatten_dims=2) + return proj_out + + +def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate): + """ + Position-wise Feed-Forward Networks. + This module consists of two linear transformations with a ReLU activation + in between, which is applied to each position separately and identically. + """ + hidden = layers.fc( + input=x, size=d_inner_hid, num_flatten_dims=2, act="relu") + if dropout_rate: + hidden = layers.dropout( + hidden, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False) + out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) + return out + + +def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.): + """ + Add residual connection, layer normalization and droput to the out tensor + optionally according to the value of process_cmd. + This will be used before or after multi-head attention and position-wise + feed-forward networks. + """ + for cmd in process_cmd: + if cmd == "a": # add residual connection + out = out + prev_out if prev_out else out + elif cmd == "n": # add layer normalization + out = layers.layer_norm( + out, + begin_norm_axis=len(out.shape) - 1, + param_attr=fluid.initializer.Constant(1.), + bias_attr=fluid.initializer.Constant(0.)) + elif cmd == "d": # add dropout + if dropout_rate: + out = layers.dropout( + out, + dropout_prob=dropout_rate, + seed=dropout_seed, + is_test=False) + return out + + +pre_process_layer = partial(pre_post_process_layer, None) +post_process_layer = pre_post_process_layer + + +def prepare_encoder(src_word, + src_pos, + src_vocab_size, + src_phone, + src_phone_mask, + phone_vocab_size, + src_emb_dim, + src_max_len, + beta=0.0, + dropout_rate=0., + bos_idx=0, + phone_pad_idx=-1, + word_emb_param_name=None): + """Add word embeddings and position encodings. + The output tensor has a shape of: + [batch_size, max_src_length_in_batch, d_model]. + This module is used at the bottom of the encoder stacks. + """ + src_word_emb = layers.embedding( + src_word, + size=[src_vocab_size, src_emb_dim], + padding_idx=bos_idx, # set embedding of bos to 0 + param_attr=fluid.ParamAttr( + name=word_emb_param_name, + initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) + + # shape [batch_size, max_seq_len, max_phone_len, dim] + src_phone_emb = layers.embedding( + src_phone, + size=[phone_vocab_size, src_emb_dim], + padding_idx=phone_pad_idx, # set embedding of phone_pad_idx to 0 + param_attr=fluid.ParamAttr( + name=phone_emb_param_name, + initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) + sum_phone_emb = layers.reduce_sum(src_phone_emb, dim=2) + float_mask = layers.cast(src_phone_mask, dtype='float32') + sum_mask = layers.reduce_sum(float_mask, dim=2) + 1e-9 + mean_phone_emb = layers.elementwise_div(sum_phone_emb, sum_mask, axis=0) + + src_pos_enc = layers.embedding( + src_pos, + size=[src_max_len, src_emb_dim], + param_attr=fluid.ParamAttr( + name=pos_enc_param_names[0], trainable=False)) + src_pos_enc.stop_gradient = True + enc_input = ( + 1 - beta) * src_word_emb + beta * mean_phone_emb + src_pos_enc + return layers.dropout( + enc_input, dropout_prob=dropout_rate, seed=dropout_seed, + is_test=False) if dropout_rate else enc_input + + +def prepare_decoder(src_word, + src_pos, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0., + bos_idx=0, + word_emb_param_name=None): + """Add word embeddings and position encodings. + The output tensor has a shape of: + [batch_size, max_src_length_in_batch, d_model]. + This module is used at the bottom of the encoder stacks. + """ + src_word_emb = layers.embedding( + src_word, + size=[src_vocab_size, src_emb_dim], + padding_idx=bos_idx, # set embedding of bos to 0 + param_attr=fluid.ParamAttr( + name=word_emb_param_name, + initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5))) + + src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5) + src_pos_enc = layers.embedding( + src_pos, + size=[src_max_len, src_emb_dim], + param_attr=fluid.ParamAttr( + name=pos_enc_param_names[1], trainable=False)) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + return layers.dropout( + enc_input, dropout_prob=dropout_rate, seed=dropout_seed, + is_test=False) if dropout_rate else enc_input + + +def encoder_layer(enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + """The encoder layers that can be stacked to form a deep encoder. + This module consits of a multi-head (self) attention followed by + position-wise feed-forward networks and both the two components companied + with the post_process_layer to add residual connection, layer normalization + and droput. + """ + attn_output = multi_head_attention( + pre_process_layer(enc_input, preprocess_cmd, + prepostprocess_dropout), None, None, attn_bias, + d_key, d_value, d_model, n_head, attention_dropout) + attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd, + prepostprocess_dropout) + ffd_output = positionwise_feed_forward( + pre_process_layer(attn_output, preprocess_cmd, prepostprocess_dropout), + d_inner_hid, d_model, relu_dropout) + return post_process_layer(attn_output, ffd_output, postprocess_cmd, + prepostprocess_dropout) + + +def encoder(enc_input, + attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + """ + The encoder is composed of a stack of identical layers returned by calling + encoder_layer. + """ + for i in range(n_layer): + enc_output = encoder_layer( + enc_input, + attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + ) + enc_input = enc_output + enc_output = pre_process_layer(enc_output, preprocess_cmd, + prepostprocess_dropout) + return enc_output + + +def decoder_layer(dec_input, + enc_output, + slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + cache=None, + gather_idx=None): + """ The layer to be stacked in decoder part. + The structure of this module is similar to that in the encoder part except + a multi-head attention is added to implement encoder-decoder attention. + """ + slf_attn_output = multi_head_attention( + pre_process_layer(dec_input, preprocess_cmd, prepostprocess_dropout), + None, + None, + slf_attn_bias, + d_key, + d_value, + d_model, + n_head, + attention_dropout, + cache=cache, + gather_idx=gather_idx) + slf_attn_output = post_process_layer( + dec_input, + slf_attn_output, + postprocess_cmd, + prepostprocess_dropout, + ) + enc_attn_output = multi_head_attention( + pre_process_layer(slf_attn_output, preprocess_cmd, + prepostprocess_dropout), + enc_output, + enc_output, + dec_enc_attn_bias, + d_key, + d_value, + d_model, + n_head, + attention_dropout, + cache=cache, + gather_idx=gather_idx, + static_kv=True) + enc_attn_output = post_process_layer( + slf_attn_output, + enc_attn_output, + postprocess_cmd, + prepostprocess_dropout, + ) + ffd_output = positionwise_feed_forward( + pre_process_layer(enc_attn_output, preprocess_cmd, + prepostprocess_dropout), + d_inner_hid, + d_model, + relu_dropout, + ) + dec_output = post_process_layer( + enc_attn_output, + ffd_output, + postprocess_cmd, + prepostprocess_dropout, + ) + return dec_output + + +def decoder(dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + caches=None, + gather_idx=None): + """ + The decoder is composed of a stack of identical decoder_layer layers. + """ + for i in range(n_layer): + dec_output = decoder_layer( + dec_input, + enc_output, + dec_slf_attn_bias, + dec_enc_attn_bias, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + cache=None if caches is None else caches[i], + gather_idx=gather_idx) + dec_input = dec_output + dec_output = pre_process_layer(dec_output, preprocess_cmd, + prepostprocess_dropout) + return dec_output + + +def make_all_inputs(input_fields): + """ + Define the input data layers for the transformer model. + """ + inputs = [] + for input_field in input_fields: + input_var = layers.data( + name=input_field, + shape=input_descs[input_field][0], + dtype=input_descs[input_field][1], + lod_level=input_descs[input_field][2] + if len(input_descs[input_field]) == 3 else 0, + append_batch_size=False) + inputs.append(input_var) + + return inputs + + +def make_all_py_reader_inputs(input_fields, is_test=False): + reader = layers.py_reader( + capacity=20, + name="test_reader" if is_test else "train_reader", + shapes=[input_descs[input_field][0] for input_field in input_fields], + dtypes=[input_descs[input_field][1] for input_field in input_fields], + lod_levels=[ + input_descs[input_field][2] + if len(input_descs[input_field]) == 3 else 0 + for input_field in input_fields + ]) + return layers.read_file(reader), reader + + +def transformer(src_vocab_size, + trg_vocab_size, + phone_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + label_smooth_eps, + beta, + bos_idx=0, + use_py_reader=False, + is_test=False): + if weight_sharing: + assert src_vocab_size == trg_vocab_size, ( + "Vocabularies in source and target should be same for weight sharing." + ) + + data_input_names = encoder_data_input_fields + \ + decoder_data_input_fields[:-1] + label_data_input_fields + + if use_py_reader: + all_inputs, reader = make_all_py_reader_inputs(data_input_names, + is_test) + else: + all_inputs = make_all_inputs(data_input_names) + + enc_inputs_len = len(encoder_data_input_fields) + dec_inputs_len = len(decoder_data_input_fields[:-1]) + enc_inputs = all_inputs[0:enc_inputs_len] + dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len] + label = all_inputs[-2] + weights = all_inputs[-1] + + enc_output = wrap_encoder( + src_vocab_size, + phone_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + beta, + enc_inputs, + ) + + predict = wrap_decoder( + trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + dec_inputs, + enc_output, + ) + + # Padding index do not contribute to the total loss. The weights is used to + # cancel padding index in calculating the loss. + if label_smooth_eps: + label = layers.label_smooth( + label=layers.one_hot(input=label, depth=trg_vocab_size), + epsilon=label_smooth_eps) + + cost = layers.softmax_with_cross_entropy( + logits=predict, + label=label, + soft_label=True if label_smooth_eps else False) + weighted_cost = cost * weights + sum_cost = layers.reduce_sum(weighted_cost) + token_num = layers.reduce_sum(weights) + token_num.stop_gradient = True + avg_cost = sum_cost / token_num + return sum_cost, avg_cost, predict, token_num, reader if use_py_reader else None + + +def wrap_encoder(src_vocab_size, + phone_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + beta, + enc_inputs=None, + bos_idx=0): + """ + The wrapper assembles together all needed layers for the encoder. + """ + if enc_inputs is None: + # This is used to implement independent encoder program in inference. + enc_inputs = make_all_inputs(encoder_data_input_fields) + + src_word = enc_inputs[0] + src_pos = enc_inputs[1] + src_slf_attn_bias = enc_inputs[2] + src_phone = enc_inputs[3] + src_phone_mask = enc_inputs[4] + + enc_input = prepare_encoder( + src_word, + src_pos, + src_vocab_size, + src_phone, + src_phone_mask, + phone_vocab_size, + d_model, + max_length, + beta, + prepostprocess_dropout, + bos_idx=bos_idx, + word_emb_param_name=word_emb_param_names[0]) + enc_output = encoder( + enc_input, + src_slf_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + ) + return enc_output + + +def wrap_decoder(trg_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + dec_inputs=None, + enc_output=None, + caches=None, + gather_idx=None, + bos_idx=0): + """ + The wrapper assembles together all needed layers for the decoder. + """ + if dec_inputs is None: + # This is used to implement independent decoder program in inference. + dec_inputs = make_all_inputs(decoder_data_input_fields) + trg_word = dec_inputs[0] + trg_pos = dec_inputs[1] + trg_slf_attn_bias = dec_inputs[2] + trg_src_attn_bias = dec_inputs[3] + + dec_input = prepare_decoder( + trg_word, + trg_pos, + trg_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx, + word_emb_param_name=word_emb_param_names[0] + if weight_sharing else word_emb_param_names[1]) + dec_output = decoder( + dec_input, + enc_output, + trg_slf_attn_bias, + trg_src_attn_bias, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + caches=caches, + gather_idx=gather_idx) + # Reshape to 2D tensor to use GEMM instead of BatchedGEMM + dec_output = layers.reshape( + dec_output, shape=[-1, dec_output.shape[-1]], inplace=True) + if weight_sharing: + predict = layers.matmul( + x=dec_output, + y=fluid.default_main_program().global_block().var( + word_emb_param_names[0]), + transpose_y=True) + else: + predict = layers.fc( + input=dec_output, size=trg_vocab_size, bias_attr=False) + if dec_inputs is None: + # Return probs for independent decoder program. + predict = layers.softmax(predict) + return predict + + +def fast_decode(src_vocab_size, + trg_vocab_size, + phone_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + beam_size, + max_out_len, + bos_idx, + eos_idx, + beta=0.0, + use_py_reader=False): + """ + Use beam search to decode. Caches will be used to store states of history + steps which can make the decoding faster. + """ + data_input_names = encoder_data_input_fields + fast_decoder_data_input_fields + + if use_py_reader: + all_inputs, reader = make_all_py_reader_inputs(data_input_names) + else: + all_inputs = make_all_inputs(data_input_names) + + enc_inputs_len = len(encoder_data_input_fields) + dec_inputs_len = len(fast_decoder_data_input_fields) + enc_inputs = all_inputs[0:enc_inputs_len] + dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len] + + enc_output = wrap_encoder( + src_vocab_size, + phone_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + beta, + enc_inputs, + bos_idx=bos_idx) + + start_tokens, init_scores, parent_idx, trg_src_attn_bias = dec_inputs + + def beam_search(): + max_len = layers.fill_constant( + shape=[1], + dtype=start_tokens.dtype, + value=max_out_len, + force_cpu=True) + step_idx = layers.fill_constant( + shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True) + cond = layers.less_than( + x=step_idx, y=max_len) # default force_cpu=True + while_op = layers.While(cond) + # array states will be stored for each step. + ids = layers.array_write( + layers.reshape(start_tokens, (-1, 1)), step_idx) + scores = layers.array_write(init_scores, step_idx) + # cell states will be overwrited at each step. + # caches contains states of history steps in decoder self-attention + # and static encoder output projections in encoder-decoder attention + # to reduce redundant computation. + caches = [ + { + "k": # for self attention + layers.fill_constant_batch_size_like( + input=start_tokens, + shape=[-1, n_head, 0, d_key], + dtype=enc_output.dtype, + value=0), + "v": # for self attention + layers.fill_constant_batch_size_like( + input=start_tokens, + shape=[-1, n_head, 0, d_value], + dtype=enc_output.dtype, + value=0), + "static_k": # for encoder-decoder attention + layers.create_tensor(dtype=enc_output.dtype), + "static_v": # for encoder-decoder attention + layers.create_tensor(dtype=enc_output.dtype) + } for i in range(n_layer) + ] + + with while_op.block(): + pre_ids = layers.array_read(array=ids, i=step_idx) + # Since beam_search_op dosen't enforce pre_ids' shape, we can do + # inplace reshape here which actually change the shape of pre_ids. + pre_ids = layers.reshape(pre_ids, (-1, 1, 1), inplace=True) + pre_scores = layers.array_read(array=scores, i=step_idx) + # gather cell states corresponding to selected parent + pre_src_attn_bias = layers.gather( + trg_src_attn_bias, index=parent_idx) + pre_pos = layers.elementwise_mul( + x=layers.fill_constant_batch_size_like( + input=pre_src_attn_bias, # cann't use lod tensor here + value=1, + shape=[-1, 1, 1], + dtype=pre_ids.dtype), + y=step_idx, + axis=0) + logits = wrap_decoder( + trg_vocab_size, + max_in_len, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + dec_inputs=(pre_ids, pre_pos, None, pre_src_attn_bias), + enc_output=enc_output, + caches=caches, + gather_idx=parent_idx, + bos_idx=bos_idx) + # intra-beam topK + topk_scores, topk_indices = layers.topk( + input=layers.softmax(logits), k=beam_size) + accu_scores = layers.elementwise_add( + x=layers.log(topk_scores), y=pre_scores, axis=0) + # beam_search op uses lod to differentiate branches. + accu_scores = layers.lod_reset(accu_scores, pre_ids) + # topK reduction across beams, also contain special handle of + # end beams and end sentences(batch reduction) + selected_ids, selected_scores, gather_idx = layers.beam_search( + pre_ids=pre_ids, + pre_scores=pre_scores, + ids=topk_indices, + scores=accu_scores, + beam_size=beam_size, + end_id=eos_idx, + return_parent_idx=True) + layers.increment(x=step_idx, value=1.0, in_place=True) + # cell states(caches) have been updated in wrap_decoder, + # only need to update beam search states here. + layers.array_write(selected_ids, i=step_idx, array=ids) + layers.array_write(selected_scores, i=step_idx, array=scores) + layers.assign(gather_idx, parent_idx) + layers.assign(pre_src_attn_bias, trg_src_attn_bias) + length_cond = layers.less_than(x=step_idx, y=max_len) + finish_cond = layers.logical_not(layers.is_empty(x=selected_ids)) + layers.logical_and(x=length_cond, y=finish_cond, out=cond) + + finished_ids, finished_scores = layers.beam_search_decode( + ids, scores, beam_size=beam_size, end_id=eos_idx) + return finished_ids, finished_scores + + finished_ids, finished_scores = beam_search() + return finished_ids, finished_scores, reader if use_py_reader else None diff --git a/PaddleNLP/Research/ACL2019-JEMT/reader.py b/PaddleNLP/Research/ACL2019-JEMT/reader.py new file mode 100644 index 00000000..26a486c8 --- /dev/null +++ b/PaddleNLP/Research/ACL2019-JEMT/reader.py @@ -0,0 +1,385 @@ +import glob +import six +import os +import random +import tarfile + +import numpy as np + + +class SortType(object): + GLOBAL = 'global' + POOL = 'pool' + NONE = "none" + + +class SrcConverter(object): + def __init__(self, vocab, end, unk, delimiter, lexicon): + self._vocab = vocab + self._end = end + self._unk = unk + self._delimiter = delimiter + self._lexicon = lexicon + + def __call__(self, sentence): + src_seqs = [] + src_ph_seqs = [] + unk_phs = self._lexicon[''] + for w in sentence.split(self._delimiter): + src_seqs.append(self._vocab.get(w, self._unk)) + ph_groups = self._lexicon.get(w, unk_phs) + src_ph_seqs.append(random.choice(ph_groups)) + src_seqs.append(self._end) + src_ph_seqs.append(unk_phs[0]) + + return src_seqs, src_ph_seqs + + +class TgtConverter(object): + def __init__(self, vocab, beg, end, unk, delimiter): + self._vocab = vocab + self._beg = beg + self._end = end + self._unk = unk + self._delimiter = delimiter + + def __call__(self, sentence): + return [self._beg] + [ + self._vocab.get(w, self._unk) + for w in sentence.split(self._delimiter) + ] + [self._end] + + +class ComposedConverter(object): + def __init__(self, converters): + self._converters = converters + + def __call__(self, parallel_sentence): + return [ + self._converters[i](parallel_sentence[i]) + for i in range(len(self._converters)) + ] + + +class SentenceBatchCreator(object): + def __init__(self, batch_size): + self.batch = [] + self._batch_size = batch_size + + def append(self, info): + self.batch.append(info) + if len(self.batch) == self._batch_size: + tmp = self.batch + self.batch = [] + return tmp + + +class TokenBatchCreator(object): + def __init__(self, batch_size): + self.batch = [] + self.max_len = -1 + self._batch_size = batch_size + + def append(self, info): + cur_len = info.max_len + max_len = max(self.max_len, cur_len) + if max_len * (len(self.batch) + 1) > self._batch_size: + result = self.batch + self.batch = [info] + self.max_len = cur_len + return result + else: + self.max_len = max_len + self.batch.append(info) + + +class SampleInfo(object): + def __init__(self, i, max_len, min_len): + self.i = i + self.min_len = min_len + self.max_len = max_len + + +class MinMaxFilter(object): + def __init__(self, max_len, min_len, underlying_creator): + self._min_len = min_len + self._max_len = max_len + self._creator = underlying_creator + + def append(self, info): + if info.max_len > self._max_len or info.min_len < self._min_len: + return + else: + return self._creator.append(info) + + @property + def batch(self): + return self._creator.batch + + +class DataReader(object): + """ + The data reader loads all data from files and produces batches of data + in the way corresponding to settings. + + An example of returning a generator producing data batches whose data + is shuffled in each pass and sorted in each pool: + + ``` + train_data = DataReader( + src_vocab_fpath='data/src_vocab_file', + trg_vocab_fpath='data/trg_vocab_file', + fpattern='data/part-*', + use_token_batch=True, + batch_size=2000, + pool_size=10000, + sort_type=SortType.POOL, + shuffle=True, + shuffle_batch=True, + start_mark='', + end_mark='', + unk_mark='', + clip_last_batch=False).batch_generator + ``` + + :param src_vocab_fpath: The path of vocabulary file of source language. + :type src_vocab_fpath: basestring + :param trg_vocab_fpath: The path of vocabulary file of target language. + :type trg_vocab_fpath: basestring + :param fpattern: The pattern to match data files. + :type fpattern: basestring + :param batch_size: The number of sequences contained in a mini-batch. + or the maximum number of tokens (include paddings) contained in a + mini-batch. + :type batch_size: int + :param pool_size: The size of pool buffer. + :type pool_size: int + :param sort_type: The grain to sort by length: 'global' for all + instances; 'pool' for instances in pool; 'none' for no sort. + :type sort_type: basestring + :param clip_last_batch: Whether to clip the last uncompleted batch. + :type clip_last_batch: bool + :param tar_fname: The data file in tar if fpattern matches a tar file. + :type tar_fname: basestring + :param min_length: The minimum length used to filt sequences. + :type min_length: int + :param max_length: The maximum length used to filt sequences. + :type max_length: int + :param shuffle: Whether to shuffle all instances. + :type shuffle: bool + :param shuffle_batch: Whether to shuffle the generated batches. + :type shuffle_batch: bool + :param use_token_batch: Whether to produce batch data according to + token number. + :type use_token_batch: bool + :param field_delimiter: The delimiter used to split source and target in + each line of data file. + :type field_delimiter: basestring + :param token_delimiter: The delimiter used to split tokens in source or + target sentences. + :type token_delimiter: basestring + :param start_mark: The token representing for the beginning of + sentences in dictionary. + :type start_mark: basestring + :param end_mark: The token representing for the end of sentences + in dictionary. + :type end_mark: basestring + :param unk_mark: The token representing for unknown word in dictionary. + :type unk_mark: basestring + :param seed: The seed for random. + :type seed: int + """ + + def __init__(self, + src_vocab_fpath, + trg_vocab_fpath, + fpattern, + phoneme_vocab_fpath, + lexicon_fpath, + batch_size, + pool_size, + sort_type=SortType.GLOBAL, + clip_last_batch=True, + tar_fname=None, + min_length=0, + max_length=100, + shuffle=True, + shuffle_batch=False, + use_token_batch=False, + field_delimiter="\t", + token_delimiter=" ", + start_mark="", + end_mark="", + unk_mark="", + seed=0): + self._src_vocab = self.load_dict(src_vocab_fpath) + self._only_src = True + if trg_vocab_fpath is not None: + self._trg_vocab = self.load_dict(trg_vocab_fpath) + self._only_src = False + + self._phoneme_vocab = self.load_dict(phoneme_vocab_fpath) + self._lexicon = self.load_lexicon(lexicon_fpath, self._phoneme_vocab) + + self._pool_size = pool_size + self._batch_size = batch_size + self._use_token_batch = use_token_batch + self._sort_type = sort_type + self._clip_last_batch = clip_last_batch + self._shuffle = shuffle + self._shuffle_batch = shuffle_batch + self._min_length = min_length + self._max_length = max_length + self._field_delimiter = field_delimiter + self._token_delimiter = token_delimiter + self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname, + unk_mark) + self._random = np.random + self._random.seed(seed) + + def load_lexicon(self, lexicon_path, phoneme_vocab): + lexicon = {} + with open(lexicon_path) as fp: + for line in fp: + tokens = line.strip().split() + word = tokens[0] + all_phone_str = ' '.join(tokens[1:]) + phone_strs = all_phone_str.split('|') + phone_groups = [] + for phone_str in phone_strs: + cur_phone_seq = [ + phoneme_vocab[x] for x in phone_str.split() + ] + phone_groups.append(cur_phone_seq) + lexicon[word] = phone_groups + + lexicon[''] = [[phoneme_vocab['']]] + return lexicon + + def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname, + unk_mark): + converters = [ + SrcConverter( + vocab=self._src_vocab, + end=self._src_vocab[end_mark], + unk=self._src_vocab[unk_mark], + delimiter=self._token_delimiter, + lexicon=self._lexicon) + ] + if not self._only_src: + converters.append( + TgtConverter( + vocab=self._trg_vocab, + beg=self._trg_vocab[start_mark], + end=self._trg_vocab[end_mark], + unk=self._trg_vocab[unk_mark], + delimiter=self._token_delimiter)) + + converters = ComposedConverter(converters) + + self._src_seq_ids = [] + self._src_phone_ids = [] + self._trg_seq_ids = None if self._only_src else [] + self._sample_infos = [] + + for i, line in enumerate(self._load_lines(fpattern, tar_fname)): + src_trg_ids = converters(line) + self._src_seq_ids.append(src_trg_ids[0][0]) + self._src_phone_ids.append(src_trg_ids[0][1]) + lens = [len(src_trg_ids[0][0])] + if not self._only_src: + self._trg_seq_ids.append(src_trg_ids[1]) + lens.append(len(src_trg_ids[1])) + self._sample_infos.append(SampleInfo(i, max(lens), min(lens))) + + def _load_lines(self, fpattern, tar_fname): + fpaths = glob.glob(fpattern) + + if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): + if tar_fname is None: + raise Exception("If tar file provided, please set tar_fname.") + + f = tarfile.open(fpaths[0], "r") + for line in f.extractfile(tar_fname): + fields = line.strip("\n").split(self._field_delimiter) + if (not self._only_src + and len(fields) == 2) or (self._only_src + and len(fields) == 1): + yield fields + else: + for fpath in fpaths: + if not os.path.isfile(fpath): + raise IOError("Invalid file: %s" % fpath) + + with open(fpath, "rb") as f: + for line in f: + if six.PY3: + line = line.decode() + fields = line.strip("\n").split(self._field_delimiter) + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): + yield fields + + @staticmethod + def load_dict(dict_path, reverse=False): + word_dict = {} + with open(dict_path, "rb") as fdict: + for idx, line in enumerate(fdict): + if six.PY3: + line = line.decode() + if reverse: + word_dict[idx] = line.strip("\n") + else: + word_dict[line.strip("\n")] = idx + return word_dict + + def batch_generator(self): + # global sort or global shuffle + if self._sort_type == SortType.GLOBAL: + infos = sorted(self._sample_infos, key=lambda x: x.max_len) + else: + if self._shuffle: + infos = self._sample_infos + self._random.shuffle(infos) + else: + infos = self._sample_infos + + if self._sort_type == SortType.POOL: + reverse = True + for i in range(0, len(infos), self._pool_size): + # to avoid placing short next to long sentences + reverse = not reverse + infos[i:i + self._pool_size] = sorted( + infos[i:i + self._pool_size], + key=lambda x: x.max_len, + reverse=reverse) + + # concat batch + batches = [] + batch_creator = TokenBatchCreator( + self._batch_size + ) if self._use_token_batch else SentenceBatchCreator(self._batch_size) + batch_creator = MinMaxFilter(self._max_length, self._min_length, + batch_creator) + + for info in infos: + batch = batch_creator.append(info) + if batch is not None: + batches.append(batch) + + if not self._clip_last_batch and len(batch_creator.batch) != 0: + batches.append(batch_creator.batch) + + if self._shuffle_batch: + self._random.shuffle(batches) + + for batch in batches: + batch_ids = [info.i for info in batch] + + if self._only_src: + yield [[(self._src_seq_ids[idx], self._src_phone_ids[idx])] + for idx in batch_ids] + else: + yield [(self._src_seq_ids[idx], self._src_phone_ids[idx], + self._trg_seq_ids[idx][:-1], + self._trg_seq_ids[idx][1:]) for idx in batch_ids] diff --git a/PaddleNLP/Research/ACL2019-JEMT/train.py b/PaddleNLP/Research/ACL2019-JEMT/train.py new file mode 100644 index 00000000..cfef76e4 --- /dev/null +++ b/PaddleNLP/Research/ACL2019-JEMT/train.py @@ -0,0 +1,826 @@ +import argparse +import ast +import copy +import logging +import multiprocessing +import os +import six +import sys +import time + +import numpy as np +import paddle.fluid as fluid + +import reader +from config import * +from desc import * +from model import transformer, position_encoding_init + + +def parse_args(): + parser = argparse.ArgumentParser("Training for Transformer.") + parser.add_argument( + "--src_vocab_fpath", + type=str, + required=True, + help="The path of vocabulary file of source language.") + parser.add_argument( + "--trg_vocab_fpath", + type=str, + required=True, + help="The path of vocabulary file of target language.") + parser.add_argument( + "--phoneme_vocab_fpath", + type=str, + required=True, + help="The path of vocabulary file of phonemes.") + parser.add_argument( + "--lexicon_fpath", + type=str, + required=True, + help="The path of lexicon of source language.") + parser.add_argument( + "--train_file_pattern", + type=str, + required=True, + help="The pattern to match training data files.") + parser.add_argument( + "--val_file_pattern", + type=str, + help="The pattern to match validation data files.") + parser.add_argument( + "--use_token_batch", + type=ast.literal_eval, + default=True, + help="The flag indicating whether to " + "produce batch data according to token number.") + parser.add_argument( + "--batch_size", + type=int, + default=4096, + help="The number of sequences contained in a mini-batch, or the maximum " + "number of tokens (include paddings) contained in a mini-batch. Note " + "that this represents the number on single device and the actual batch " + "size for multi-devices will multiply the device number.") + parser.add_argument( + "--pool_size", + type=int, + default=200000, + help="The buffer size to pool data.") + parser.add_argument( + "--sort_type", + default="pool", + choices=("global", "pool", "none"), + help="The grain to sort by length: global for all instances; pool for " + "instances in pool; none for no sort.") + parser.add_argument( + "--shuffle", + type=ast.literal_eval, + default=True, + help="The flag indicating whether to shuffle instances in each pass.") + parser.add_argument( + "--shuffle_batch", + type=ast.literal_eval, + default=True, + help="The flag indicating whether to shuffle the data batches.") + parser.add_argument( + "--special_token", + type=str, + default=["", "", ""], + nargs=3, + help="The , and tokens in the dictionary.") + parser.add_argument( + "--token_delimiter", + type=lambda x: str(x.encode().decode("unicode-escape")), + default=" ", + help="The delimiter used to split tokens in source or target sentences. " + "For EN-DE BPE data we provided, use spaces as token delimiter. ") + parser.add_argument( + 'opts', + help='See config.py for all options', + default=None, + nargs=argparse.REMAINDER) + parser.add_argument( + '--local', + type=ast.literal_eval, + default=True, + help='Whether to run as local mode.') + parser.add_argument( + '--device', + type=str, + default='GPU', + choices=['CPU', 'GPU'], + help="The device type.") + parser.add_argument( + '--update_method', + choices=("pserver", "nccl2"), + default="pserver", + help='Update method.') + parser.add_argument( + '--sync', type=ast.literal_eval, default=True, help="sync mode.") + parser.add_argument( + "--enable_ce", + type=ast.literal_eval, + default=False, + help="The flag indicating whether to run the task " + "for continuous evaluation.") + parser.add_argument( + "--use_mem_opt", + type=ast.literal_eval, + default=True, + help="The flag indicating whether to use memory optimization.") + parser.add_argument( + "--use_py_reader", + type=ast.literal_eval, + default=True, + help="The flag indicating whether to use py_reader.") + parser.add_argument( + "--fetch_steps", + type=int, + default=100, + help="The frequency to fetch and print output.") + + args = parser.parse_args() + # Append args related to dict + src_dict = reader.DataReader.load_dict(args.src_vocab_fpath) + trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath) + phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath) + dict_args = [ + "src_vocab_size", + str(len(src_dict)), "trg_vocab_size", + str(len(trg_dict)), "phone_vocab_size", + str(len(phone_dict)), "bos_idx", + str(src_dict[args.special_token[0]]), "eos_idx", + str(src_dict[args.special_token[1]]), "unk_idx", + str(src_dict[args.special_token[2]]) + ] + merge_cfg_from_list(args.opts + dict_args, + [TrainTaskConfig, ModelHyperParams]) + + return args + + +def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints, + current_endpoint): + assert (trainer_id >= 0 and len(worker_endpoints) > 1 + and current_endpoint in worker_endpoints) + eps = copy.deepcopy(worker_endpoints) + eps.remove(current_endpoint) + nccl_id_var = startup_prog.global_block().create_var( + name="NCCLID", persistable=True, type=fluid.core.VarDesc.VarType.RAW) + startup_prog.global_block().append_op( + type="gen_nccl_id", + inputs={}, + outputs={"NCCLID": nccl_id_var}, + attrs={ + "endpoint": current_endpoint, + "endpoint_list": eps, + "trainer_id": trainer_id + }) + return nccl_id_var + + +def pad_phoneme_data(phoneme_seqs, pad_idx, max_seq_len): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + ph_seq_lens = [] + for ps in phoneme_seqs: + cur_seq_lens = [len(x) for x in ps] + ph_seq_lens.append(max(cur_seq_lens)) + max_ph_seq_len = max(ph_seq_lens) + + batch_size = len(phoneme_seqs) + phoneme_data = pad_idx * np.ones( + (batch_size, max_seq_len, max_ph_seq_len), dtype=np.int64) + phoneme_mask = np.zeros((batch_size, max_seq_len, max_ph_seq_len), + dtype=np.int64) + + for i in range(batch_size): + cur_ph_seq = phoneme_seqs[i] + for j, cur_word_phs in enumerate(cur_ph_seq): + word_phs_len = len(cur_word_phs) + phoneme_data[i, j, :word_phs_len] = cur_word_phs + phoneme_mask[i, j, :word_phs_len] = 1 + + phoneme_data = np.reshape(phoneme_data, [batch_size, max_seq_len, -1, 1]) + + return phoneme_data, phoneme_mask, max_ph_seq_len + + +def pad_batch_data(insts, + pad_idx, + n_head, + is_target=False, + is_label=False, + return_attn_bias=True, + return_max_len=True, + return_num_token=False): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + return_list = [] + max_len = max(len(inst) for inst in insts) + # Any token included in dict can be used to pad, since the paddings' loss + # will be masked out by weights and make no effect on parameter gradients. + inst_data = np.array( + [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) + return_list += [inst_data.astype("int64").reshape([-1, 1])] + if is_label: # label weight + inst_weight = np.array([[1.] * len(inst) + [0.] * (max_len - len(inst)) + for inst in insts]) + return_list += [inst_weight.astype("float32").reshape([-1, 1])] + else: # position data + inst_pos = np.array([ + list(range(0, len(inst))) + [0] * (max_len - len(inst)) + for inst in insts + ]) + return_list += [inst_pos.astype("int64").reshape([-1, 1])] + if return_attn_bias: + if is_target: + # This is used to avoid attention on paddings and subsequent + # words. + slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, + max_len)) + slf_attn_bias_data = np.triu(slf_attn_bias_data, + 1).reshape([-1, 1, max_len, max_len]) + slf_attn_bias_data = np.tile(slf_attn_bias_data, + [1, n_head, 1, 1]) * [-1e9] + else: + # This is used to avoid attention on paddings. + slf_attn_bias_data = np.array( + [[0] * len(inst) + [-1e9] * (max_len - len(inst)) + for inst in insts]) + slf_attn_bias_data = np.tile( + slf_attn_bias_data.reshape([-1, 1, 1, max_len]), + [1, n_head, max_len, 1]) + return_list += [slf_attn_bias_data.astype("float32")] + if return_max_len: + return_list += [max_len] + if return_num_token: + num_token = 0 + for inst in insts: + num_token += len(inst) + return_list += [num_token] + return return_list if len(return_list) > 1 else return_list[0] + + +def prepare_batch_input(insts, data_input_names, src_pad_idx, phone_pad_idx, + trg_pad_idx, n_head, d_model): + """ + Put all padded data needed by training into a dict. + """ + src_word, src_pos, src_slf_attn_bias, src_max_len = pad_batch_data( + [inst[0] for inst in insts], src_pad_idx, n_head, is_target=False) + src_word = src_word.reshape(-1, src_max_len, 1) + src_pos = src_pos.reshape(-1, src_max_len, 1) + src_phone, src_phone_mask, max_phone_len = pad_phoneme_data( + [inst[1] for inst in insts], phone_pad_idx, src_max_len) + trg_word, trg_pos, trg_slf_attn_bias, trg_max_len = pad_batch_data( + [inst[2] for inst in insts], trg_pad_idx, n_head, is_target=True) + trg_word = trg_word.reshape(-1, trg_max_len, 1) + trg_pos = trg_pos.reshape(-1, trg_max_len, 1) + + trg_src_attn_bias = np.tile(src_slf_attn_bias[:, :, ::src_max_len, :], + [1, 1, trg_max_len, 1]).astype("float32") + + lbl_word, lbl_weight, num_token = pad_batch_data( + [inst[3] for inst in insts], + trg_pad_idx, + n_head, + is_target=False, + is_label=True, + return_attn_bias=False, + return_max_len=False, + return_num_token=True) + + data_input_dict = dict( + zip(data_input_names, [ + src_word, src_pos, src_slf_attn_bias, src_phone, src_phone_mask, + trg_word, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, lbl_word, + lbl_weight + ])) + + return data_input_dict, np.asarray([num_token], dtype="float32") + + +def prepare_data_generator(args, + is_test, + count, + pyreader, + py_reader_provider_wrapper, + place=None): + """ + Data generator wrapper for DataReader. If use py_reader, set the data + provider for py_reader + """ + data_reader = reader.DataReader( + phoneme_vocab_fpath=args.phoneme_vocab_fpath, + lexicon_fpath=args.lexicon_fpath, + fpattern=args.val_file_pattern if is_test else args.train_file_pattern, + src_vocab_fpath=args.src_vocab_fpath, + trg_vocab_fpath=args.trg_vocab_fpath, + token_delimiter=args.token_delimiter, + use_token_batch=args.use_token_batch, + batch_size=args.batch_size * (1 if args.use_token_batch else count), + pool_size=args.pool_size, + sort_type=args.sort_type, + shuffle=args.shuffle, + shuffle_batch=args.shuffle_batch, + start_mark=args.special_token[0], + end_mark=args.special_token[1], + unk_mark=args.special_token[2], + # count start and end tokens out + max_length=ModelHyperParams.max_length - 2, + clip_last_batch=False).batch_generator + + def stack(data_reader, count, clip_last=True): + def __impl__(): + res = [] + for item in data_reader(): + res.append(item) + if len(res) == count: + yield res + res = [] + if len(res) == count: + yield res + elif not clip_last: + data = [] + for item in res: + data += item + if len(data) > count: + inst_num_per_part = len(data) // count + yield [ + data[inst_num_per_part * i:inst_num_per_part * (i + 1)] + for i in range(count) + ] + + return __impl__ + + def split(data_reader, count): + def __impl__(): + for item in data_reader(): + inst_num_per_part = len(item) // count + for i in range(count): + yield item[inst_num_per_part * i:inst_num_per_part * + (i + 1)] + + return __impl__ + + if not args.use_token_batch: + # to make data on each device have similar token number + data_reader = split(data_reader, count) + if args.use_py_reader: + pyreader.decorate_tensor_provider( + py_reader_provider_wrapper(data_reader, place)) + data_reader = None + else: # Data generator for multi-devices + data_reader = stack(data_reader, count) + return data_reader + + +def prepare_feed_dict_list(data_generator, init_flag, count): + """ + Prepare the list of feed dict for multi-devices. + """ + feed_dict_list = [] + if data_generator is not None: # use_py_reader == False + data_input_names = encoder_data_input_fields + \ + decoder_data_input_fields[:-1] + label_data_input_fields + data = next(data_generator) + for idx, data_buffer in enumerate(data): + data_input_dict, num_token = prepare_batch_input( + data_buffer, data_input_names, ModelHyperParams.eos_idx, + ModelHyperParams.phone_pad_idx, ModelHyperParams.eos_idx, + ModelHyperParams.n_head, ModelHyperParams.d_model) + feed_dict_list.append(data_input_dict) + if init_flag: + for idx in range(count): + pos_enc_tables = dict() + for pos_enc_param_name in pos_enc_param_names: + pos_enc_tables[pos_enc_param_name] = position_encoding_init( + ModelHyperParams.max_length + 1, ModelHyperParams.d_model) + if len(feed_dict_list) <= idx: + feed_dict_list.append(pos_enc_tables) + else: + feed_dict_list[idx] = dict( + list(pos_enc_tables.items()) + + list(feed_dict_list[idx].items())) + return feed_dict_list if len(feed_dict_list) == count else None + + +def py_reader_provider_wrapper(data_reader, place): + """ + Data provider needed by fluid.layers.py_reader. + """ + + def py_reader_provider(): + data_input_names = encoder_data_input_fields + \ + decoder_data_input_fields[:-1] + label_data_input_fields + for batch_id, data in enumerate(data_reader()): + data_input_dict, num_token = prepare_batch_input( + data, data_input_names, ModelHyperParams.eos_idx, + ModelHyperParams.phone_pad_idx, ModelHyperParams.eos_idx, + ModelHyperParams.n_head, ModelHyperParams.d_model) + yield [data_input_dict[item] for item in data_input_names] + + return py_reader_provider + + +def test_context(exe, train_exe, dev_count): + # Context to do validation. + test_prog = fluid.Program() + startup_prog = fluid.Program() + if args.enable_ce: + test_prog.random_seed = 1000 + startup_prog.random_seed = 1000 + with fluid.program_guard(test_prog, startup_prog): + with fluid.unique_name.guard(): + sum_cost, avg_cost, predict, token_num, pyreader = transformer( + ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size, + ModelHyperParams.max_length + 1, + ModelHyperParams.n_layer, + ModelHyperParams.n_head, + ModelHyperParams.d_key, + ModelHyperParams.d_value, + ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, + ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, + ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, + ModelHyperParams.postprocess_cmd, + ModelHyperParams.weight_sharing, + TrainTaskConfig.label_smooth_eps, + use_py_reader=args.use_py_reader, + beta=ModelHyperParams.beta, + is_test=True) + test_prog = test_prog.clone(for_test=True) + test_data = prepare_data_generator( + args, + is_test=True, + count=dev_count, + pyreader=pyreader, + py_reader_provider_wrapper=py_reader_provider_wrapper) + + exe.run(startup_prog) # to init pyreader for testing + if TrainTaskConfig.ckpt_path: + fluid.io.load_persistables( + exe, TrainTaskConfig.ckpt_path, main_program=test_prog) + + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.use_experimental_executor = True + build_strategy = fluid.BuildStrategy() + test_exe = fluid.ParallelExecutor( + use_cuda=TrainTaskConfig.use_gpu, + main_program=test_prog, + build_strategy=build_strategy, + exec_strategy=exec_strategy, + share_vars_from=train_exe) + + def test(exe=test_exe, pyreader=pyreader): + test_total_cost = 0 + test_total_token = 0 + + if args.use_py_reader: + pyreader.start() + data_generator = None + else: + data_generator = test_data() + while True: + try: + feed_dict_list = prepare_feed_dict_list( + data_generator, False, dev_count) + outs = test_exe.run( + fetch_list=[sum_cost.name, token_num.name], + feed=feed_dict_list) + except (StopIteration, fluid.core.EOFException): + # The current pass is over. + if args.use_py_reader: + pyreader.reset() + break + sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1]) + test_total_cost += sum_cost_val.sum() + test_total_token += token_num_val.sum() + test_avg_cost = test_total_cost / test_total_token + test_ppl = np.exp([min(test_avg_cost, 100)]) + return test_avg_cost, test_ppl + + return test + + +def train_loop(exe, + train_prog, + startup_prog, + dev_count, + sum_cost, + avg_cost, + token_num, + predict, + pyreader, + nccl2_num_trainers=1, + nccl2_trainer_id=0): + # Initialize the parameters. + if TrainTaskConfig.ckpt_path: + exe.run(startup_prog) # to init pyreader for training + logging.info("load checkpoint from {}".format( + TrainTaskConfig.ckpt_path)) + fluid.io.load_persistables( + exe, TrainTaskConfig.ckpt_path, main_program=train_prog) + else: + logging.info("init fluid.framework.default_startup_program") + exe.run(startup_prog) + + logging.info("begin reader") + train_data = prepare_data_generator( + args, + is_test=False, + count=dev_count, + pyreader=pyreader, + py_reader_provider_wrapper=py_reader_provider_wrapper) + + # For faster executor + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.use_experimental_executor = True + exec_strategy.num_iteration_per_drop_scope = int(args.fetch_steps) + build_strategy = fluid.BuildStrategy() + # Since the token number differs among devices, customize gradient scale to + # use token average cost among multi-devices. and the gradient scale is + # `1 / token_number` for average cost. + # build_strategy.gradient_scale_strategy = fluid.BuildStrategy.GradientScaleStrategy.Customized + + logging.info("begin executor") + train_exe = fluid.ParallelExecutor( + use_cuda=TrainTaskConfig.use_gpu, + loss_name=avg_cost.name, + main_program=train_prog, + build_strategy=build_strategy, + exec_strategy=exec_strategy, + num_trainers=nccl2_num_trainers, + trainer_id=nccl2_trainer_id) + + if args.val_file_pattern is not None: + test = test_context(exe, train_exe, dev_count) + + # the best cross-entropy value with label smoothing + loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log( + (1. - TrainTaskConfig.label_smooth_eps)) + + TrainTaskConfig.label_smooth_eps * + np.log(TrainTaskConfig.label_smooth_eps / + (ModelHyperParams.trg_vocab_size - 1) + 1e-20)) + + step_idx = 0 + init_flag = True + + logging.info("begin train") + for pass_id in six.moves.xrange(TrainTaskConfig.pass_num): + pass_start_time = time.time() + + if args.use_py_reader: + pyreader.start() + data_generator = None + else: + data_generator = train_data() + + batch_id = 0 + while True: + try: + feed_dict_list = prepare_feed_dict_list( + data_generator, init_flag, dev_count) + outs = train_exe.run( + fetch_list=[sum_cost.name, token_num.name] + if step_idx % args.fetch_steps == 0 else [], + feed=feed_dict_list) + + if step_idx % args.fetch_steps == 0: + sum_cost_val, token_num_val = np.array(outs[0]), np.array( + outs[1]) + # sum the cost from multi-devices + total_sum_cost = sum_cost_val.sum() + total_token_num = token_num_val.sum() + total_avg_cost = total_sum_cost / total_token_num + + if step_idx == 0: + logging.info( + "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " + "normalized loss: %f, ppl: %f" % + (step_idx, pass_id, batch_id, total_avg_cost, + total_avg_cost - loss_normalizer, + np.exp([min(total_avg_cost, 100)]))) + avg_batch_time = time.time() + else: + logging.info( + "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " + "normalized loss: %f, ppl: %f, speed: %.2f step/s" + % (step_idx, pass_id, batch_id, total_avg_cost, + total_avg_cost - loss_normalizer, + np.exp([min(total_avg_cost, 100) + ]), args.fetch_steps / + (time.time() - avg_batch_time))) + avg_batch_time = time.time() + + if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0: + fluid.io.save_persistables( + exe, + os.path.join(TrainTaskConfig.ckpt_dir, + "latest.checkpoint"), train_prog) + fluid.io.save_params( + exe, + os.path.join(TrainTaskConfig.model_dir, + "iter_" + str(step_idx) + ".infer.model"), + train_prog) + + init_flag = False + batch_id += 1 + step_idx += 1 + except (StopIteration, fluid.core.EOFException): + # The current pass is over. + if args.use_py_reader: + pyreader.reset() + break + + time_consumed = time.time() - pass_start_time + # Validate and save the persistable. + if args.val_file_pattern is not None: + val_avg_cost, val_ppl = test() + logging.info( + "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f," + " consumed %fs" % (pass_id, val_avg_cost, val_avg_cost - + loss_normalizer, val_ppl, time_consumed)) + else: + logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed)) + if not args.enable_ce: + fluid.io.save_persistables( + exe, + os.path.join(TrainTaskConfig.ckpt_dir, + "pass_" + str(pass_id) + ".checkpoint"), + train_prog) + + if args.enable_ce: # For CE + print("kpis\ttrain_cost_card%d\t%f" % (dev_count, total_avg_cost)) + if args.val_file_pattern is not None: + print("kpis\ttest_cost_card%d\t%f" % (dev_count, val_avg_cost)) + print("kpis\ttrain_duration_card%d\t%f" % (dev_count, time_consumed)) + + +def train(args): + # priority: ENV > args > config + is_local = os.getenv("PADDLE_IS_LOCAL", "1") + if is_local == '0': + args.local = False + logging.info(args) + + if args.device == 'CPU': + TrainTaskConfig.use_gpu = False + + training_role = os.getenv("TRAINING_ROLE", "TRAINER") + + if training_role == "PSERVER" or (not TrainTaskConfig.use_gpu): + place = fluid.CPUPlace() + dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + else: + place = fluid.CUDAPlace(0) + dev_count = fluid.core.get_cuda_device_count() + + exe = fluid.Executor(place) + + train_prog = fluid.Program() + startup_prog = fluid.Program() + + if args.enable_ce: + train_prog.random_seed = 1000 + startup_prog.random_seed = 1000 + + with fluid.program_guard(train_prog, startup_prog): + with fluid.unique_name.guard(): + sum_cost, avg_cost, predict, token_num, pyreader = transformer( + ModelHyperParams.src_vocab_size, + ModelHyperParams.trg_vocab_size, + ModelHyperParams.phone_vocab_size, + ModelHyperParams.max_length + 1, + ModelHyperParams.n_layer, + ModelHyperParams.n_head, + ModelHyperParams.d_key, + ModelHyperParams.d_value, + ModelHyperParams.d_model, + ModelHyperParams.d_inner_hid, + ModelHyperParams.prepostprocess_dropout, + ModelHyperParams.attention_dropout, + ModelHyperParams.relu_dropout, + ModelHyperParams.preprocess_cmd, + ModelHyperParams.postprocess_cmd, + ModelHyperParams.weight_sharing, + TrainTaskConfig.label_smooth_eps, + ModelHyperParams.beta, + ModelHyperParams.bos_idx, + use_py_reader=args.use_py_reader, + is_test=False) + + optimizer = None + if args.sync: + lr_decay = fluid.layers.learning_rate_scheduler.noam_decay( + ModelHyperParams.d_model, TrainTaskConfig.warmup_steps) + logging.info("before adam") + + with fluid.default_main_program()._lr_schedule_guard(): + learning_rate = lr_decay * TrainTaskConfig.learning_rate + + optimizer = fluid.optimizer.Adam( + learning_rate=learning_rate, + beta1=TrainTaskConfig.beta1, + beta2=TrainTaskConfig.beta2, + epsilon=TrainTaskConfig.eps) + else: + optimizer = fluid.optimizer.SGD(0.003) + optimizer.minimize(avg_cost) + + if args.use_mem_opt: + fluid.memory_optimize(train_prog) + + if args.local: + logging.info("local start_up:") + train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, + avg_cost, token_num, predict, pyreader) + else: + if args.update_method == "nccl2": + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + port = os.getenv("PADDLE_PORT") + worker_ips = os.getenv("PADDLE_TRAINERS") + worker_endpoints = [] + for ip in worker_ips.split(","): + worker_endpoints.append(':'.join([ip, port])) + trainers_num = len(worker_endpoints) + current_endpoint = os.getenv("POD_IP") + ":" + port + if trainer_id == 0: + logging.info("train_id == 0, sleep 60s") + time.sleep(60) + logging.info("trainers_num:{}".format(trainers_num)) + logging.info("worker_endpoints:{}".format(worker_endpoints)) + logging.info("current_endpoint:{}".format(current_endpoint)) + append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints, + current_endpoint) + train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, + avg_cost, token_num, predict, pyreader, trainers_num, + trainer_id) + return + + port = os.getenv("PADDLE_PORT", "6174") + pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip... + eplist = [] + for ip in pserver_ips.split(","): + eplist.append(':'.join([ip, port])) + pserver_endpoints = ",".join(eplist) # ip:port,ip:port... + trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "0")) + current_endpoint = os.getenv("POD_IP") + ":" + port + trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) + + logging.info("pserver_endpoints:{}".format(pserver_endpoints)) + logging.info("current_endpoint:{}".format(current_endpoint)) + logging.info("trainer_id:{}".format(trainer_id)) + logging.info("pserver_ips:{}".format(pserver_ips)) + logging.info("port:{}".format(port)) + + t = fluid.DistributeTranspiler() + t.transpile( + trainer_id, + pservers=pserver_endpoints, + trainers=trainers, + program=train_prog, + startup_program=startup_prog) + + if training_role == "PSERVER": + logging.info("distributed: pserver started") + current_endpoint = os.getenv("POD_IP") + ":" + os.getenv( + "PADDLE_PORT") + if not current_endpoint: + logging.critical("need env SERVER_ENDPOINT") + exit(1) + pserver_prog = t.get_pserver_program(current_endpoint) + pserver_startup = t.get_startup_program(current_endpoint, + pserver_prog) + + exe.run(pserver_startup) + exe.run(pserver_prog) + elif training_role == "TRAINER": + logging.info("distributed: trainer started") + trainer_prog = t.get_trainer_program() + + train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, + avg_cost, token_num, predict, pyreader) + else: + logging.critical( + "environment var TRAINER_ROLE should be TRAINER os PSERVER") + exit(1) + + +if __name__ == "__main__": + LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s" + logging.basicConfig( + stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT) + logging.getLogger().setLevel(logging.INFO) + + args = parse_args() + train(args) -- GitLab