From 6e375fda77b7eda9b44110a012cdfe180f80738e Mon Sep 17 00:00:00 2001 From: Xing Wu <1160386409@qq.com> Date: Mon, 28 Oct 2019 14:52:54 +0800 Subject: [PATCH] PaddleTextGEN Cherry pick from 1.6 (#3785) * add new seq2seq example: rnn_search (#3574) * re-commit seq2eq example: rnn search * add new_seq2seq examples: rnn_search and vae_text * modify dataset description in vae_text/README.md * modify vae_text/README.md and vae_text/download.py * update vae_text/infer.sh and vae_text/README.md * add multi-gpu support for rnn_search; fix windows support for rnn_search and vae_text * remove parallel argument in args.py * remove hard codes * remove hard codes * update download.py for windows download * fix model path for windows in vae_text/infer.py * fix python2 utf-encoding * fix utf-8 error in python2 when predict (#3721) * change rnn_search to seq2seq; change vae_text to variational seq2seq (#3734) * re-commit seq2eq example: rnn search * add new_seq2seq examples: rnn_search and vae_text * modify dataset description in vae_text/README.md * modify vae_text/README.md and vae_text/download.py * update vae_text/infer.sh and vae_text/README.md * add multi-gpu support for rnn_search; fix windows support for rnn_search and vae_text * remove parallel argument in args.py * remove hard codes * remove hard codes * update download.py for windows download * fix model path for windows in vae_text/infer.py * fix python2 utf-encoding * change rnn_search to seq2seq; change vae_text to variational seq2seq * change rnn_search to seq2seq; change vae_text to variational seq2seq * mv variational seq2seq to variational_seq2seq * remove old projects (#3736) * fix use_gpu=False error (#3775) * update_seq2seq_readme (#3783) --- .../dgu/utils/py23.py | 25 ++ .../dialogue_general_understanding/predict.py | 68 ++-- PaddleNLP/PaddleTextGEN/seq2seq/README.md | 127 ++++++ PaddleNLP/PaddleTextGEN/seq2seq/__init__.py | 0 PaddleNLP/PaddleTextGEN/seq2seq/args.py | 127 ++++++ .../PaddleTextGEN/seq2seq/attention_model.py | 156 ++++++++ PaddleNLP/PaddleTextGEN/seq2seq/base_model.py | 226 +++++++++++ PaddleNLP/PaddleTextGEN/seq2seq/download.py | 55 +++ PaddleNLP/PaddleTextGEN/seq2seq/infer.py | 184 +++++++++ PaddleNLP/PaddleTextGEN/seq2seq/infer.sh | 22 ++ PaddleNLP/PaddleTextGEN/seq2seq/reader.py | 211 ++++++++++ PaddleNLP/PaddleTextGEN/seq2seq/run.sh | 20 + PaddleNLP/PaddleTextGEN/seq2seq/train.py | 277 +++++++++++++ .../variational_seq2seq/README.md | 107 +++++ .../variational_seq2seq/__init__.py | 0 .../PaddleTextGEN/variational_seq2seq/args.py | 163 ++++++++ .../variational_seq2seq/download.py | 92 +++++ .../variational_seq2seq/infer.py | 128 ++++++ .../variational_seq2seq/infer.sh | 16 + .../variational_seq2seq/model.py | 369 ++++++++++++++++++ .../variational_seq2seq/reader.py | 206 ++++++++++ .../PaddleTextGEN/variational_seq2seq/run.sh | 15 + .../variational_seq2seq/train.py | 313 +++++++++++++++ 23 files changed, 2873 insertions(+), 34 deletions(-) create mode 100644 PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/utils/py23.py create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/README.md create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/__init__.py create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/args.py create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/attention_model.py create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/base_model.py create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/download.py create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/infer.py create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/infer.sh create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/reader.py create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/run.sh create mode 100644 PaddleNLP/PaddleTextGEN/seq2seq/train.py create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/README.md create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/__init__.py create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/args.py create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/download.py create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/infer.py create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/infer.sh create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/model.py create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/reader.py create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/run.sh create mode 100644 PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py diff --git a/PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/utils/py23.py b/PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/utils/py23.py new file mode 100644 index 00000000..0d84ddfa --- /dev/null +++ b/PaddleNLP/PaddleDialogue/dialogue_general_understanding/dgu/utils/py23.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +if sys.version[0] == '2': + rt_tok = u'\n' + tab_tok = u'\t' + space_tok = u' ' +else: + rt_tok = '\n' + tab_tok = '\t' + space_tok = ' ' diff --git a/PaddleNLP/PaddleDialogue/dialogue_general_understanding/predict.py b/PaddleNLP/PaddleDialogue/dialogue_general_understanding/predict.py index bab34ed2..8cc64f1b 100644 --- a/PaddleNLP/PaddleDialogue/dialogue_general_understanding/predict.py +++ b/PaddleNLP/PaddleDialogue/dialogue_general_understanding/predict.py @@ -24,16 +24,17 @@ import paddle.fluid as fluid import dgu.reader as reader from dgu_net import create_net -import dgu.define_paradigm as define_paradigm +import dgu.define_paradigm as define_paradigm import dgu.define_predict_pack as define_predict_pack from dgu.utils.configure import PDConfig from dgu.utils.input_field import InputField from dgu.utils.model_check import check_cuda import dgu.utils.save_load_io as save_load_io +from dgu.utils.py23 import tab_tok, rt_tok -def do_predict(args): +def do_predict(args): """predict function""" task_name = args.task_name.lower() @@ -63,34 +64,35 @@ def do_predict(args): num_labels = len(processors[task_name].get_labels()) src_ids = fluid.data( - name='src_ids', shape=[-1, args.max_seq_len], dtype='int64') + name='src_ids', shape=[-1, args.max_seq_len], dtype='int64') pos_ids = fluid.data( - name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64') + name='pos_ids', shape=[-1, args.max_seq_len], dtype='int64') sent_ids = fluid.data( - name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64') + name='sent_ids', shape=[-1, args.max_seq_len], dtype='int64') input_mask = fluid.data( - name='input_mask', shape=[-1, args.max_seq_len], dtype='float32') - if args.task_name == 'atis_slot': + name='input_mask', + shape=[-1, args.max_seq_len], + dtype='float32') + if args.task_name == 'atis_slot': labels = fluid.data( - name='labels', shape=[-1, args.max_seq_len], dtype='int64') + name='labels', shape=[-1, args.max_seq_len], dtype='int64') elif args.task_name in ['dstc2', 'dstc2_asr', 'multi-woz']: labels = fluid.data( - name='labels', shape=[-1, num_labels], dtype='int64') - else: - labels = fluid.data( - name='labels', shape=[-1, 1], dtype='int64') - + name='labels', shape=[-1, num_labels], dtype='int64') + else: + labels = fluid.data(name='labels', shape=[-1, 1], dtype='int64') + input_inst = [src_ids, pos_ids, sent_ids, input_mask, labels] input_field = InputField(input_inst) - data_reader = fluid.io.PyReader(feed_list=input_inst, - capacity=4, iterable=False) - + data_reader = fluid.io.PyReader( + feed_list=input_inst, capacity=4, iterable=False) + results = create_net( - is_training=False, - model_input=input_field, - num_labels=num_labels, - paradigm_inst=paradigm_inst, - args=args) + is_training=False, + model_input=input_field, + num_labels=num_labels, + paradigm_inst=paradigm_inst, + args=args) probs = results.get("probs", None) @@ -117,7 +119,7 @@ def do_predict(args): save_load_io.init_from_pretrain_model(args, exe, test_prog) compiled_test_prog = fluid.CompiledProgram(test_prog) - + processor = processors[task_name](data_dir=args.data_dir, vocab_path=args.vocab_path, max_seq_len=args.max_seq_len, @@ -126,34 +128,32 @@ def do_predict(args): task_name=task_name, random_seed=args.random_seed) batch_generator = processor.data_generator( - batch_size=args.batch_size, - phase='test', - shuffle=False) + batch_size=args.batch_size, phase='test', shuffle=False) - data_reader.decorate_batch_generator(batch_generator) + data_reader.decorate_batch_generator(batch_generator) data_reader.start() - + all_results = [] - while True: - try: + while True: + try: results = exe.run(compiled_test_prog, fetch_list=fetch_list) all_results.extend(results[0]) - except fluid.core.EOFException: + except fluid.core.EOFException: data_reader.reset() break np.set_printoptions(precision=4, suppress=True) print("Write the predicted results into the output_prediction_file") - + fw = io.open(args.output_prediction_file, 'w', encoding="utf8") - if task_name not in ['atis_slot']: + if task_name not in ['atis_slot']: for index, result in enumerate(all_results): tags = pred_func(result) - fw.write("%s\t%s\n" % (index, tags)) + fw.write("%s%s%s%s" % (index, tab_tok, tags, rt_tok)) else: tags = pred_func(all_results, args.max_seq_len) for index, tag in enumerate(tags): - fw.write("%s\t%s\n" % (index, tag)) + fw.write("%s%s%s%s" % (index, tab_tok, tag, rt_tok)) if __name__ == "__main__": diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/README.md b/PaddleNLP/PaddleTextGEN/seq2seq/README.md new file mode 100644 index 00000000..ad0fa962 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/README.md @@ -0,0 +1,127 @@ +运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.6版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。 + +# Sequence to Sequence (Seq2Seq) + +以下是本范例模型的简要目录结构及说明: + +``` +. +├── README.md # 文档,本文件 +├── args.py # 训练、预测以及模型参数配置程序 +├── reader.py # 数据读入程序 +├── download.py # 数据下载程序 +├── train.py # 训练主程序 +├── infer.py # 预测主程序 +├── run.sh # 默认配置的启动脚本 +├── infer.sh # 默认配置的解码脚本 +├── attention_model.py # 带注意力机制的翻译模型程序 +└── base_model.py # 无注意力机制的翻译模型程序 +``` + +## 简介 + +Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。 + +本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考[深度学习101](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/basics/machine_translation/index.html)。 + +**本目录旨在展示如何用Paddle Fluid 1.6 新增的Seq2Seq API** ,新 Seq2Seq API 组网更简单,从1.6版本开始不推荐使用low-level的API。如果您确实需要使用low-level的API来实现自己模型,样例可参看1.5版本的样例 [RNN Search](https://github.com/PaddlePaddle/models/tree/release/1.5/PaddleNLP/unarchived/neural_machine_translation/rnn_search)。 + +## 模型概览 + +本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。以下将分别介绍用到的这些方法。 + +## 数据介绍 + +本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集 + +### 数据获取 + +``` +python download.py +``` + +## 模型训练 + +`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行: + +``` +sh run.sh +``` + +默认使用带有注意力机制的RNN模型,可以通过修改 --attention 为False来训练不带注意力机制的RNN模型。 + +``` +python train.py \ + --src_lang en --tar_lang vi \ + --attention True \ + --num_layers 2 \ + --hidden_size 512 \ + --src_vocab_size 17191 \ + --tar_vocab_size 7709 \ + --batch_size 128 \ + --dropout 0.2 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --train_data_prefix data/en-vi/train \ + --eval_data_prefix data/en-vi/tst2012 \ + --test_data_prefix data/en-vi/tst2013 \ + --vocab_prefix data/en-vi/vocab \ + --use_gpu True \ + --model_path ./attention_models +``` + +训练程序会在每个epoch训练结束之后,save一次模型。 + +## 模型预测 + +当模型训练完成之后, 可以利用infer.sh的脚本进行预测,默认使用beam search的方法进行预测,加载第10个epoch的模型进行预测,对test的数据集进行解码 + +``` +sh infer.sh +``` + +如果想预测别的数据文件,只需要将 --infer_file参数进行修改。 + +``` +python infer.py \ + --attention True \ + --src_lang en --tar_lang vi \ + --num_layers 2 \ + --hidden_size 512 \ + --src_vocab_size 17191 \ + --tar_vocab_size 7709 \ + --batch_size 128 \ + --dropout 0.2 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --vocab_prefix data/en-vi/vocab \ + --infer_file data/en-vi/tst2013.en \ + --reload_model attention_models/epoch_10/ \ + --infer_output_file attention_infer_output/infer_output.txt \ + --beam_size 10 \ + --use_gpu True +``` + +## 效果评价 + +使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下: + +```sh +mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt +``` + +每个模型分别训练了10次,单次取第10个epoch保存的模型进行预测,取beam_size=10。效果如下(为了便于观察,对10次结果按照升序进行了排序): + +``` +> no attention +tst2012 BLEU: +[10.75 10.85 10.9 10.94 10.97 11.01 11.01 11.04 11.13 11.4] +tst2013 BLEU: +[10.71 10.71 10.74 10.76 10.91 10.94 11.02 11.16 11.21 11.44] + +> with attention +tst2012 BLEU: +[21.14 22.34 22.54 22.65 22.71 22.71 23.08 23.15 23.3 23.4] +tst2013 BLEU: +[23.41 24.79 25.11 25.12 25.19 25.24 25.39 25.61 25.61 25.63] +``` diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/__init__.py b/PaddleNLP/PaddleTextGEN/seq2seq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/args.py b/PaddleNLP/PaddleTextGEN/seq2seq/args.py new file mode 100644 index 00000000..ee056e33 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/args.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import distutils.util + + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--train_data_prefix", type=str, help="file prefix for train data") + parser.add_argument( + "--eval_data_prefix", type=str, help="file prefix for eval data") + parser.add_argument( + "--test_data_prefix", type=str, help="file prefix for test data") + parser.add_argument( + "--vocab_prefix", type=str, help="file prefix for vocab") + parser.add_argument("--src_lang", type=str, help="source language suffix") + parser.add_argument("--tar_lang", type=str, help="target language suffix") + + parser.add_argument( + "--attention", + type=eval, + default=False, + help="Whether use attention model") + + parser.add_argument( + "--optimizer", + type=str, + default='adam', + help="optimizer to use, only supprt[sgd|adam]") + + parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="learning rate for optimizer") + + parser.add_argument( + "--num_layers", + type=int, + default=1, + help="layers number of encoder and decoder") + parser.add_argument( + "--hidden_size", + type=int, + default=100, + help="hidden size of encoder and decoder") + parser.add_argument("--src_vocab_size", type=int, help="source vocab size") + parser.add_argument("--tar_vocab_size", type=int, help="target vocab size") + + parser.add_argument( + "--batch_size", type=int, help="batch size of each step") + + parser.add_argument( + "--max_epoch", type=int, default=12, help="max epoch for the training") + + parser.add_argument( + "--max_len", + type=int, + default=50, + help="max length for source and target sentence") + parser.add_argument( + "--dropout", type=float, default=0.0, help="drop probability") + parser.add_argument( + "--init_scale", + type=float, + default=0.0, + help="init scale for parameter") + parser.add_argument( + "--max_grad_norm", + type=float, + default=5.0, + help="max grad norm for global norm clip") + + parser.add_argument( + "--model_path", + type=str, + default='model', + help="model path for model to save") + + parser.add_argument( + "--reload_model", type=str, help="reload model to inference") + + parser.add_argument( + "--infer_file", type=str, help="file name for inference") + parser.add_argument( + "--infer_output_file", + type=str, + default='infer_output', + help="file name for inference output") + parser.add_argument( + "--beam_size", type=int, default=10, help="file name for inference") + + parser.add_argument( + '--use_gpu', + type=eval, + default=False, + help='Whether using gpu [True|False]') + + parser.add_argument( + "--enable_ce", + action='store_true', + help="The flag indicating whether to run the task " + "for continuous evaluation.") + + parser.add_argument( + "--profile", action='store_true', help="Whether enable the profile.") + + args = parser.parse_args() + return args diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/attention_model.py b/PaddleNLP/PaddleTextGEN/seq2seq/attention_model.py new file mode 100644 index 00000000..4f53aa97 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/attention_model.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid.layers as layers +import paddle.fluid as fluid +import numpy as np +from paddle.fluid import ParamAttr +from paddle.fluid.contrib.layers import basic_lstm, BasicLSTMUnit +from base_model import BaseModel, DecoderCell +from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode + +INF = 1. * 1e5 +alpha = 0.6 + + +class AttentionDecoderCell(DecoderCell): + def __init__(self, num_layers, hidden_size, dropout_prob=0., + init_scale=0.1): + super(AttentionDecoderCell, self).__init__(num_layers, hidden_size, + dropout_prob, init_scale) + + def attention(self, query, enc_output, mask=None): + query = layers.unsqueeze(query, [1]) + memory = layers.fc(enc_output, + self.hidden_size, + num_flatten_dims=2, + param_attr=ParamAttr( + initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale)), + bias_attr=False) + attn = layers.matmul(query, memory, transpose_y=True) + + if mask: + attn = layers.transpose(attn, [1, 0, 2]) + attn = layers.elementwise_add(attn, mask * 1000000000, -1) + attn = layers.transpose(attn, [1, 0, 2]) + weight = layers.softmax(attn) + weight_memory = layers.matmul(weight, memory) + + return weight_memory + + def call(self, step_input, states, enc_output, enc_padding_mask=None): + lstm_states, input_feed = states + new_lstm_states = [] + step_input = layers.concat([step_input, input_feed], 1) + for i in range(self.num_layers): + out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i]) + step_input = layers.dropout( + out, + self.dropout_prob, + dropout_implementation='upscale_in_train' + ) if self.dropout_prob > 0 else out + new_lstm_states.append(new_lstm_state) + dec_att = self.attention(step_input, enc_output, enc_padding_mask) + dec_att = layers.squeeze(dec_att, [1]) + concat_att_out = layers.concat([dec_att, step_input], 1) + out = layers.fc(concat_att_out, + self.hidden_size, + param_attr=ParamAttr( + initializer=fluid.initializer.UniformInitializer( + low=-self.init_scale, high=self.init_scale)), + bias_attr=False) + return out, [new_lstm_states, out] + + +class AttentionModel(BaseModel): + def __init__(self, + hidden_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=1, + init_scale=0.1, + dropout=None, + beam_start_token=1, + beam_end_token=2, + beam_max_step_num=100): + super(AttentionModel, self).__init__( + hidden_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=num_layers, + init_scale=init_scale, + dropout=dropout) + + def _build_decoder(self, enc_final_state, mode='train', beam_size=10): + output_layer = lambda x: layers.fc(x, + size=self.tar_vocab_size, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name="output_w", + initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale)), + bias_attr=False) + + dec_cell = AttentionDecoderCell(self.num_layers, self.hidden_size, + self.dropout, self.init_scale) + dec_initial_states = [ + enc_final_state, dec_cell.get_initial_states( + batch_ref=self.enc_output, shape=[self.hidden_size]) + ] + max_src_seq_len = layers.shape(self.src)[1] + src_mask = layers.sequence_mask( + self.src_sequence_length, maxlen=max_src_seq_len, dtype='float32') + enc_padding_mask = (src_mask - 1.0) + if mode == 'train': + dec_output, _ = rnn(cell=dec_cell, + inputs=self.tar_emb, + initial_states=dec_initial_states, + sequence_length=None, + enc_output=self.enc_output, + enc_padding_mask=enc_padding_mask) + + dec_output = output_layer(dec_output) + + elif mode == 'beam_search': + output_layer = lambda x: layers.fc(x, + size=self.tar_vocab_size, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name="output_w"), + bias_attr=False) + beam_search_decoder = BeamSearchDecoder( + dec_cell, + self.beam_start_token, + self.beam_end_token, + beam_size, + embedding_fn=self.tar_embeder, + output_fn=output_layer) + enc_output = beam_search_decoder.tile_beam_merge_with_batch( + self.enc_output, beam_size) + enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch( + enc_padding_mask, beam_size) + outputs, _ = dynamic_decode( + beam_search_decoder, + inits=dec_initial_states, + max_step_num=self.beam_max_step_num, + enc_output=enc_output, + enc_padding_mask=enc_padding_mask) + return outputs + + return dec_output diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/base_model.py b/PaddleNLP/PaddleTextGEN/seq2seq/base_model.py new file mode 100644 index 00000000..f4ee1b95 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/base_model.py @@ -0,0 +1,226 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid.layers as layers +import paddle.fluid as fluid +import numpy as np +from paddle.fluid import ParamAttr +from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode + +INF = 1. * 1e5 +alpha = 0.6 +uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x) +zero_constant = fluid.initializer.Constant(0.0) + + +class EncoderCell(RNNCell): + def __init__( + self, + num_layers, + hidden_size, + dropout_prob=0., + init_scale=0.1, ): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + self.lstm_cells = [] + + param_attr = ParamAttr(initializer=uniform_initializer(init_scale)) + bias_attr = ParamAttr(initializer=zero_constant) + for i in range(num_layers): + self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr)) + + def call(self, step_input, states): + new_states = [] + for i in range(self.num_layers): + out, new_state = self.lstm_cells[i](step_input, states[i]) + step_input = layers.dropout( + out, + self.dropout_prob, + dropout_implementation="upscale_in_train" + ) if self.dropout_prob > 0. else out + new_states.append(new_state) + return step_input, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.lstm_cells] + + +class DecoderCell(RNNCell): + def __init__(self, num_layers, hidden_size, dropout_prob=0., + init_scale=0.1): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + self.lstm_cells = [] + self.init_scale = init_scale + param_attr = ParamAttr(initializer=uniform_initializer(init_scale)) + bias_attr = ParamAttr(initializer=zero_constant) + for i in range(num_layers): + self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr)) + + def call(self, step_input, states): + new_lstm_states = [] + for i in range(self.num_layers): + out, new_lstm_state = self.lstm_cells[i](step_input, states[i]) + step_input = layers.dropout( + out, + self.dropout_prob, + dropout_implementation="upscale_in_train" + ) if self.dropout_prob > 0. else out + new_lstm_states.append(new_lstm_state) + return step_input, new_lstm_states + + +class BaseModel(object): + def __init__(self, + hidden_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=1, + init_scale=0.1, + dropout=None, + beam_start_token=1, + beam_end_token=2, + beam_max_step_num=100): + + self.hidden_size = hidden_size + self.src_vocab_size = src_vocab_size + self.tar_vocab_size = tar_vocab_size + self.batch_size = batch_size + self.num_layers = num_layers + self.init_scale = init_scale + self.dropout = dropout + self.beam_start_token = beam_start_token + self.beam_end_token = beam_end_token + self.beam_max_step_num = beam_max_step_num + self.src_embeder = lambda x: fluid.embedding( + input=x, + size=[self.src_vocab_size, self.hidden_size], + dtype='float32', + is_sparse=False, + param_attr=fluid.ParamAttr( + name='source_embedding', + initializer=uniform_initializer(init_scale))) + + self.tar_embeder = lambda x: fluid.embedding( + input=x, + size=[self.tar_vocab_size, self.hidden_size], + dtype='float32', + is_sparse=False, + param_attr=fluid.ParamAttr( + name='target_embedding', + initializer=uniform_initializer(init_scale))) + + def _build_data(self): + self.src = fluid.data(name="src", shape=[None, None], dtype='int64') + self.src_sequence_length = fluid.data( + name="src_sequence_length", shape=[None], dtype='int32') + + self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64') + self.tar_sequence_length = fluid.data( + name="tar_sequence_length", shape=[None], dtype='int32') + self.label = fluid.data( + name="label", shape=[None, None, 1], dtype='int64') + + def _emebdding(self): + self.src_emb = self.src_embeder(self.src) + self.tar_emb = self.tar_embeder(self.tar) + + def _build_encoder(self): + enc_cell = EncoderCell(self.num_layers, self.hidden_size, self.dropout, + self.init_scale) + self.enc_output, enc_final_state = rnn( + cell=enc_cell, + inputs=self.src_emb, + sequence_length=self.src_sequence_length) + return self.enc_output, enc_final_state + + def _build_decoder(self, enc_final_state, mode='train', beam_size=10): + + dec_cell = DecoderCell(self.num_layers, self.hidden_size, self.dropout, + self.init_scale) + output_layer = lambda x: layers.fc(x, + size=self.tar_vocab_size, + num_flatten_dims=len(x.shape) - 1, + param_attr=fluid.ParamAttr(name="output_w", + initializer=uniform_initializer(self.init_scale)), + bias_attr=False) + + if mode == 'train': + dec_output, dec_final_state = rnn(cell=dec_cell, + inputs=self.tar_emb, + initial_states=enc_final_state) + + dec_output = output_layer(dec_output) + + return dec_output + elif mode == 'beam_search': + beam_search_decoder = BeamSearchDecoder( + dec_cell, + self.beam_start_token, + self.beam_end_token, + beam_size, + embedding_fn=self.tar_embeder, + output_fn=output_layer) + + outputs, _ = dynamic_decode( + beam_search_decoder, + inits=enc_final_state, + max_step_num=self.beam_max_step_num) + return outputs + + def _compute_loss(self, dec_output): + loss = layers.softmax_with_cross_entropy( + logits=dec_output, label=self.label, soft_label=False) + loss = layers.unsqueeze(loss, axes=[2]) + + max_tar_seq_len = layers.shape(self.tar)[1] + tar_mask = layers.sequence_mask( + self.tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32') + loss = loss * tar_mask + loss = layers.reduce_mean(loss, dim=[0]) + loss = layers.reduce_sum(loss) + return loss + + def _beam_search(self, enc_last_hidden, enc_last_cell): + pass + + def build_graph(self, mode='train', beam_size=10): + if mode == 'train' or mode == 'eval': + self._build_data() + self._emebdding() + enc_output, enc_final_state = self._build_encoder() + dec_output = self._build_decoder(enc_final_state) + + loss = self._compute_loss(dec_output) + return loss + elif mode == "beam_search" or mode == 'greedy_search': + self._build_data() + self._emebdding() + enc_output, enc_final_state = self._build_encoder() + dec_output = self._build_decoder( + enc_final_state, mode=mode, beam_size=beam_size) + + return dec_output + else: + print("not support mode ", mode) + raise Exception("not support mode: " + mode) diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/download.py b/PaddleNLP/PaddleTextGEN/seq2seq/download.py new file mode 100644 index 00000000..4dd1466d --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/download.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Script for downloading training data. +''' +import os +import urllib +import sys + +if sys.version_info >= (3, 0): + import urllib.request +import zipfile + +URLLIB = urllib +if sys.version_info >= (3, 0): + URLLIB = urllib.request + +remote_path = 'https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi' +base_path = 'data' +tar_path = os.path.join(base_path, 'en-vi') +filenames = [ + 'train.en', 'train.vi', 'tst2012.en', 'tst2012.vi', 'tst2013.en', + 'tst2013.vi', 'vocab.en', 'vocab.vi' +] + + +def main(arguments): + print("Downloading data......") + + if not os.path.exists(tar_path): + if not os.path.exists(base_path): + os.mkdir(base_path) + os.mkdir(tar_path) + + for filename in filenames: + url = remote_path + '/' + filename + tar_file = os.path.join(tar_path, filename) + URLLIB.urlretrieve(url, tar_file) + print("Downloaded sucess......") + + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/infer.py b/PaddleNLP/PaddleTextGEN/seq2seq/infer.py new file mode 100644 index 00000000..47240429 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/infer.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import time +import os +import random +import logging +import math +import io +import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework +from paddle.fluid.executor import Executor + +import reader + +import sys +line_tok = '\n' +space_tok = ' ' +if sys.version[0] == '2': + reload(sys) + sys.setdefaultencoding("utf-8") + line_tok = u'\n' + space_tok = u' ' + +logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger("fluid") +logger.setLevel(logging.INFO) + +from args import * +import logging +import pickle + +from attention_model import AttentionModel +from base_model import BaseModel + + +def infer(): + args = parse_args() + + num_layers = args.num_layers + src_vocab_size = args.src_vocab_size + tar_vocab_size = args.tar_vocab_size + batch_size = args.batch_size + dropout = args.dropout + init_scale = args.init_scale + max_grad_norm = args.max_grad_norm + hidden_size = args.hidden_size + # inference process + + print("src", src_vocab_size) + + # dropout type using upscale_in_train, dropout can be remove in inferecen + # So we can set dropout to 0 + if args.attention: + model = AttentionModel( + hidden_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=num_layers, + init_scale=init_scale, + dropout=0.0) + else: + model = BaseModel( + hidden_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=num_layers, + init_scale=init_scale, + dropout=0.0) + + beam_size = args.beam_size + trans_res = model.build_graph(mode='beam_search', beam_size=beam_size) + # clone from default main program and use it as the validation program + main_program = fluid.default_main_program() + main_program = main_program.clone(for_test=True) + print([param.name for param in main_program.blocks[0].all_parameters()]) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = Executor(place) + exe.run(framework.default_startup_program()) + + source_vocab_file = args.vocab_prefix + "." + args.src_lang + infer_file = args.infer_file + + infer_data = reader.raw_mono_data(source_vocab_file, infer_file) + + def prepare_input(batch, epoch_id=0, with_lr=True): + src_ids, src_mask, tar_ids, tar_mask = batch + res = {} + src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1])) + in_tar = tar_ids[:, :-1] + label_tar = tar_ids[:, 1:] + + in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1])) + in_tar = np.zeros_like(in_tar, dtype='int64') + label_tar = label_tar.reshape( + (label_tar.shape[0], label_tar.shape[1], 1)) + label_tar = np.zeros_like(label_tar, dtype='int64') + + res['src'] = src_ids + res['tar'] = in_tar + res['label'] = label_tar + res['src_sequence_length'] = src_mask + res['tar_sequence_length'] = tar_mask + + return res, np.sum(tar_mask) + + dir_name = args.reload_model + print("dir name", dir_name) + fluid.io.load_params(exe, dir_name) + + train_data_iter = reader.get_data_iter(infer_data, 1, mode='eval') + + tar_id2vocab = [] + tar_vocab_file = args.vocab_prefix + "." + args.tar_lang + with io.open(tar_vocab_file, "r", encoding='utf-8') as f: + for line in f.readlines(): + tar_id2vocab.append(line.strip()) + + infer_output_file = args.infer_output_file + infer_output_dir = infer_output_file.split('/')[0] + if not os.path.exists(infer_output_dir): + os.mkdir(infer_output_dir) + + with io.open(infer_output_file, 'w', encoding='utf-8') as out_file: + + for batch_id, batch in enumerate(train_data_iter): + input_data_feed, word_num = prepare_input(batch, epoch_id=0) + fetch_outs = exe.run(program=main_program, + feed=input_data_feed, + fetch_list=[trans_res.name], + use_program_cache=False) + + for ins in fetch_outs[0]: + res = [tar_id2vocab[e] for e in ins[:, 0].reshape(-1)] + new_res = [] + for ele in res: + if ele == "": + break + new_res.append(ele) + + out_file.write(space_tok.join(new_res)) + out_file.write(line_tok) + + +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) + + +if __name__ == '__main__': + check_version() + infer() diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/infer.sh b/PaddleNLP/PaddleTextGEN/seq2seq/infer.sh new file mode 100644 index 00000000..6b62b013 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/infer.sh @@ -0,0 +1,22 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0 + +python infer.py \ + --attention True \ + --src_lang en --tar_lang vi \ + --num_layers 2 \ + --hidden_size 512 \ + --src_vocab_size 17191 \ + --tar_vocab_size 7709 \ + --batch_size 128 \ + --dropout 0.2 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --vocab_prefix data/en-vi/vocab \ + --infer_file data/en-vi/tst2013.en \ + --reload_model attention_models/epoch_10/ \ + --infer_output_file attention_infer_output/infer_output.txt \ + --beam_size 10 \ + --use_gpu True + + diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/reader.py b/PaddleNLP/PaddleTextGEN/seq2seq/reader.py new file mode 100644 index 00000000..661546f7 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/reader.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for parsing PTB text files.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import io +import sys +import numpy as np + +Py3 = sys.version_info[0] == 3 + +UNK_ID = 0 + + +def _read_words(filename): + data = [] + with io.open(filename, "r", encoding='utf-8') as f: + if Py3: + return f.read().replace("\n", "").split() + else: + return f.read().decode("utf-8").replace(u"\n", u"").split() + + +def read_all_line(filenam): + data = [] + with io.open(filename, "r", encoding='utf-8') as f: + for line in f.readlines(): + data.append(line.strip()) + + +def _build_vocab(filename): + + vocab_dict = {} + ids = 0 + with io.open(filename, "r", encoding='utf-8') as f: + for line in f.readlines(): + vocab_dict[line.strip()] = ids + ids += 1 + + print("vocab word num", ids) + + return vocab_dict + + +def _para_file_to_ids(src_file, tar_file, src_vocab, tar_vocab): + + src_data = [] + with io.open(src_file, "r", encoding='utf-8') as f_src: + for line in f_src.readlines(): + arra = line.strip().split() + ids = [src_vocab[w] if w in src_vocab else UNK_ID for w in arra] + ids = ids + + src_data.append(ids) + + tar_data = [] + with io.open(tar_file, "r", encoding='utf-8') as f_tar: + for line in f_tar.readlines(): + arra = line.strip().split() + ids = [tar_vocab[w] if w in tar_vocab else UNK_ID for w in arra] + + ids = [1] + ids + [2] + + tar_data.append(ids) + + return src_data, tar_data + + +def filter_len(src, tar, max_sequence_len=50): + new_src = [] + new_tar = [] + + for id1, id2 in zip(src, tar): + if len(id1) > max_sequence_len: + id1 = id1[:max_sequence_len] + if len(id2) > max_sequence_len + 2: + id2 = id2[:max_sequence_len + 2] + + new_src.append(id1) + new_tar.append(id2) + + return new_src, new_tar + + +def raw_data(src_lang, + tar_lang, + vocab_prefix, + train_prefix, + eval_prefix, + test_prefix, + max_sequence_len=50): + + src_vocab_file = vocab_prefix + "." + src_lang + tar_vocab_file = vocab_prefix + "." + tar_lang + + src_train_file = train_prefix + "." + src_lang + tar_train_file = train_prefix + "." + tar_lang + + src_eval_file = eval_prefix + "." + src_lang + tar_eval_file = eval_prefix + "." + tar_lang + + src_test_file = test_prefix + "." + src_lang + tar_test_file = test_prefix + "." + tar_lang + + src_vocab = _build_vocab(src_vocab_file) + tar_vocab = _build_vocab(tar_vocab_file) + + train_src, train_tar = _para_file_to_ids( src_train_file, tar_train_file, \ + src_vocab, tar_vocab ) + train_src, train_tar = filter_len( + train_src, train_tar, max_sequence_len=max_sequence_len) + eval_src, eval_tar = _para_file_to_ids( src_eval_file, tar_eval_file, \ + src_vocab, tar_vocab ) + + test_src, test_tar = _para_file_to_ids( src_test_file, tar_test_file, \ + src_vocab, tar_vocab ) + + return ( train_src, train_tar), (eval_src, eval_tar), (test_src, test_tar),\ + (src_vocab, tar_vocab) + + +def raw_mono_data(vocab_file, file_path): + + src_vocab = _build_vocab(vocab_file) + + test_src, test_tar = _para_file_to_ids( file_path, file_path, \ + src_vocab, src_vocab ) + + return (test_src, test_tar) + + +def get_data_iter(raw_data, + batch_size, + mode='train', + enable_ce=False, + cache_num=20): + + src_data, tar_data = raw_data + + data_len = len(src_data) + + index = np.arange(data_len) + if mode == "train" and not enable_ce: + np.random.shuffle(index) + + def to_pad_np(data, source=False): + max_len = 0 + for ele in data: + if len(ele) > max_len: + max_len = len(ele) + + ids = np.ones((batch_size, max_len), dtype='int64') * 2 + mask = np.zeros((batch_size), dtype='int32') + + for i, ele in enumerate(data): + ids[i, :len(ele)] = ele + if not source: + mask[i] = len(ele) - 1 + else: + mask[i] = len(ele) + + return ids, mask + + b_src = [] + + if mode != "train": + cache_num = 1 + for j in range(data_len): + if len(b_src) == batch_size * cache_num: + # build batch size + + # sort + new_cache = sorted(b_src, key=lambda k: len(k[0])) + + for i in range(cache_num): + batch_data = new_cache[i * batch_size:(i + 1) * batch_size] + src_cache = [w[0] for w in batch_data] + tar_cache = [w[1] for w in batch_data] + src_ids, src_mask = to_pad_np(src_cache, source=True) + tar_ids, tar_mask = to_pad_np(tar_cache) + yield (src_ids, src_mask, tar_ids, tar_mask) + + b_src = [] + + b_src.append((src_data[index[j]], tar_data[index[j]])) + if len(b_src) == batch_size * cache_num: + new_cache = sorted(b_src, key=lambda k: len(k[0])) + + for i in range(cache_num): + batch_data = new_cache[i * batch_size:(i + 1) * batch_size] + src_cache = [w[0] for w in batch_data] + tar_cache = [w[1] for w in batch_data] + src_ids, src_mask = to_pad_np(src_cache, source=True) + tar_ids, tar_mask = to_pad_np(tar_cache) + yield (src_ids, src_mask, tar_ids, tar_mask) diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/run.sh b/PaddleNLP/PaddleTextGEN/seq2seq/run.sh new file mode 100644 index 00000000..25bc78a3 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/run.sh @@ -0,0 +1,20 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=0 + +python train.py \ + --src_lang en --tar_lang vi \ + --attention True \ + --num_layers 2 \ + --hidden_size 512 \ + --src_vocab_size 17191 \ + --tar_vocab_size 7709 \ + --batch_size 128 \ + --dropout 0.2 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --train_data_prefix data/en-vi/train \ + --eval_data_prefix data/en-vi/tst2012 \ + --test_data_prefix data/en-vi/tst2013 \ + --vocab_prefix data/en-vi/vocab \ + --use_gpu True \ + --model_path attention_models diff --git a/PaddleNLP/PaddleTextGEN/seq2seq/train.py b/PaddleNLP/PaddleTextGEN/seq2seq/train.py new file mode 100644 index 00000000..51d4d29e --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/seq2seq/train.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import time +import os +import logging +import random +import math +import contextlib + +import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import paddle.fluid.profiler as profiler +from paddle.fluid.executor import Executor + +import reader + +import sys +if sys.version[0] == '2': + reload(sys) + sys.setdefaultencoding("utf-8") + +from args import * +from base_model import BaseModel +from attention_model import AttentionModel +import logging +import pickle + + +@contextlib.contextmanager +def profile_context(profile=True): + if profile: + with profiler.profiler('All', 'total', 'seq2seq.profile'): + yield + else: + yield + + +def main(): + args = parse_args() + print(args) + num_layers = args.num_layers + src_vocab_size = args.src_vocab_size + tar_vocab_size = args.tar_vocab_size + batch_size = args.batch_size + dropout = args.dropout + init_scale = args.init_scale + max_grad_norm = args.max_grad_norm + hidden_size = args.hidden_size + + if args.enable_ce: + fluid.default_main_program().random_seed = 102 + framework.default_startup_program().random_seed = 102 + + train_program = fluid.Program() + startup_program = fluid.Program() + + with fluid.program_guard(train_program, startup_program): + # Training process + + if args.attention: + model = AttentionModel( + hidden_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=num_layers, + init_scale=init_scale, + dropout=dropout) + else: + model = BaseModel( + hidden_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=num_layers, + init_scale=init_scale, + dropout=dropout) + loss = model.build_graph() + inference_program = train_program.clone(for_test=True) + fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm( + clip_norm=max_grad_norm)) + lr = args.learning_rate + opt_type = args.optimizer + if opt_type == "sgd": + optimizer = fluid.optimizer.SGD(lr) + elif opt_type == "adam": + optimizer = fluid.optimizer.Adam(lr) + else: + print("only support [sgd|adam]") + raise Exception("opt type not support") + + optimizer.minimize(loss) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = Executor(place) + exe.run(startup_program) + + device_count = len(fluid.cuda_places()) if args.use_gpu else len( + fluid.cpu_places()) + + CompiledProgram = fluid.CompiledProgram(train_program).with_data_parallel( + loss_name=loss.name) + + train_data_prefix = args.train_data_prefix + eval_data_prefix = args.eval_data_prefix + test_data_prefix = args.test_data_prefix + vocab_prefix = args.vocab_prefix + src_lang = args.src_lang + tar_lang = args.tar_lang + print("begin to load data") + raw_data = reader.raw_data(src_lang, tar_lang, vocab_prefix, + train_data_prefix, eval_data_prefix, + test_data_prefix, args.max_len) + print("finished load data") + train_data, valid_data, test_data, _ = raw_data + + def prepare_input(batch, epoch_id=0, with_lr=True): + src_ids, src_mask, tar_ids, tar_mask = batch + res = {} + src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1])) + in_tar = tar_ids[:, :-1] + label_tar = tar_ids[:, 1:] + + in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1])) + label_tar = label_tar.reshape( + (label_tar.shape[0], label_tar.shape[1], 1)) + + res['src'] = src_ids + res['tar'] = in_tar + res['label'] = label_tar + res['src_sequence_length'] = src_mask + res['tar_sequence_length'] = tar_mask + + return res, np.sum(tar_mask) + + # get train epoch size + def eval(data, epoch_id=0): + eval_data_iter = reader.get_data_iter(data, batch_size, mode='eval') + total_loss = 0.0 + word_count = 0.0 + for batch_id, batch in enumerate(eval_data_iter): + input_data_feed, word_num = prepare_input( + batch, epoch_id, with_lr=False) + fetch_outs = exe.run(inference_program, + feed=input_data_feed, + fetch_list=[loss.name], + use_program_cache=False) + + cost_train = np.array(fetch_outs[0]) + + total_loss += cost_train * batch_size + word_count += word_num + + ppl = np.exp(total_loss / word_count) + + return ppl + + def train(): + ce_time = [] + ce_ppl = [] + max_epoch = args.max_epoch + for epoch_id in range(max_epoch): + start_time = time.time() + if args.enable_ce: + train_data_iter = reader.get_data_iter( + train_data, batch_size, enable_ce=True) + else: + train_data_iter = reader.get_data_iter(train_data, batch_size) + + total_loss = 0 + word_count = 0.0 + batch_times = [] + for batch_id, batch in enumerate(train_data_iter): + batch_start_time = time.time() + input_data_feed, word_num = prepare_input( + batch, epoch_id=epoch_id) + word_count += word_num + fetch_outs = exe.run(program=CompiledProgram, + feed=input_data_feed, + fetch_list=[loss.name], + use_program_cache=True) + + cost_train = np.mean(fetch_outs[0]) + # print(cost_train) + total_loss += cost_train * batch_size + batch_end_time = time.time() + batch_time = batch_end_time - batch_start_time + batch_times.append(batch_time) + + if batch_id > 0 and batch_id % 100 == 0: + print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" % + (epoch_id, batch_id, batch_time, + np.exp(total_loss / word_count))) + ce_ppl.append(np.exp(total_loss / word_count)) + total_loss = 0.0 + word_count = 0.0 + + end_time = time.time() + epoch_time = end_time - start_time + ce_time.append(epoch_time) + print( + "\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n" + % (epoch_id, epoch_time, sum(batch_times) / len(batch_times))) + + if not args.profile: + dir_name = os.path.join(args.model_path, + "epoch_" + str(epoch_id)) + print("begin to save", dir_name) + fluid.io.save_params(exe, dir_name, main_program=train_program) + print("save finished") + dev_ppl = eval(valid_data) + print("dev ppl", dev_ppl) + test_ppl = eval(test_data) + print("test ppl", test_ppl) + + if args.enable_ce: + card_num = get_cards() + _ppl = 0 + _time = 0 + try: + _time = ce_time[-1] + _ppl = ce_ppl[-1] + except: + print("ce info error") + print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time)) + print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl)) + + with profile_context(args.profile): + train() + + +def get_cards(): + num = 0 + cards = os.environ.get('CUDA_VISIBLE_DEVICES', '') + if cards != '': + num = len(cards.split(",")) + return num + + +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) + + +if __name__ == '__main__': + check_version() + main() diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/README.md b/PaddleNLP/PaddleTextGEN/variational_seq2seq/README.md new file mode 100644 index 00000000..e367496c --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/README.md @@ -0,0 +1,107 @@ +运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.6版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。 + +# Variational Autoencoder (VAE) for Text Generation + +以下是本范例模型的简要目录结构及说明: + +```text +. +├── README.md # 文档,本文件 +├── args.py # 训练、预测以及模型参数配置程序 +├── reader.py # 数据读入程序 +├── download.py # 数据下载程序 +├── train.py # 训练主程序 +├── infer.py # 预测主程序 +├── run.sh # 默认配置的启动脚本 +├── infer.sh # 默认配置的解码脚本 +└── model.py # VAE模型配置程序 + +``` + +## 简介 +本目录下此范例模型的实现,旨在展示如何用Paddle Fluid的 **新Seq2Seq API** 构建用于文本生成的VAE示例,其中LSTM作为编码器和解码器。 分别对官方PTB数据和SWDA数据进行培训。 + +关于VAE的详细介绍参照: [(Bowman et al., 2015) Generating Sentences from a Continuous Space](https://arxiv.org/pdf/1511.06349.pdf) + +## 数据介绍 + +本教程使用了两个文本数据集: + +PTB dataset,原始下载地址为: http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz。 + +SWDA dataset,来源于[Knowledge-Guided CVAE for dialog generation](https://arxiv.org/pdf/1703.10960.pdf),原始数据集下载地址为:https://github.com/snakeztc/NeuralDialog-CVAE ,感谢作者@[snakeztc](https://github.com/snakeztc)。我们过滤了数据集中长度小于5的短句子。 + +### 数据获取 + +``` +python download.py --task ptb/swda +``` + +## 模型训练 + +`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行: + +``` +sh run.sh ptb/swda +``` + +如果需要修改模型的参数设置,也可以通过下面命令配置: + +``` +python train.py \ + --vocab_size 10003 \ + --batch_size 32 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --dataset_prefix data/${dataset}/${dataset} \ + --model_path ${dataset}_model\ + --use_gpu True \ + --max_epoch 50 \ +``` + +训练程序采用了 Early Stopping,会在每个epoch根据ppl的表现来决定是否保存模型。 + +## 模型预测 + +当模型训练完成之后, 可以利用infer.sh的脚本进行预测,选择加载模型保存目录下的第 k 个epoch的模型进行预测,生成batch_size条短文本。 + +``` +sh infer.sh ptb/swda k +``` + +如果需要修改模型预测输出的参数设置,也可以通过下面命令配置: + +``` +python infer.py \ + --vocab_size 10003 \ + --batch_size 32 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --dataset_prefix data/${dataset}/${dataset} \ + --use_gpu True \ + --reload_model ${dataset}_model/epoch_${k} \ +``` + +## 效果评价 + +```sh +PTB数据集: +Test PPL: 102.24 +Test NLL: 108.22 + +SWDA数据集: +Test PPL: 64.21 +Test NLL: 81.92 +``` + +## 生成样例 + +the movie are discovered in the u.s. industry that on aircraft variations for a aircraft that was repaired + +the percentage of treasury bonds rose to N N at N N up N N from the two days N N and a premium over N + +he could n't plunge as a factor that attention now has n't picked up for the state according to mexico + +take the remark we need to do then support for the market to tell it + +i think that it believes the core of the company in its first quarter of heavy mid-october to fuel earnings after prices on friday diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/__init__.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/args.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/args.py new file mode 100644 index 00000000..a19f65cb --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/args.py @@ -0,0 +1,163 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import distutils.util + + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--dataset_prefix", type=str, help="file prefix for train data") + + parser.add_argument( + "--optimizer", + type=str, + default='adam', + help="optimizer to use, only supprt[sgd|adam]") + + parser.add_argument( + "--learning_rate", + type=float, + default=0.001, + help="learning rate for optimizer") + + parser.add_argument( + "--num_layers", + type=int, + default=1, + help="layers number of encoder and decoder") + parser.add_argument( + "--hidden_size", + type=int, + default=256, + help="hidden size of encoder and decoder") + parser.add_argument("--vocab_size", type=int, help="source vocab size") + + parser.add_argument( + "--batch_size", type=int, help="batch size of each step") + + parser.add_argument( + "--max_epoch", type=int, default=20, help="max epoch for the training") + + parser.add_argument( + "--max_len", + type=int, + default=1280, + help="max length for source and target sentence") + parser.add_argument( + "--dec_dropout_in", + type=float, + default=0.5, + help="decoder input drop probability") + parser.add_argument( + "--dec_dropout_out", + type=float, + default=0.5, + help="decoder output drop probability") + parser.add_argument( + "--enc_dropout_in", + type=float, + default=0., + help="encoder input drop probability") + parser.add_argument( + "--enc_dropout_out", + type=float, + default=0., + help="encoder output drop probability") + parser.add_argument( + "--word_keep_prob", + type=float, + default=0.5, + help="word keep probability") + parser.add_argument( + "--init_scale", + type=float, + default=0.0, + help="init scale for parameter") + parser.add_argument( + "--max_grad_norm", + type=float, + default=5.0, + help="max grad norm for global norm clip") + + parser.add_argument( + "--model_path", + type=str, + default='model', + help="model path for model to save") + + parser.add_argument( + "--reload_model", type=str, help="reload model to inference") + + parser.add_argument( + "--infer_output_file", + type=str, + default='infer_output.txt', + help="file name for inference output") + + parser.add_argument( + "--beam_size", type=int, default=10, help="file name for inference") + + parser.add_argument( + '--use_gpu', + type=eval, + default=False, + help='Whether using gpu [True|False]') + + parser.add_argument( + "--enable_ce", + action='store_true', + help="The flag indicating whether to run the task " + "for continuous evaluation.") + + parser.add_argument( + "--profile", action='store_true', help="Whether enable the profile.") + + parser.add_argument( + "--warm_up", + type=int, + default=10, + help='number of warm up epochs for KL') + + parser.add_argument( + "--kl_start", type=float, default=0.1, help='KL start value, upto 1.0') + + parser.add_argument( + "--attr_init", + type=str, + default='normal_initializer', + help="initializer for paramters") + + parser.add_argument( + "--cache_num", type=int, default=1, help='cache num for reader') + + parser.add_argument( + "--max_decay", + type=int, + default=5, + help='max decay tries (if exceeds, early stop)') + + parser.add_argument( + "--sort_cache", + action='store_true', + help='sort cache before batch to accelerate training') + + args = parser.parse_args() + return args diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/download.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/download.py new file mode 100644 index 00000000..ba4acacd --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/download.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +''' +Script for downloading training data. +''' +import os +import sys +import shutil +import argparse +import tempfile +import urllib +import tarfile +import io +if sys.version_info >= (3, 0): + import urllib.request +import zipfile + +URLLIB = urllib +if sys.version_info >= (3, 0): + URLLIB = urllib.request + +TASKS = ['ptb', 'swda'] +TASK2PATH = { + 'ptb': 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz', + 'swda': 'https://baidu-nlp.bj.bcebos.com/TextGen/swda.tar.gz' +} + + +def un_tar(tar_name, dir_name): + try: + t = tarfile.open(tar_name) + t.extractall(path=dir_name) + return True + except Exception as e: + print(e) + return False + + +def download_and_extract(task, data_path): + print('Downloading and extracting %s...' % task) + data_file = os.path.join(data_path, TASK2PATH[task].split('/')[-1]) + URLLIB.urlretrieve(TASK2PATH[task], data_file) + un_tar(data_file, data_path) + os.remove(data_file) + if task == 'ptb': + src_dir = os.path.join(data_path, 'simple-examples') + dst_dir = os.path.join(data_path, 'ptb') + if not os.path.exists(dst_dir): + os.mkdir(dst_dir) + shutil.copy(os.path.join(src_dir, 'data/ptb.train.txt'), dst_dir) + shutil.copy(os.path.join(src_dir, 'data/ptb.valid.txt'), dst_dir) + shutil.copy(os.path.join(src_dir, 'data/ptb.test.txt'), dst_dir) + shutil.rmtree(src_dir) + print('\tCompleted!') + + +def main(arguments): + parser = argparse.ArgumentParser() + parser.add_argument( + '-d', + '--data_dir', + help='directory to save data to', + type=str, + default='data') + parser.add_argument( + '-t', + '--task', + help='tasks to download data for as a comma separated string', + type=str, + default='ptb') + args = parser.parse_args(arguments) + + if not os.path.isdir(args.data_dir): + os.mkdir(args.data_dir) + + download_and_extract(args.task, args.data_dir) + + +if __name__ == '__main__': + sys.exit(main(sys.argv[1:])) diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/infer.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/infer.py new file mode 100644 index 00000000..c21fff3a --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/infer.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import time +import os +import random +import io +import math + +import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework +from paddle.fluid.executor import Executor + +import reader + +import sys +line_tok = '\n' +space_tok = ' ' +if sys.version[0] == '2': + reload(sys) + sys.setdefaultencoding("utf-8") + line_tok = u'\n' + space_tok = u' ' + +import os + +from args import * +#from . import lm_model +import logging +import pickle + +from model import VAE +from reader import BOS_ID, EOS_ID, get_vocab + + +def infer(): + args = parse_args() + + num_layers = args.num_layers + src_vocab_size = args.vocab_size + tar_vocab_size = args.vocab_size + batch_size = args.batch_size + init_scale = args.init_scale + max_grad_norm = args.max_grad_norm + hidden_size = args.hidden_size + attr_init = args.attr_init + latent_size = 32 + + if args.enable_ce: + fluid.default_main_program().random_seed = 102 + framework.default_startup_program().random_seed = 102 + + model = VAE(hidden_size, + latent_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=num_layers, + init_scale=init_scale, + attr_init=attr_init) + + beam_size = args.beam_size + trans_res = model.build_graph(mode='sampling', beam_size=beam_size) + # clone from default main program and use it as the validation program + main_program = fluid.default_main_program() + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = Executor(place) + exe.run(framework.default_startup_program()) + + dir_name = args.reload_model + print("dir name", dir_name) + fluid.io.load_params(exe, dir_name) + vocab, tar_id2vocab = get_vocab(args.dataset_prefix) + infer_output = np.ones((batch_size, 1), dtype='int64') * BOS_ID + + fetch_outs = exe.run(feed={'tar': infer_output}, + fetch_list=[trans_res.name], + use_program_cache=False) + + with io.open(args.infer_output_file, 'w', encoding='utf-8') as out_file: + + for line in fetch_outs[0]: + end_id = -1 + if EOS_ID in line: + end_id = np.where(line == EOS_ID)[0][0] + new_line = [tar_id2vocab[e[0]] for e in line[1:end_id]] + out_file.write(space_tok.join(new_line)) + out_file.write(line_tok) + + +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) + + +if __name__ == '__main__': + check_version() + infer() diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/infer.sh b/PaddleNLP/PaddleTextGEN/variational_seq2seq/infer.sh new file mode 100644 index 00000000..17940a04 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/infer.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -x +export CUDA_VISIBLE_DEVICES=0 +dataset=$1 +k=$2 +python infer.py \ + --vocab_size 10003 \ + --batch_size 32 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --dataset_prefix data/${dataset}/${dataset} \ + --use_gpu True \ + --reload_model ${dataset}_model/epoch_${k} \ + + diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/model.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/model.py new file mode 100644 index 00000000..a582ff9e --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/model.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle.fluid.layers as layers +import paddle.fluid as fluid +import numpy as np +from paddle.fluid import ParamAttr +from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode +from paddle.fluid.contrib.layers import basic_lstm, BasicLSTMUnit + +from reader import BOS_ID, EOS_ID + +INF = 1. * 1e5 +alpha = 0.6 +normal_initializer = lambda x: fluid.initializer.NormalInitializer(loc=0., scale=x**-0.5) +uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x) +zero_constant = fluid.initializer.Constant(0.0) + + +class EncoderCell(RNNCell): + def __init__(self, + num_layers, + hidden_size, + param_attr_initializer, + param_attr_scale, + dropout_prob=0.): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + self.lstm_cells = [] + + for i in range(num_layers): + lstm_name = "enc_layers_" + str(i) + self.lstm_cells.append( + LSTMCell( + hidden_size, forget_bias=0., name=lstm_name)) + + def call(self, step_input, states): + new_states = [] + for i in range(self.num_layers): + out, new_state = self.lstm_cells[i](step_input, states[i]) + step_input = layers.dropout( + out, + self.dropout_prob, + dropout_implementation="upscale_in_train" + ) if self.dropout_prob > 0 else out + new_states.append(new_state) + return step_input, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.lstm_cells] + + +class DecoderCell(RNNCell): + def __init__(self, + num_layers, + hidden_size, + latent_z, + param_attr_initializer, + param_attr_scale, + dropout_prob=0.): + self.num_layers = num_layers + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + self.latent_z = latent_z + self.lstm_cells = [] + + param_attr = ParamAttr( + initializer=param_attr_initializer(param_attr_scale)) + + for i in range(num_layers): + lstm_name = "dec_layers_" + str(i) + self.lstm_cells.append( + LSTMCell( + hidden_size, param_attr, forget_bias=0., name=lstm_name)) + + def call(self, step_input, states): + lstm_states = states + new_lstm_states = [] + step_input = layers.concat([step_input, self.latent_z], 1) + for i in range(self.num_layers): + out, lstm_state = self.lstm_cells[i](step_input, lstm_states[i]) + step_input = layers.dropout( + out, + self.dropout_prob, + dropout_implementation="upscale_in_train" + ) if self.dropout_prob > 0 else out + new_lstm_states.append(lstm_state) + return step_input, new_lstm_states + + +class VAE(object): + def __init__(self, + hidden_size, + latent_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=1, + init_scale=0.1, + dec_dropout_in=0.5, + dec_dropout_out=0.5, + enc_dropout_in=0., + enc_dropout_out=0., + word_keep_prob=0.5, + batch_first=True, + attr_init="normal_initializer"): + + self.hidden_size = hidden_size + self.latent_size = latent_size + self.src_vocab_size = src_vocab_size + self.tar_vocab_size = tar_vocab_size + self.batch_size = batch_size + self.num_layers = num_layers + self.init_scale = init_scale + self.dec_dropout_in = dec_dropout_in + self.dec_dropout_out = dec_dropout_out + self.enc_dropout_in = enc_dropout_in + self.enc_dropout_out = enc_dropout_out + self.word_keep_prob = word_keep_prob + self.batch_first = batch_first + + if attr_init == "normal_initializer": + self.param_attr_initializer = normal_initializer + self.param_attr_scale = hidden_size + elif attr_init == "uniform_initializer": + self.param_attr_initializer = uniform_initializer + self.param_attr_scale = init_scale + else: + raise TypeError("The type of 'attr_initializer' is not supported") + + self.src_embeder = lambda x: fluid.embedding( + input=x, + size=[self.src_vocab_size, self.hidden_size], + dtype='float32', + is_sparse=False, + param_attr=fluid.ParamAttr( + name='source_embedding', + initializer=self.param_attr_initializer(self.param_attr_scale))) + + self.tar_embeder = lambda x: fluid.embedding( + input=x, + size=[self.tar_vocab_size, self.hidden_size], + dtype='float32', + is_sparse=False, + param_attr=fluid.ParamAttr( + name='target_embedding', + initializer=self.param_attr_initializer(self.param_attr_scale))) + + def _build_data(self, mode='train'): + if mode == 'train': + self.src = fluid.data(name="src", shape=[None, None], dtype='int64') + self.src_sequence_length = fluid.data( + name="src_sequence_length", shape=[None], dtype='int32') + self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64') + self.tar_sequence_length = fluid.data( + name="tar_sequence_length", shape=[None], dtype='int32') + self.label = fluid.data( + name="label", shape=[None, None, 1], dtype='int64') + self.kl_weight = fluid.data( + name='kl_weight', shape=[1], dtype='float32') + else: + self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64') + + def _emebdding(self, mode='train'): + if mode == 'train': + self.src_emb = self.src_embeder(self.src) + self.tar_emb = self.tar_embeder(self.tar) + + def _build_encoder(self): + self.enc_input = layers.dropout( + self.src_emb, + dropout_prob=self.enc_dropout_in, + dropout_implementation="upscale_in_train") + enc_cell = EncoderCell(self.num_layers, self.hidden_size, + self.param_attr_initializer, + self.param_attr_scale, self.enc_dropout_out) + enc_output, enc_final_state = rnn( + cell=enc_cell, + inputs=self.enc_input, + sequence_length=self.src_sequence_length) + return enc_output, enc_final_state + + def _build_distribution(self, enc_final_state=None): + enc_hidden = [ + layers.concat( + state, axis=-1) for state in enc_final_state + ] + enc_hidden = layers.concat(enc_hidden, axis=-1) + z_mean_log_var = layers.fc(input=enc_hidden, + size=self.latent_size * 2, + name='fc_dist') + z_mean, z_log_var = layers.split(z_mean_log_var, 2, -1) + return z_mean, z_log_var + + def _build_decoder(self, + z_mean=None, + z_log_var=None, + enc_output=None, + mode='train', + beam_size=10): + dec_input = layers.dropout( + self.tar_emb, + dropout_prob=self.dec_dropout_in, + dropout_implementation="upscale_in_train") + + # `output_layer` will be used within BeamSearchDecoder + output_layer = lambda x: layers.fc(x, + size=self.tar_vocab_size, + num_flatten_dims=len(x.shape) - 1, + name="output_w") + + # `sample_output_layer` samples an id from the logits distribution instead of argmax(logits) + # it will be used within BeamSearchDecoder + sample_output_layer = lambda x: layers.unsqueeze(layers.one_hot( + layers.unsqueeze( + layers.sampling_id( + layers.softmax( + layers.squeeze(output_layer(x),[1]) + ),dtype='int'), [1]), + depth=self.tar_vocab_size), [1]) + + if mode == 'train': + latent_z = self._sampling(z_mean, z_log_var) + else: + latent_z = layers.gaussian_random_batch_size_like( + self.tar, shape=[-1, self.latent_size]) + dec_first_hidden_cell = layers.fc(latent_z, + 2 * self.hidden_size * + self.num_layers, + name='fc_hc') + dec_first_hidden, dec_first_cell = layers.split(dec_first_hidden_cell, + 2) + if self.num_layers > 1: + dec_first_hidden = layers.split(dec_first_hidden, self.num_layers) + dec_first_cell = layers.split(dec_first_cell, self.num_layers) + else: + dec_first_hidden = [dec_first_hidden] + dec_first_cell = [dec_first_cell] + dec_initial_states = [[h, c] + for h, c in zip(dec_first_hidden, dec_first_cell)] + dec_cell = DecoderCell(self.num_layers, self.hidden_size, latent_z, + self.param_attr_initializer, + self.param_attr_scale, self.dec_dropout_out) + + if mode == 'train': + dec_output, _ = rnn(cell=dec_cell, + inputs=dec_input, + initial_states=dec_initial_states, + sequence_length=self.tar_sequence_length) + dec_output = output_layer(dec_output) + + return dec_output + elif mode == 'greedy': + start_token = 1 + end_token = 2 + max_length = 100 + beam_search_decoder = BeamSearchDecoder( + dec_cell, + start_token, + end_token, + beam_size=1, + embedding_fn=self.tar_embeder, + output_fn=output_layer) + outputs, _ = dynamic_decode( + beam_search_decoder, + inits=dec_initial_states, + max_step_num=max_length) + return outputs + + elif mode == 'sampling': + start_token = 1 + end_token = 2 + max_length = 100 + beam_search_decoder = BeamSearchDecoder( + dec_cell, + start_token, + end_token, + beam_size=1, + embedding_fn=self.tar_embeder, + output_fn=sample_output_layer) + + outputs, _ = dynamic_decode( + beam_search_decoder, + inits=dec_initial_states, + max_step_num=max_length) + return outputs + else: + print("mode not supprt", mode) + + def _sampling(self, z_mean, z_log_var): + """reparameterization trick + """ + # by default, random_normal has mean=0 and std=1.0 + epsilon = layers.gaussian_random_batch_size_like( + self.tar, shape=[-1, self.latent_size]) + epsilon.stop_gradient = True + return z_mean + layers.exp(0.5 * z_log_var) * epsilon + + def _kl_dvg(self, means, logvars): + """compute the KL divergence between Gaussian distribution + """ + kl_cost = -0.5 * (logvars - fluid.layers.square(means) - + fluid.layers.exp(logvars) + 1.0) + kl_cost = fluid.layers.reduce_mean(kl_cost, 0) + + return fluid.layers.reduce_sum(kl_cost) + + def _compute_loss(self, mean, logvars, dec_output): + + kl_loss = self._kl_dvg(mean, logvars) + + rec_loss = layers.softmax_with_cross_entropy( + logits=dec_output, label=self.label, soft_label=False) + + rec_loss = layers.reshape(rec_loss, shape=[self.batch_size, -1]) + + max_tar_seq_len = layers.shape(self.tar)[1] + tar_mask = layers.sequence_mask( + self.tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32') + rec_loss = rec_loss * tar_mask + rec_loss = layers.reduce_mean(rec_loss, dim=[0]) + rec_loss = layers.reduce_sum(rec_loss) + + loss = kl_loss * self.kl_weight + rec_loss + + return loss, kl_loss, rec_loss + + def _beam_search(self, enc_last_hidden, enc_last_cell): + pass + + def build_graph(self, mode='train', beam_size=10): + if mode == 'train' or mode == 'eval': + self._build_data() + self._emebdding() + enc_output, enc_final_state = self._build_encoder() + z_mean, z_log_var = self._build_distribution(enc_final_state) + dec_output = self._build_decoder(z_mean, z_log_var, enc_output) + + loss, kl_loss, rec_loss = self._compute_loss(z_mean, z_log_var, + dec_output) + return loss, kl_loss, rec_loss + + elif mode == "sampling" or mode == 'greedy': + self._build_data(mode) + self._emebdding(mode) + dec_output = self._build_decoder(mode=mode, beam_size=1) + + return dec_output + else: + print("not support mode ", mode) + raise Exception("not support mode: " + mode) diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/reader.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/reader.py new file mode 100644 index 00000000..86bb6afd --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/reader.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for parsing PTB text files.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import sys +import numpy as np +import io +from collections import Counter +Py3 = sys.version_info[0] == 3 + +if Py3: + line_tok = '\n' +else: + line_tok = u'\n' + +PAD_ID = 0 +BOS_ID = 1 +EOS_ID = 2 +UNK_ID = 3 + + +def _read_words(filename): + data = [] + with io.open(filename, "r", encoding='utf-8') as f: + if Py3: + return f.read().replace("\n", "").split() + else: + return f.read().decode("utf-8").replace(u"\n", u"").split() + + +def read_all_line(filename): + data = [] + with io.open(filename, "r", encoding='utf-8') as f: + for line in f.readlines(): + data.append(line.strip()) + return data + + +def _vocab(vocab_file, train_file, max_vocab_cnt): + lines = read_all_line(train_file) + + all_words = [] + for line in lines: + all_words.extend(line.split()) + vocab_count = Counter(all_words).most_common() + raw_vocab_size = min(len(vocab_count), max_vocab_cnt) + with io.open(vocab_file, "w", encoding='utf-8') as f: + for voc, fre in vocab_count[0:max_vocab_cnt]: + f.write(voc) + f.write(line_tok) + + +def _build_vocab(vocab_file, train_file=None, max_vocab_cnt=-1): + if not os.path.exists(vocab_file): + _vocab(vocab_file, train_file, max_vocab_cnt) + vocab_dict = {"": 0, "": 1, "": 2, "": 3} + ids = 4 + with io.open(vocab_file, "r", encoding='utf-8') as f: + for line in f.readlines(): + vocab_dict[line.strip()] = ids + ids += 1 + # rev_vocab = {value:key for key, value in vocab_dict.items()} + print("vocab word num", ids) + + return vocab_dict + + +def _para_file_to_ids(src_file, src_vocab): + + src_data = [] + with io.open(src_file, "r", encoding='utf-8') as f_src: + for line in f_src.readlines(): + arra = line.strip().split() + ids = [BOS_ID] + ids.extend( + [src_vocab[w] if w in src_vocab else UNK_ID for w in arra]) + ids.append(EOS_ID) + src_data.append(ids) + + return src_data + + +def filter_len(src, max_sequence_len=128): + new_src = [] + + for id1 in src: + if len(id1) > max_sequence_len: + id1 = id1[:max_sequence_len] + + new_src.append(id1) + + return new_src + + +def raw_data(dataset_prefix, max_sequence_len=50, max_vocab_cnt=-1): + + src_vocab_file = dataset_prefix + ".vocab.txt" + src_train_file = dataset_prefix + ".train.txt" + src_eval_file = dataset_prefix + ".valid.txt" + src_test_file = dataset_prefix + ".test.txt" + + src_vocab = _build_vocab(src_vocab_file, src_train_file, max_vocab_cnt) + + train_src = _para_file_to_ids(src_train_file, src_vocab) + train_src = filter_len(train_src, max_sequence_len=max_sequence_len) + eval_src = _para_file_to_ids(src_eval_file, src_vocab) + + test_src = _para_file_to_ids(src_test_file, src_vocab) + + return train_src, eval_src, test_src, src_vocab + + +def get_vocab(dataset_prefix, max_sequence_len=50): + src_vocab_file = dataset_prefix + ".vocab.txt" + src_vocab = _build_vocab(src_vocab_file) + rev_vocab = {} + for key, value in src_vocab.items(): + rev_vocab[value] = key + + return src_vocab, rev_vocab + + +def raw_mono_data(vocab_file, file_path): + + src_vocab = _build_vocab(vocab_file) + test_src, test_tar = _para_file_to_ids( file_path, file_path, \ + src_vocab, src_vocab ) + + return (test_src, test_tar) + + +def get_data_iter(raw_data, + batch_size, + sort_cache=False, + cache_num=1, + mode='train', + enable_ce=False): + + src_data = raw_data + + data_len = len(src_data) + + index = np.arange(data_len) + if mode == "train" and not enable_ce: + np.random.shuffle(index) + + def to_pad_np(data): + max_len = 0 + for ele in data: + if len(ele) > max_len: + max_len = len(ele) + + ids = np.ones( + (batch_size, max_len), dtype='int64') * PAD_ID # PAD_ID = 0 + mask = np.zeros((batch_size), dtype='int32') + + for i, ele in enumerate(data): + ids[i, :len(ele)] = ele + mask[i] = len(ele) + + return ids, mask + + b_src = [] + + if mode != "train": + cache_num = 1 + for j in range(data_len): + if len(b_src) == batch_size * cache_num: + if sort_cache: + new_cache = sorted(b_src, key=lambda k: len(k)) + new_cache = b_src + for i in range(cache_num): + batch_data = new_cache[i * batch_size:(i + 1) * batch_size] + src_ids, src_mask = to_pad_np(batch_data) + yield (src_ids, src_mask) + + b_src = [] + + b_src.append(src_data[index[j]]) + + if len(b_src) > 0: + if sort_cache: + new_cache = sorted(b_src, key=lambda k: len(k)) + new_cache = b_src + for i in range(0, len(b_src), batch_size): + end_index = min((i + 1) * batch_size, len(b_src)) + batch_data = new_cache[i * batch_size:end_index] + src_ids, src_mask = to_pad_np(batch_data) + yield (src_ids, src_mask) diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/run.sh b/PaddleNLP/PaddleTextGEN/variational_seq2seq/run.sh new file mode 100644 index 00000000..af0db8a1 --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/run.sh @@ -0,0 +1,15 @@ +#!/bin/bash +set -x +export CUDA_VISIBLE_DEVICES=0 +dataset=$1 +python train.py \ + --vocab_size 10003 \ + --batch_size 32 \ + --init_scale 0.1 \ + --max_grad_norm 5.0 \ + --dataset_prefix data/${dataset}/${dataset} \ + --model_path ${dataset}_model\ + --use_gpu True \ + --max_epoch 200 \ + + diff --git a/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py b/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py new file mode 100644 index 00000000..08b9e1ae --- /dev/null +++ b/PaddleNLP/PaddleTextGEN/variational_seq2seq/train.py @@ -0,0 +1,313 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import time +import os +import random +import math +import contextlib + +import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework +import paddle.fluid.profiler as profiler +from paddle.fluid.executor import Executor +import reader + +import sys +if sys.version[0] == '2': + reload(sys) + sys.setdefaultencoding("utf-8") +import os + +from args import * +from model import VAE +import logging +import pickle + + +@contextlib.contextmanager +def profile_context(profile=True): + if profile: + with profiler.profiler('All', 'total', 'seq2seq.profile'): + yield + else: + yield + + +def main(): + args = parse_args() + print(args) + num_layers = args.num_layers + src_vocab_size = args.vocab_size + tar_vocab_size = args.vocab_size + batch_size = args.batch_size + init_scale = args.init_scale + max_grad_norm = args.max_grad_norm + hidden_size = args.hidden_size + attr_init = args.attr_init + latent_size = 32 + + main_program = fluid.Program() + startup_program = fluid.Program() + if args.enable_ce: + fluid.default_main_program().random_seed = 123 + framework.default_startup_program().random_seed = 123 + + # Training process + with fluid.program_guard(main_program, startup_program): + with fluid.unique_name.guard(): + model = VAE(hidden_size, + latent_size, + src_vocab_size, + tar_vocab_size, + batch_size, + num_layers=num_layers, + init_scale=init_scale, + attr_init=attr_init) + + loss, kl_loss, rec_loss = model.build_graph() + # clone from default main program and use it as the validation program + main_program = fluid.default_main_program() + inference_program = fluid.default_main_program().clone( + for_test=True) + + fluid.clip.set_gradient_clip( + clip=fluid.clip.GradientClipByGlobalNorm( + clip_norm=max_grad_norm)) + + learning_rate = fluid.layers.create_global_var( + name="learning_rate", + shape=[1], + value=float(args.learning_rate), + dtype="float32", + persistable=True) + + opt_type = args.optimizer + if opt_type == "sgd": + optimizer = fluid.optimizer.SGD(learning_rate) + elif opt_type == "adam": + optimizer = fluid.optimizer.Adam(learning_rate) + else: + print("only support [sgd|adam]") + raise Exception("opt type not support") + + optimizer.minimize(loss) + + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() + exe = Executor(place) + exe.run(startup_program) + + train_program = fluid.compiler.CompiledProgram(main_program) + + dataset_prefix = args.dataset_prefix + print("begin to load data") + raw_data = reader.raw_data(dataset_prefix, args.max_len) + print("finished load data") + train_data, valid_data, test_data, _ = raw_data + + anneal_r = 1.0 / (args.warm_up * len(train_data) / args.batch_size) + + def prepare_input(batch, kl_weight=1.0, lr=None): + src_ids, src_mask = batch + res = {} + src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1])) + in_tar = src_ids[:, :-1] + label_tar = src_ids[:, 1:] + + in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1])) + label_tar = label_tar.reshape( + (label_tar.shape[0], label_tar.shape[1], 1)) + + res['src'] = src_ids + res['tar'] = in_tar + res['label'] = label_tar + res['src_sequence_length'] = src_mask + res['tar_sequence_length'] = src_mask - 1 + res['kl_weight'] = np.array([kl_weight]).astype(np.float32) + if lr is not None: + res['learning_rate'] = np.array([lr]).astype(np.float32) + + return res, np.sum(src_mask), np.sum(src_mask - 1) + + # get train epoch size + def eval(data): + eval_data_iter = reader.get_data_iter(data, batch_size, mode='eval') + total_loss = 0.0 + word_count = 0.0 + batch_count = 0.0 + for batch_id, batch in enumerate(eval_data_iter): + input_data_feed, src_word_num, dec_word_sum = prepare_input(batch) + fetch_outs = exe.run(inference_program, + feed=input_data_feed, + fetch_list=[loss.name], + use_program_cache=False) + + cost_train = np.array(fetch_outs[0]) + + total_loss += cost_train * batch_size + word_count += dec_word_sum + batch_count += batch_size + + nll = total_loss / batch_count + ppl = np.exp(total_loss / word_count) + + return nll, ppl + + def train(): + ce_time = [] + ce_ppl = [] + max_epoch = args.max_epoch + kl_w = args.kl_start + lr_w = args.learning_rate + best_valid_nll = 1e100 # +inf + best_epoch_id = -1 + decay_cnt = 0 + max_decay = args.max_decay + decay_factor = 0.5 + decay_ts = 2 + steps_not_improved = 0 + for epoch_id in range(max_epoch): + start_time = time.time() + if args.enable_ce: + train_data_iter = reader.get_data_iter( + train_data, + batch_size, + args.sort_cache, + args.cache_num, + enable_ce=True) + else: + train_data_iter = reader.get_data_iter( + train_data, batch_size, args.sort_cache, args.cache_num) + + total_loss = 0 + total_rec_loss = 0 + total_kl_loss = 0 + word_count = 0.0 + batch_count = 0.0 + batch_times = [] + batch_start_time = time.time() + for batch_id, batch in enumerate(train_data_iter): + kl_w = min(1.0, kl_w + anneal_r) + kl_weight = kl_w + input_data_feed, src_word_num, dec_word_sum = prepare_input( + batch, kl_weight, lr_w) + fetch_outs = exe.run( + program=train_program, + feed=input_data_feed, + fetch_list=[loss.name, kl_loss.name, rec_loss.name], + use_program_cache=False) + + cost_train = np.array(fetch_outs[0]) + kl_cost_train = np.array(fetch_outs[1]) + rec_cost_train = np.array(fetch_outs[2]) + + total_loss += cost_train * batch_size + total_rec_loss += rec_cost_train * batch_size + total_kl_loss += kl_cost_train * batch_size + word_count += dec_word_sum + batch_count += batch_size + batch_end_time = time.time() + batch_time = batch_end_time - batch_start_time + batch_times.append(batch_time) + + if batch_id > 0 and batch_id % 200 == 0: + print("-- Epoch:[%d]; Batch:[%d]; Time: %.4f s; " + "kl_weight: %.4f; kl_loss: %.4f; rec_loss: %.4f; " + "nll: %.4f; ppl: %.4f" % + (epoch_id, batch_id, batch_time, kl_w, total_kl_loss / + batch_count, total_rec_loss / batch_count, total_loss + / batch_count, np.exp(total_loss / word_count))) + ce_ppl.append(np.exp(total_loss / word_count)) + + end_time = time.time() + epoch_time = end_time - start_time + ce_time.append(epoch_time) + print( + "\nTrain epoch:[%d]; Epoch Time: %.4f; avg_time: %.4f s/step\n" + % (epoch_id, epoch_time, sum(batch_times) / len(batch_times))) + + val_nll, val_ppl = eval(valid_data) + print("dev ppl", val_ppl) + test_nll, test_ppl = eval(test_data) + print("test ppl", test_ppl) + + if val_nll < best_valid_nll: + best_valid_nll = val_nll + steps_not_improved = 0 + best_nll = test_nll + best_ppl = test_ppl + best_epoch_id = epoch_id + dir_name = os.path.join(args.model_path, + "epoch_" + str(best_epoch_id)) + print("save model {}".format(dir_name)) + fluid.io.save_params(exe, dir_name, main_program) + else: + steps_not_improved += 1 + if steps_not_improved == decay_ts: + old_lr = lr_w + lr_w *= decay_factor + steps_not_improved = 0 + new_lr = lr_w + + print('-----\nchange lr, old lr: %f, new lr: %f\n-----' % + (old_lr, new_lr)) + + dir_name = args.model_path + "/epoch_" + str(best_epoch_id) + fluid.io.load_params(exe, dir_name) + + decay_cnt += 1 + if decay_cnt == max_decay: + break + + print('\nbest testing nll: %.4f, best testing ppl %.4f\n' % + (best_nll, best_ppl)) + + with profile_context(args.profile): + train() + + +def get_cards(): + num = 0 + cards = os.environ.get('CUDA_VISIBLE_DEVICES', '') + if cards != '': + num = len(cards.split(",")) + return num + + +def check_version(): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version 1.6 or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + + try: + fluid.require_version('1.6.0') + except Exception as e: + logger.error(err) + sys.exit(1) + + +if __name__ == '__main__': + check_version() + main() -- GitLab