From 0442a432aeef59597719d05603dc7ac0e8623836 Mon Sep 17 00:00:00 2001
From: Xing Wu <1160386409@qq.com>
Date: Tue, 22 Oct 2019 21:44:59 +0800
Subject: [PATCH] add new seq2seq example: rnn_search (#3574)
* re-commit seq2eq example: rnn search
* add new_seq2seq examples: rnn_search and vae_text
* modify dataset description in vae_text/README.md
* modify vae_text/README.md and vae_text/download.py
* update vae_text/infer.sh and vae_text/README.md
* add multi-gpu support for rnn_search; fix windows support for rnn_search and vae_text
* remove parallel argument in args.py
* remove hard codes
* remove hard codes
* update download.py for windows download
* fix model path for windows in vae_text/infer.py
* fix python2 utf-encoding
---
PaddleNLP/PaddleTextGEN/rnn_search/README.md | 129 ++++++
.../PaddleTextGEN/rnn_search/__init__.py | 0
PaddleNLP/PaddleTextGEN/rnn_search/args.py | 127 ++++++
.../rnn_search/attention_model.py | 156 ++++++++
.../PaddleTextGEN/rnn_search/base_model.py | 226 +++++++++++
.../PaddleTextGEN/rnn_search/download.py | 55 +++
PaddleNLP/PaddleTextGEN/rnn_search/infer.py | 184 +++++++++
PaddleNLP/PaddleTextGEN/rnn_search/infer.sh | 22 ++
PaddleNLP/PaddleTextGEN/rnn_search/reader.py | 211 ++++++++++
PaddleNLP/PaddleTextGEN/rnn_search/run.sh | 20 +
PaddleNLP/PaddleTextGEN/rnn_search/train.py | 277 +++++++++++++
PaddleNLP/PaddleTextGEN/vae_text/README.md | 107 +++++
PaddleNLP/PaddleTextGEN/vae_text/__init__.py | 0
PaddleNLP/PaddleTextGEN/vae_text/args.py | 163 ++++++++
PaddleNLP/PaddleTextGEN/vae_text/download.py | 92 +++++
PaddleNLP/PaddleTextGEN/vae_text/infer.py | 128 ++++++
PaddleNLP/PaddleTextGEN/vae_text/infer.sh | 16 +
PaddleNLP/PaddleTextGEN/vae_text/model.py | 369 ++++++++++++++++++
PaddleNLP/PaddleTextGEN/vae_text/reader.py | 206 ++++++++++
PaddleNLP/PaddleTextGEN/vae_text/run.sh | 15 +
PaddleNLP/PaddleTextGEN/vae_text/train.py | 313 +++++++++++++++
21 files changed, 2816 insertions(+)
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/README.md
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/__init__.py
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/args.py
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/attention_model.py
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/base_model.py
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/download.py
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/infer.py
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/infer.sh
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/reader.py
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/run.sh
create mode 100644 PaddleNLP/PaddleTextGEN/rnn_search/train.py
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/README.md
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/__init__.py
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/args.py
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/download.py
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/infer.py
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/infer.sh
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/model.py
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/reader.py
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/run.sh
create mode 100644 PaddleNLP/PaddleTextGEN/vae_text/train.py
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/README.md b/PaddleNLP/PaddleTextGEN/rnn_search/README.md
new file mode 100644
index 00000000..9da6ef3e
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/README.md
@@ -0,0 +1,129 @@
+运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.6版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
+
+# 机器翻译:RNN Search
+
+以下是本范例模型的简要目录结构及说明:
+
+```
+.
+├── README.md # 文档,本文件
+├── args.py # 训练、预测以及模型参数配置程序
+├── reader.py # 数据读入程序
+├── download.py # 数据下载程序
+├── train.py # 训练主程序
+├── infer.py # 预测主程序
+├── run.sh # 默认配置的启动脚本
+├── infer.sh # 默认配置的解码脚本
+├── attention_model.py # 带注意力机制的翻译模型程序
+└── base_model.py # 无注意力机制的翻译模型程序
+```
+
+## 简介
+
+机器翻译(machine translation, MT)是用计算机来实现不同语言之间翻译的技术。被翻译的语言通常称为源语言(source language),翻译成的结果语言称为目标语言(target language)。机器翻译即实现从源语言到目标语言转换的过程,是自然语言处理的重要研究领域之一。
+
+近年来,深度学习技术的发展不断为机器翻译任务带来新的突破。直接用神经网络将源语言映射到目标语言,即端到端的神经网络机器翻译(End-to-End Neural Machine Translation, End-to-End NMT)模型逐渐成为主流,此类模型一般简称为NMT模型。
+
+本目录包含两个经典的机器翻译模型一个base model(不带attention机制),一个带attention机制的翻译模型 .在现阶段,其表现已被很多新模型(如[Transformer](https://arxiv.org/abs/1706.03762))超越。但除机器翻译外,该模型是许多序列到序列(sequence to sequence, 以下简称Seq2Seq)类模型的基础,很多解决其他NLP问题的模型均以此模型为基础;因此其在NLP领域具有重要意义,并被广泛用作Baseline.
+
+本目录下此范例模型的实现,旨在展示如何用Paddle Fluid的 **新Seq2Seq API** 实现一个带有注意力机制(Attention)的RNN模型来解决Seq2Seq类问题,以及如何使用带有Beam Search算法的解码器。如果您仅仅只是需要在机器翻译方面有着较好翻译效果的模型,则建议您参考[Transformer的Paddle Fluid实现](https://github.com/PaddlePaddle/models/tree/develop/fluid/neural_machine_translation/transformer)。
+
+**新 Seq2Seq API 组网更简单,从1.6版本开始不推荐使用low-level的API。如果您确实需要使用low-level的API来实现自己模型,样例可参看1.5版本 [RNN Search](https://github.com/PaddlePaddle/models/tree/release/1.5/PaddleNLP/unarchived/neural_machine_translation/rnn_search)。**
+
+## 模型概览
+
+RNN Search模型使用了经典的编码器-解码器(Encoder-Decoder)的框架结构来解决Seq2Seq类问题。这种方法先用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。这其实模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。编码器和解码器往往都使用RNN来实现。关于此方法的具体原理和数学表达式,可以参考[深度学习101](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/basics/machine_translation/index.html).
+
+本模型中,在编码器方面,我们采用了基于LSTM的多层的encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比;而在预测方面我们使用柱搜索(beam search)算法来生成翻译的目标语句。以下将分别介绍用到的这些方法。
+
+## 数据介绍
+
+本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集
+
+### 数据获取
+
+```
+python download.py
+```
+
+## 模型训练
+
+`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行:
+
+```
+sh run.sh
+```
+
+默认使用带有注意力机制的RNN模型,可以通过修改 --attention 为False来训练不带注意力机制的RNN模型。
+
+```
+python train.py \
+ --src_lang en --tar_lang vi \
+ --attention True \
+ --num_layers 2 \
+ --hidden_size 512 \
+ --src_vocab_size 17191 \
+ --tar_vocab_size 7709 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --train_data_prefix data/en-vi/train \
+ --eval_data_prefix data/en-vi/tst2012 \
+ --test_data_prefix data/en-vi/tst2013 \
+ --vocab_prefix data/en-vi/vocab \
+ --use_gpu True \
+ --model_path ./attention_models
+```
+
+训练程序会在每个epoch训练结束之后,save一次模型。
+
+## 模型预测
+
+当模型训练完成之后, 可以利用infer.sh的脚本进行预测,默认使用beam search的方法进行预测,加载第10个epoch的模型进行预测,对test的数据集进行解码
+
+```
+sh infer.sh
+```
+
+如果想预测别的数据文件,只需要将 --infer_file参数进行修改。
+
+```
+python infer.py \
+ --attention True \
+ --src_lang en --tar_lang vi \
+ --num_layers 2 \
+ --hidden_size 512 \
+ --src_vocab_size 17191 \
+ --tar_vocab_size 7709 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --vocab_prefix data/en-vi/vocab \
+ --infer_file data/en-vi/tst2013.en \
+ --reload_model attention_models/epoch_10/ \
+ --infer_output_file attention_infer_output/infer_output.txt \
+ --beam_size 10 \
+ --use_gpu True
+```
+
+## 效果评价
+
+使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下:
+
+```sh
+mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt
+```
+
+单个模型 beam_size = 10的效果如下:
+
+```
+> no attention
+tst2012 BLEU: 10.99
+tst2013 BLEU: 11.23
+
+>with attention
+tst2012 BLEU: 22.85
+tst2013 BLEU: 25.68
+```
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/__init__.py b/PaddleNLP/PaddleTextGEN/rnn_search/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/args.py b/PaddleNLP/PaddleTextGEN/rnn_search/args.py
new file mode 100644
index 00000000..ee056e33
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/args.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import distutils.util
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--train_data_prefix", type=str, help="file prefix for train data")
+ parser.add_argument(
+ "--eval_data_prefix", type=str, help="file prefix for eval data")
+ parser.add_argument(
+ "--test_data_prefix", type=str, help="file prefix for test data")
+ parser.add_argument(
+ "--vocab_prefix", type=str, help="file prefix for vocab")
+ parser.add_argument("--src_lang", type=str, help="source language suffix")
+ parser.add_argument("--tar_lang", type=str, help="target language suffix")
+
+ parser.add_argument(
+ "--attention",
+ type=eval,
+ default=False,
+ help="Whether use attention model")
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default='adam',
+ help="optimizer to use, only supprt[sgd|adam]")
+
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=0.001,
+ help="learning rate for optimizer")
+
+ parser.add_argument(
+ "--num_layers",
+ type=int,
+ default=1,
+ help="layers number of encoder and decoder")
+ parser.add_argument(
+ "--hidden_size",
+ type=int,
+ default=100,
+ help="hidden size of encoder and decoder")
+ parser.add_argument("--src_vocab_size", type=int, help="source vocab size")
+ parser.add_argument("--tar_vocab_size", type=int, help="target vocab size")
+
+ parser.add_argument(
+ "--batch_size", type=int, help="batch size of each step")
+
+ parser.add_argument(
+ "--max_epoch", type=int, default=12, help="max epoch for the training")
+
+ parser.add_argument(
+ "--max_len",
+ type=int,
+ default=50,
+ help="max length for source and target sentence")
+ parser.add_argument(
+ "--dropout", type=float, default=0.0, help="drop probability")
+ parser.add_argument(
+ "--init_scale",
+ type=float,
+ default=0.0,
+ help="init scale for parameter")
+ parser.add_argument(
+ "--max_grad_norm",
+ type=float,
+ default=5.0,
+ help="max grad norm for global norm clip")
+
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default='model',
+ help="model path for model to save")
+
+ parser.add_argument(
+ "--reload_model", type=str, help="reload model to inference")
+
+ parser.add_argument(
+ "--infer_file", type=str, help="file name for inference")
+ parser.add_argument(
+ "--infer_output_file",
+ type=str,
+ default='infer_output',
+ help="file name for inference output")
+ parser.add_argument(
+ "--beam_size", type=int, default=10, help="file name for inference")
+
+ parser.add_argument(
+ '--use_gpu',
+ type=eval,
+ default=False,
+ help='Whether using gpu [True|False]')
+
+ parser.add_argument(
+ "--enable_ce",
+ action='store_true',
+ help="The flag indicating whether to run the task "
+ "for continuous evaluation.")
+
+ parser.add_argument(
+ "--profile", action='store_true', help="Whether enable the profile.")
+
+ args = parser.parse_args()
+ return args
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/attention_model.py b/PaddleNLP/PaddleTextGEN/rnn_search/attention_model.py
new file mode 100644
index 00000000..4f53aa97
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/attention_model.py
@@ -0,0 +1,156 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle.fluid.layers as layers
+import paddle.fluid as fluid
+import numpy as np
+from paddle.fluid import ParamAttr
+from paddle.fluid.contrib.layers import basic_lstm, BasicLSTMUnit
+from base_model import BaseModel, DecoderCell
+from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode
+
+INF = 1. * 1e5
+alpha = 0.6
+
+
+class AttentionDecoderCell(DecoderCell):
+ def __init__(self, num_layers, hidden_size, dropout_prob=0.,
+ init_scale=0.1):
+ super(AttentionDecoderCell, self).__init__(num_layers, hidden_size,
+ dropout_prob, init_scale)
+
+ def attention(self, query, enc_output, mask=None):
+ query = layers.unsqueeze(query, [1])
+ memory = layers.fc(enc_output,
+ self.hidden_size,
+ num_flatten_dims=2,
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.UniformInitializer(
+ low=-self.init_scale, high=self.init_scale)),
+ bias_attr=False)
+ attn = layers.matmul(query, memory, transpose_y=True)
+
+ if mask:
+ attn = layers.transpose(attn, [1, 0, 2])
+ attn = layers.elementwise_add(attn, mask * 1000000000, -1)
+ attn = layers.transpose(attn, [1, 0, 2])
+ weight = layers.softmax(attn)
+ weight_memory = layers.matmul(weight, memory)
+
+ return weight_memory
+
+ def call(self, step_input, states, enc_output, enc_padding_mask=None):
+ lstm_states, input_feed = states
+ new_lstm_states = []
+ step_input = layers.concat([step_input, input_feed], 1)
+ for i in range(self.num_layers):
+ out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
+ step_input = layers.dropout(
+ out,
+ self.dropout_prob,
+ dropout_implementation='upscale_in_train'
+ ) if self.dropout_prob > 0 else out
+ new_lstm_states.append(new_lstm_state)
+ dec_att = self.attention(step_input, enc_output, enc_padding_mask)
+ dec_att = layers.squeeze(dec_att, [1])
+ concat_att_out = layers.concat([dec_att, step_input], 1)
+ out = layers.fc(concat_att_out,
+ self.hidden_size,
+ param_attr=ParamAttr(
+ initializer=fluid.initializer.UniformInitializer(
+ low=-self.init_scale, high=self.init_scale)),
+ bias_attr=False)
+ return out, [new_lstm_states, out]
+
+
+class AttentionModel(BaseModel):
+ def __init__(self,
+ hidden_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=1,
+ init_scale=0.1,
+ dropout=None,
+ beam_start_token=1,
+ beam_end_token=2,
+ beam_max_step_num=100):
+ super(AttentionModel, self).__init__(
+ hidden_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=num_layers,
+ init_scale=init_scale,
+ dropout=dropout)
+
+ def _build_decoder(self, enc_final_state, mode='train', beam_size=10):
+ output_layer = lambda x: layers.fc(x,
+ size=self.tar_vocab_size,
+ num_flatten_dims=len(x.shape) - 1,
+ param_attr=fluid.ParamAttr(name="output_w",
+ initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale)),
+ bias_attr=False)
+
+ dec_cell = AttentionDecoderCell(self.num_layers, self.hidden_size,
+ self.dropout, self.init_scale)
+ dec_initial_states = [
+ enc_final_state, dec_cell.get_initial_states(
+ batch_ref=self.enc_output, shape=[self.hidden_size])
+ ]
+ max_src_seq_len = layers.shape(self.src)[1]
+ src_mask = layers.sequence_mask(
+ self.src_sequence_length, maxlen=max_src_seq_len, dtype='float32')
+ enc_padding_mask = (src_mask - 1.0)
+ if mode == 'train':
+ dec_output, _ = rnn(cell=dec_cell,
+ inputs=self.tar_emb,
+ initial_states=dec_initial_states,
+ sequence_length=None,
+ enc_output=self.enc_output,
+ enc_padding_mask=enc_padding_mask)
+
+ dec_output = output_layer(dec_output)
+
+ elif mode == 'beam_search':
+ output_layer = lambda x: layers.fc(x,
+ size=self.tar_vocab_size,
+ num_flatten_dims=len(x.shape) - 1,
+ param_attr=fluid.ParamAttr(name="output_w"),
+ bias_attr=False)
+ beam_search_decoder = BeamSearchDecoder(
+ dec_cell,
+ self.beam_start_token,
+ self.beam_end_token,
+ beam_size,
+ embedding_fn=self.tar_embeder,
+ output_fn=output_layer)
+ enc_output = beam_search_decoder.tile_beam_merge_with_batch(
+ self.enc_output, beam_size)
+ enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch(
+ enc_padding_mask, beam_size)
+ outputs, _ = dynamic_decode(
+ beam_search_decoder,
+ inits=dec_initial_states,
+ max_step_num=self.beam_max_step_num,
+ enc_output=enc_output,
+ enc_padding_mask=enc_padding_mask)
+ return outputs
+
+ return dec_output
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/base_model.py b/PaddleNLP/PaddleTextGEN/rnn_search/base_model.py
new file mode 100644
index 00000000..f4ee1b95
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/base_model.py
@@ -0,0 +1,226 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle.fluid.layers as layers
+import paddle.fluid as fluid
+import numpy as np
+from paddle.fluid import ParamAttr
+from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode
+
+INF = 1. * 1e5
+alpha = 0.6
+uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x)
+zero_constant = fluid.initializer.Constant(0.0)
+
+
+class EncoderCell(RNNCell):
+ def __init__(
+ self,
+ num_layers,
+ hidden_size,
+ dropout_prob=0.,
+ init_scale=0.1, ):
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size
+ self.dropout_prob = dropout_prob
+ self.lstm_cells = []
+
+ param_attr = ParamAttr(initializer=uniform_initializer(init_scale))
+ bias_attr = ParamAttr(initializer=zero_constant)
+ for i in range(num_layers):
+ self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr))
+
+ def call(self, step_input, states):
+ new_states = []
+ for i in range(self.num_layers):
+ out, new_state = self.lstm_cells[i](step_input, states[i])
+ step_input = layers.dropout(
+ out,
+ self.dropout_prob,
+ dropout_implementation="upscale_in_train"
+ ) if self.dropout_prob > 0. else out
+ new_states.append(new_state)
+ return step_input, new_states
+
+ @property
+ def state_shape(self):
+ return [cell.state_shape for cell in self.lstm_cells]
+
+
+class DecoderCell(RNNCell):
+ def __init__(self, num_layers, hidden_size, dropout_prob=0.,
+ init_scale=0.1):
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size
+ self.dropout_prob = dropout_prob
+ self.lstm_cells = []
+ self.init_scale = init_scale
+ param_attr = ParamAttr(initializer=uniform_initializer(init_scale))
+ bias_attr = ParamAttr(initializer=zero_constant)
+ for i in range(num_layers):
+ self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr))
+
+ def call(self, step_input, states):
+ new_lstm_states = []
+ for i in range(self.num_layers):
+ out, new_lstm_state = self.lstm_cells[i](step_input, states[i])
+ step_input = layers.dropout(
+ out,
+ self.dropout_prob,
+ dropout_implementation="upscale_in_train"
+ ) if self.dropout_prob > 0. else out
+ new_lstm_states.append(new_lstm_state)
+ return step_input, new_lstm_states
+
+
+class BaseModel(object):
+ def __init__(self,
+ hidden_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=1,
+ init_scale=0.1,
+ dropout=None,
+ beam_start_token=1,
+ beam_end_token=2,
+ beam_max_step_num=100):
+
+ self.hidden_size = hidden_size
+ self.src_vocab_size = src_vocab_size
+ self.tar_vocab_size = tar_vocab_size
+ self.batch_size = batch_size
+ self.num_layers = num_layers
+ self.init_scale = init_scale
+ self.dropout = dropout
+ self.beam_start_token = beam_start_token
+ self.beam_end_token = beam_end_token
+ self.beam_max_step_num = beam_max_step_num
+ self.src_embeder = lambda x: fluid.embedding(
+ input=x,
+ size=[self.src_vocab_size, self.hidden_size],
+ dtype='float32',
+ is_sparse=False,
+ param_attr=fluid.ParamAttr(
+ name='source_embedding',
+ initializer=uniform_initializer(init_scale)))
+
+ self.tar_embeder = lambda x: fluid.embedding(
+ input=x,
+ size=[self.tar_vocab_size, self.hidden_size],
+ dtype='float32',
+ is_sparse=False,
+ param_attr=fluid.ParamAttr(
+ name='target_embedding',
+ initializer=uniform_initializer(init_scale)))
+
+ def _build_data(self):
+ self.src = fluid.data(name="src", shape=[None, None], dtype='int64')
+ self.src_sequence_length = fluid.data(
+ name="src_sequence_length", shape=[None], dtype='int32')
+
+ self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64')
+ self.tar_sequence_length = fluid.data(
+ name="tar_sequence_length", shape=[None], dtype='int32')
+ self.label = fluid.data(
+ name="label", shape=[None, None, 1], dtype='int64')
+
+ def _emebdding(self):
+ self.src_emb = self.src_embeder(self.src)
+ self.tar_emb = self.tar_embeder(self.tar)
+
+ def _build_encoder(self):
+ enc_cell = EncoderCell(self.num_layers, self.hidden_size, self.dropout,
+ self.init_scale)
+ self.enc_output, enc_final_state = rnn(
+ cell=enc_cell,
+ inputs=self.src_emb,
+ sequence_length=self.src_sequence_length)
+ return self.enc_output, enc_final_state
+
+ def _build_decoder(self, enc_final_state, mode='train', beam_size=10):
+
+ dec_cell = DecoderCell(self.num_layers, self.hidden_size, self.dropout,
+ self.init_scale)
+ output_layer = lambda x: layers.fc(x,
+ size=self.tar_vocab_size,
+ num_flatten_dims=len(x.shape) - 1,
+ param_attr=fluid.ParamAttr(name="output_w",
+ initializer=uniform_initializer(self.init_scale)),
+ bias_attr=False)
+
+ if mode == 'train':
+ dec_output, dec_final_state = rnn(cell=dec_cell,
+ inputs=self.tar_emb,
+ initial_states=enc_final_state)
+
+ dec_output = output_layer(dec_output)
+
+ return dec_output
+ elif mode == 'beam_search':
+ beam_search_decoder = BeamSearchDecoder(
+ dec_cell,
+ self.beam_start_token,
+ self.beam_end_token,
+ beam_size,
+ embedding_fn=self.tar_embeder,
+ output_fn=output_layer)
+
+ outputs, _ = dynamic_decode(
+ beam_search_decoder,
+ inits=enc_final_state,
+ max_step_num=self.beam_max_step_num)
+ return outputs
+
+ def _compute_loss(self, dec_output):
+ loss = layers.softmax_with_cross_entropy(
+ logits=dec_output, label=self.label, soft_label=False)
+ loss = layers.unsqueeze(loss, axes=[2])
+
+ max_tar_seq_len = layers.shape(self.tar)[1]
+ tar_mask = layers.sequence_mask(
+ self.tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32')
+ loss = loss * tar_mask
+ loss = layers.reduce_mean(loss, dim=[0])
+ loss = layers.reduce_sum(loss)
+ return loss
+
+ def _beam_search(self, enc_last_hidden, enc_last_cell):
+ pass
+
+ def build_graph(self, mode='train', beam_size=10):
+ if mode == 'train' or mode == 'eval':
+ self._build_data()
+ self._emebdding()
+ enc_output, enc_final_state = self._build_encoder()
+ dec_output = self._build_decoder(enc_final_state)
+
+ loss = self._compute_loss(dec_output)
+ return loss
+ elif mode == "beam_search" or mode == 'greedy_search':
+ self._build_data()
+ self._emebdding()
+ enc_output, enc_final_state = self._build_encoder()
+ dec_output = self._build_decoder(
+ enc_final_state, mode=mode, beam_size=beam_size)
+
+ return dec_output
+ else:
+ print("not support mode ", mode)
+ raise Exception("not support mode: " + mode)
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/download.py b/PaddleNLP/PaddleTextGEN/rnn_search/download.py
new file mode 100644
index 00000000..4dd1466d
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/download.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+'''
+Script for downloading training data.
+'''
+import os
+import urllib
+import sys
+
+if sys.version_info >= (3, 0):
+ import urllib.request
+import zipfile
+
+URLLIB = urllib
+if sys.version_info >= (3, 0):
+ URLLIB = urllib.request
+
+remote_path = 'https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi'
+base_path = 'data'
+tar_path = os.path.join(base_path, 'en-vi')
+filenames = [
+ 'train.en', 'train.vi', 'tst2012.en', 'tst2012.vi', 'tst2013.en',
+ 'tst2013.vi', 'vocab.en', 'vocab.vi'
+]
+
+
+def main(arguments):
+ print("Downloading data......")
+
+ if not os.path.exists(tar_path):
+ if not os.path.exists(base_path):
+ os.mkdir(base_path)
+ os.mkdir(tar_path)
+
+ for filename in filenames:
+ url = remote_path + '/' + filename
+ tar_file = os.path.join(tar_path, filename)
+ URLLIB.urlretrieve(url, tar_file)
+ print("Downloaded sucess......")
+
+
+if __name__ == '__main__':
+ sys.exit(main(sys.argv[1:]))
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/infer.py b/PaddleNLP/PaddleTextGEN/rnn_search/infer.py
new file mode 100644
index 00000000..47240429
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/infer.py
@@ -0,0 +1,184 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import time
+import os
+import random
+import logging
+import math
+import io
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.framework as framework
+from paddle.fluid.executor import Executor
+
+import reader
+
+import sys
+line_tok = '\n'
+space_tok = ' '
+if sys.version[0] == '2':
+ reload(sys)
+ sys.setdefaultencoding("utf-8")
+ line_tok = u'\n'
+ space_tok = u' '
+
+logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger("fluid")
+logger.setLevel(logging.INFO)
+
+from args import *
+import logging
+import pickle
+
+from attention_model import AttentionModel
+from base_model import BaseModel
+
+
+def infer():
+ args = parse_args()
+
+ num_layers = args.num_layers
+ src_vocab_size = args.src_vocab_size
+ tar_vocab_size = args.tar_vocab_size
+ batch_size = args.batch_size
+ dropout = args.dropout
+ init_scale = args.init_scale
+ max_grad_norm = args.max_grad_norm
+ hidden_size = args.hidden_size
+ # inference process
+
+ print("src", src_vocab_size)
+
+ # dropout type using upscale_in_train, dropout can be remove in inferecen
+ # So we can set dropout to 0
+ if args.attention:
+ model = AttentionModel(
+ hidden_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=num_layers,
+ init_scale=init_scale,
+ dropout=0.0)
+ else:
+ model = BaseModel(
+ hidden_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=num_layers,
+ init_scale=init_scale,
+ dropout=0.0)
+
+ beam_size = args.beam_size
+ trans_res = model.build_graph(mode='beam_search', beam_size=beam_size)
+ # clone from default main program and use it as the validation program
+ main_program = fluid.default_main_program()
+ main_program = main_program.clone(for_test=True)
+ print([param.name for param in main_program.blocks[0].all_parameters()])
+
+ place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
+ exe = Executor(place)
+ exe.run(framework.default_startup_program())
+
+ source_vocab_file = args.vocab_prefix + "." + args.src_lang
+ infer_file = args.infer_file
+
+ infer_data = reader.raw_mono_data(source_vocab_file, infer_file)
+
+ def prepare_input(batch, epoch_id=0, with_lr=True):
+ src_ids, src_mask, tar_ids, tar_mask = batch
+ res = {}
+ src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
+ in_tar = tar_ids[:, :-1]
+ label_tar = tar_ids[:, 1:]
+
+ in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
+ in_tar = np.zeros_like(in_tar, dtype='int64')
+ label_tar = label_tar.reshape(
+ (label_tar.shape[0], label_tar.shape[1], 1))
+ label_tar = np.zeros_like(label_tar, dtype='int64')
+
+ res['src'] = src_ids
+ res['tar'] = in_tar
+ res['label'] = label_tar
+ res['src_sequence_length'] = src_mask
+ res['tar_sequence_length'] = tar_mask
+
+ return res, np.sum(tar_mask)
+
+ dir_name = args.reload_model
+ print("dir name", dir_name)
+ fluid.io.load_params(exe, dir_name)
+
+ train_data_iter = reader.get_data_iter(infer_data, 1, mode='eval')
+
+ tar_id2vocab = []
+ tar_vocab_file = args.vocab_prefix + "." + args.tar_lang
+ with io.open(tar_vocab_file, "r", encoding='utf-8') as f:
+ for line in f.readlines():
+ tar_id2vocab.append(line.strip())
+
+ infer_output_file = args.infer_output_file
+ infer_output_dir = infer_output_file.split('/')[0]
+ if not os.path.exists(infer_output_dir):
+ os.mkdir(infer_output_dir)
+
+ with io.open(infer_output_file, 'w', encoding='utf-8') as out_file:
+
+ for batch_id, batch in enumerate(train_data_iter):
+ input_data_feed, word_num = prepare_input(batch, epoch_id=0)
+ fetch_outs = exe.run(program=main_program,
+ feed=input_data_feed,
+ fetch_list=[trans_res.name],
+ use_program_cache=False)
+
+ for ins in fetch_outs[0]:
+ res = [tar_id2vocab[e] for e in ins[:, 0].reshape(-1)]
+ new_res = []
+ for ele in res:
+ if ele == "":
+ break
+ new_res.append(ele)
+
+ out_file.write(space_tok.join(new_res))
+ out_file.write(line_tok)
+
+
+def check_version():
+ """
+ Log error and exit when the installed version of paddlepaddle is
+ not satisfied.
+ """
+ err = "PaddlePaddle version 1.6 or higher is required, " \
+ "or a suitable develop version is satisfied as well. \n" \
+ "Please make sure the version is good with your code." \
+
+ try:
+ fluid.require_version('1.6.0')
+ except Exception as e:
+ logger.error(err)
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ check_version()
+ infer()
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/infer.sh b/PaddleNLP/PaddleTextGEN/rnn_search/infer.sh
new file mode 100644
index 00000000..6b62b013
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/infer.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=0
+
+python infer.py \
+ --attention True \
+ --src_lang en --tar_lang vi \
+ --num_layers 2 \
+ --hidden_size 512 \
+ --src_vocab_size 17191 \
+ --tar_vocab_size 7709 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --vocab_prefix data/en-vi/vocab \
+ --infer_file data/en-vi/tst2013.en \
+ --reload_model attention_models/epoch_10/ \
+ --infer_output_file attention_infer_output/infer_output.txt \
+ --beam_size 10 \
+ --use_gpu True
+
+
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/reader.py b/PaddleNLP/PaddleTextGEN/rnn_search/reader.py
new file mode 100644
index 00000000..661546f7
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/reader.py
@@ -0,0 +1,211 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for parsing PTB text files."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+import io
+import sys
+import numpy as np
+
+Py3 = sys.version_info[0] == 3
+
+UNK_ID = 0
+
+
+def _read_words(filename):
+ data = []
+ with io.open(filename, "r", encoding='utf-8') as f:
+ if Py3:
+ return f.read().replace("\n", "").split()
+ else:
+ return f.read().decode("utf-8").replace(u"\n", u"").split()
+
+
+def read_all_line(filenam):
+ data = []
+ with io.open(filename, "r", encoding='utf-8') as f:
+ for line in f.readlines():
+ data.append(line.strip())
+
+
+def _build_vocab(filename):
+
+ vocab_dict = {}
+ ids = 0
+ with io.open(filename, "r", encoding='utf-8') as f:
+ for line in f.readlines():
+ vocab_dict[line.strip()] = ids
+ ids += 1
+
+ print("vocab word num", ids)
+
+ return vocab_dict
+
+
+def _para_file_to_ids(src_file, tar_file, src_vocab, tar_vocab):
+
+ src_data = []
+ with io.open(src_file, "r", encoding='utf-8') as f_src:
+ for line in f_src.readlines():
+ arra = line.strip().split()
+ ids = [src_vocab[w] if w in src_vocab else UNK_ID for w in arra]
+ ids = ids
+
+ src_data.append(ids)
+
+ tar_data = []
+ with io.open(tar_file, "r", encoding='utf-8') as f_tar:
+ for line in f_tar.readlines():
+ arra = line.strip().split()
+ ids = [tar_vocab[w] if w in tar_vocab else UNK_ID for w in arra]
+
+ ids = [1] + ids + [2]
+
+ tar_data.append(ids)
+
+ return src_data, tar_data
+
+
+def filter_len(src, tar, max_sequence_len=50):
+ new_src = []
+ new_tar = []
+
+ for id1, id2 in zip(src, tar):
+ if len(id1) > max_sequence_len:
+ id1 = id1[:max_sequence_len]
+ if len(id2) > max_sequence_len + 2:
+ id2 = id2[:max_sequence_len + 2]
+
+ new_src.append(id1)
+ new_tar.append(id2)
+
+ return new_src, new_tar
+
+
+def raw_data(src_lang,
+ tar_lang,
+ vocab_prefix,
+ train_prefix,
+ eval_prefix,
+ test_prefix,
+ max_sequence_len=50):
+
+ src_vocab_file = vocab_prefix + "." + src_lang
+ tar_vocab_file = vocab_prefix + "." + tar_lang
+
+ src_train_file = train_prefix + "." + src_lang
+ tar_train_file = train_prefix + "." + tar_lang
+
+ src_eval_file = eval_prefix + "." + src_lang
+ tar_eval_file = eval_prefix + "." + tar_lang
+
+ src_test_file = test_prefix + "." + src_lang
+ tar_test_file = test_prefix + "." + tar_lang
+
+ src_vocab = _build_vocab(src_vocab_file)
+ tar_vocab = _build_vocab(tar_vocab_file)
+
+ train_src, train_tar = _para_file_to_ids( src_train_file, tar_train_file, \
+ src_vocab, tar_vocab )
+ train_src, train_tar = filter_len(
+ train_src, train_tar, max_sequence_len=max_sequence_len)
+ eval_src, eval_tar = _para_file_to_ids( src_eval_file, tar_eval_file, \
+ src_vocab, tar_vocab )
+
+ test_src, test_tar = _para_file_to_ids( src_test_file, tar_test_file, \
+ src_vocab, tar_vocab )
+
+ return ( train_src, train_tar), (eval_src, eval_tar), (test_src, test_tar),\
+ (src_vocab, tar_vocab)
+
+
+def raw_mono_data(vocab_file, file_path):
+
+ src_vocab = _build_vocab(vocab_file)
+
+ test_src, test_tar = _para_file_to_ids( file_path, file_path, \
+ src_vocab, src_vocab )
+
+ return (test_src, test_tar)
+
+
+def get_data_iter(raw_data,
+ batch_size,
+ mode='train',
+ enable_ce=False,
+ cache_num=20):
+
+ src_data, tar_data = raw_data
+
+ data_len = len(src_data)
+
+ index = np.arange(data_len)
+ if mode == "train" and not enable_ce:
+ np.random.shuffle(index)
+
+ def to_pad_np(data, source=False):
+ max_len = 0
+ for ele in data:
+ if len(ele) > max_len:
+ max_len = len(ele)
+
+ ids = np.ones((batch_size, max_len), dtype='int64') * 2
+ mask = np.zeros((batch_size), dtype='int32')
+
+ for i, ele in enumerate(data):
+ ids[i, :len(ele)] = ele
+ if not source:
+ mask[i] = len(ele) - 1
+ else:
+ mask[i] = len(ele)
+
+ return ids, mask
+
+ b_src = []
+
+ if mode != "train":
+ cache_num = 1
+ for j in range(data_len):
+ if len(b_src) == batch_size * cache_num:
+ # build batch size
+
+ # sort
+ new_cache = sorted(b_src, key=lambda k: len(k[0]))
+
+ for i in range(cache_num):
+ batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
+ src_cache = [w[0] for w in batch_data]
+ tar_cache = [w[1] for w in batch_data]
+ src_ids, src_mask = to_pad_np(src_cache, source=True)
+ tar_ids, tar_mask = to_pad_np(tar_cache)
+ yield (src_ids, src_mask, tar_ids, tar_mask)
+
+ b_src = []
+
+ b_src.append((src_data[index[j]], tar_data[index[j]]))
+ if len(b_src) == batch_size * cache_num:
+ new_cache = sorted(b_src, key=lambda k: len(k[0]))
+
+ for i in range(cache_num):
+ batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
+ src_cache = [w[0] for w in batch_data]
+ tar_cache = [w[1] for w in batch_data]
+ src_ids, src_mask = to_pad_np(src_cache, source=True)
+ tar_ids, tar_mask = to_pad_np(tar_cache)
+ yield (src_ids, src_mask, tar_ids, tar_mask)
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/run.sh b/PaddleNLP/PaddleTextGEN/rnn_search/run.sh
new file mode 100644
index 00000000..25bc78a3
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/run.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+export CUDA_VISIBLE_DEVICES=0
+
+python train.py \
+ --src_lang en --tar_lang vi \
+ --attention True \
+ --num_layers 2 \
+ --hidden_size 512 \
+ --src_vocab_size 17191 \
+ --tar_vocab_size 7709 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --train_data_prefix data/en-vi/train \
+ --eval_data_prefix data/en-vi/tst2012 \
+ --test_data_prefix data/en-vi/tst2013 \
+ --vocab_prefix data/en-vi/vocab \
+ --use_gpu True \
+ --model_path attention_models
diff --git a/PaddleNLP/PaddleTextGEN/rnn_search/train.py b/PaddleNLP/PaddleTextGEN/rnn_search/train.py
new file mode 100644
index 00000000..51d4d29e
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/rnn_search/train.py
@@ -0,0 +1,277 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import time
+import os
+import logging
+import random
+import math
+import contextlib
+
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.framework as framework
+import paddle.fluid.profiler as profiler
+from paddle.fluid.executor import Executor
+
+import reader
+
+import sys
+if sys.version[0] == '2':
+ reload(sys)
+ sys.setdefaultencoding("utf-8")
+
+from args import *
+from base_model import BaseModel
+from attention_model import AttentionModel
+import logging
+import pickle
+
+
+@contextlib.contextmanager
+def profile_context(profile=True):
+ if profile:
+ with profiler.profiler('All', 'total', 'seq2seq.profile'):
+ yield
+ else:
+ yield
+
+
+def main():
+ args = parse_args()
+ print(args)
+ num_layers = args.num_layers
+ src_vocab_size = args.src_vocab_size
+ tar_vocab_size = args.tar_vocab_size
+ batch_size = args.batch_size
+ dropout = args.dropout
+ init_scale = args.init_scale
+ max_grad_norm = args.max_grad_norm
+ hidden_size = args.hidden_size
+
+ if args.enable_ce:
+ fluid.default_main_program().random_seed = 102
+ framework.default_startup_program().random_seed = 102
+
+ train_program = fluid.Program()
+ startup_program = fluid.Program()
+
+ with fluid.program_guard(train_program, startup_program):
+ # Training process
+
+ if args.attention:
+ model = AttentionModel(
+ hidden_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=num_layers,
+ init_scale=init_scale,
+ dropout=dropout)
+ else:
+ model = BaseModel(
+ hidden_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=num_layers,
+ init_scale=init_scale,
+ dropout=dropout)
+ loss = model.build_graph()
+ inference_program = train_program.clone(for_test=True)
+ fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(
+ clip_norm=max_grad_norm))
+ lr = args.learning_rate
+ opt_type = args.optimizer
+ if opt_type == "sgd":
+ optimizer = fluid.optimizer.SGD(lr)
+ elif opt_type == "adam":
+ optimizer = fluid.optimizer.Adam(lr)
+ else:
+ print("only support [sgd|adam]")
+ raise Exception("opt type not support")
+
+ optimizer.minimize(loss)
+
+ place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
+ exe = Executor(place)
+ exe.run(startup_program)
+
+ device_count = len(fluid.cuda_places()) if args.use_gpu else len(
+ fluid.cpu_places())
+
+ CompiledProgram = fluid.CompiledProgram(train_program).with_data_parallel(
+ loss_name=loss.name)
+
+ train_data_prefix = args.train_data_prefix
+ eval_data_prefix = args.eval_data_prefix
+ test_data_prefix = args.test_data_prefix
+ vocab_prefix = args.vocab_prefix
+ src_lang = args.src_lang
+ tar_lang = args.tar_lang
+ print("begin to load data")
+ raw_data = reader.raw_data(src_lang, tar_lang, vocab_prefix,
+ train_data_prefix, eval_data_prefix,
+ test_data_prefix, args.max_len)
+ print("finished load data")
+ train_data, valid_data, test_data, _ = raw_data
+
+ def prepare_input(batch, epoch_id=0, with_lr=True):
+ src_ids, src_mask, tar_ids, tar_mask = batch
+ res = {}
+ src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
+ in_tar = tar_ids[:, :-1]
+ label_tar = tar_ids[:, 1:]
+
+ in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
+ label_tar = label_tar.reshape(
+ (label_tar.shape[0], label_tar.shape[1], 1))
+
+ res['src'] = src_ids
+ res['tar'] = in_tar
+ res['label'] = label_tar
+ res['src_sequence_length'] = src_mask
+ res['tar_sequence_length'] = tar_mask
+
+ return res, np.sum(tar_mask)
+
+ # get train epoch size
+ def eval(data, epoch_id=0):
+ eval_data_iter = reader.get_data_iter(data, batch_size, mode='eval')
+ total_loss = 0.0
+ word_count = 0.0
+ for batch_id, batch in enumerate(eval_data_iter):
+ input_data_feed, word_num = prepare_input(
+ batch, epoch_id, with_lr=False)
+ fetch_outs = exe.run(inference_program,
+ feed=input_data_feed,
+ fetch_list=[loss.name],
+ use_program_cache=False)
+
+ cost_train = np.array(fetch_outs[0])
+
+ total_loss += cost_train * batch_size
+ word_count += word_num
+
+ ppl = np.exp(total_loss / word_count)
+
+ return ppl
+
+ def train():
+ ce_time = []
+ ce_ppl = []
+ max_epoch = args.max_epoch
+ for epoch_id in range(max_epoch):
+ start_time = time.time()
+ if args.enable_ce:
+ train_data_iter = reader.get_data_iter(
+ train_data, batch_size, enable_ce=True)
+ else:
+ train_data_iter = reader.get_data_iter(train_data, batch_size)
+
+ total_loss = 0
+ word_count = 0.0
+ batch_times = []
+ for batch_id, batch in enumerate(train_data_iter):
+ batch_start_time = time.time()
+ input_data_feed, word_num = prepare_input(
+ batch, epoch_id=epoch_id)
+ word_count += word_num
+ fetch_outs = exe.run(program=CompiledProgram,
+ feed=input_data_feed,
+ fetch_list=[loss.name],
+ use_program_cache=True)
+
+ cost_train = np.mean(fetch_outs[0])
+ # print(cost_train)
+ total_loss += cost_train * batch_size
+ batch_end_time = time.time()
+ batch_time = batch_end_time - batch_start_time
+ batch_times.append(batch_time)
+
+ if batch_id > 0 and batch_id % 100 == 0:
+ print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" %
+ (epoch_id, batch_id, batch_time,
+ np.exp(total_loss / word_count)))
+ ce_ppl.append(np.exp(total_loss / word_count))
+ total_loss = 0.0
+ word_count = 0.0
+
+ end_time = time.time()
+ epoch_time = end_time - start_time
+ ce_time.append(epoch_time)
+ print(
+ "\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n"
+ % (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
+
+ if not args.profile:
+ dir_name = os.path.join(args.model_path,
+ "epoch_" + str(epoch_id))
+ print("begin to save", dir_name)
+ fluid.io.save_params(exe, dir_name, main_program=train_program)
+ print("save finished")
+ dev_ppl = eval(valid_data)
+ print("dev ppl", dev_ppl)
+ test_ppl = eval(test_data)
+ print("test ppl", test_ppl)
+
+ if args.enable_ce:
+ card_num = get_cards()
+ _ppl = 0
+ _time = 0
+ try:
+ _time = ce_time[-1]
+ _ppl = ce_ppl[-1]
+ except:
+ print("ce info error")
+ print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time))
+ print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl))
+
+ with profile_context(args.profile):
+ train()
+
+
+def get_cards():
+ num = 0
+ cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
+ if cards != '':
+ num = len(cards.split(","))
+ return num
+
+
+def check_version():
+ """
+ Log error and exit when the installed version of paddlepaddle is
+ not satisfied.
+ """
+ err = "PaddlePaddle version 1.6 or higher is required, " \
+ "or a suitable develop version is satisfied as well. \n" \
+ "Please make sure the version is good with your code." \
+
+ try:
+ fluid.require_version('1.6.0')
+ except Exception as e:
+ logger.error(err)
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ check_version()
+ main()
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/README.md b/PaddleNLP/PaddleTextGEN/vae_text/README.md
new file mode 100644
index 00000000..e367496c
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/README.md
@@ -0,0 +1,107 @@
+运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.6版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
+
+# Variational Autoencoder (VAE) for Text Generation
+
+以下是本范例模型的简要目录结构及说明:
+
+```text
+.
+├── README.md # 文档,本文件
+├── args.py # 训练、预测以及模型参数配置程序
+├── reader.py # 数据读入程序
+├── download.py # 数据下载程序
+├── train.py # 训练主程序
+├── infer.py # 预测主程序
+├── run.sh # 默认配置的启动脚本
+├── infer.sh # 默认配置的解码脚本
+└── model.py # VAE模型配置程序
+
+```
+
+## 简介
+本目录下此范例模型的实现,旨在展示如何用Paddle Fluid的 **新Seq2Seq API** 构建用于文本生成的VAE示例,其中LSTM作为编码器和解码器。 分别对官方PTB数据和SWDA数据进行培训。
+
+关于VAE的详细介绍参照: [(Bowman et al., 2015) Generating Sentences from a Continuous Space](https://arxiv.org/pdf/1511.06349.pdf)
+
+## 数据介绍
+
+本教程使用了两个文本数据集:
+
+PTB dataset,原始下载地址为: http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz。
+
+SWDA dataset,来源于[Knowledge-Guided CVAE for dialog generation](https://arxiv.org/pdf/1703.10960.pdf),原始数据集下载地址为:https://github.com/snakeztc/NeuralDialog-CVAE ,感谢作者@[snakeztc](https://github.com/snakeztc)。我们过滤了数据集中长度小于5的短句子。
+
+### 数据获取
+
+```
+python download.py --task ptb/swda
+```
+
+## 模型训练
+
+`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行:
+
+```
+sh run.sh ptb/swda
+```
+
+如果需要修改模型的参数设置,也可以通过下面命令配置:
+
+```
+python train.py \
+ --vocab_size 10003 \
+ --batch_size 32 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --dataset_prefix data/${dataset}/${dataset} \
+ --model_path ${dataset}_model\
+ --use_gpu True \
+ --max_epoch 50 \
+```
+
+训练程序采用了 Early Stopping,会在每个epoch根据ppl的表现来决定是否保存模型。
+
+## 模型预测
+
+当模型训练完成之后, 可以利用infer.sh的脚本进行预测,选择加载模型保存目录下的第 k 个epoch的模型进行预测,生成batch_size条短文本。
+
+```
+sh infer.sh ptb/swda k
+```
+
+如果需要修改模型预测输出的参数设置,也可以通过下面命令配置:
+
+```
+python infer.py \
+ --vocab_size 10003 \
+ --batch_size 32 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --dataset_prefix data/${dataset}/${dataset} \
+ --use_gpu True \
+ --reload_model ${dataset}_model/epoch_${k} \
+```
+
+## 效果评价
+
+```sh
+PTB数据集:
+Test PPL: 102.24
+Test NLL: 108.22
+
+SWDA数据集:
+Test PPL: 64.21
+Test NLL: 81.92
+```
+
+## 生成样例
+
+the movie are discovered in the u.s. industry that on aircraft variations for a aircraft that was repaired
+
+the percentage of treasury bonds rose to N N at N N up N N from the two days N N and a premium over N
+
+he could n't plunge as a factor that attention now has n't picked up for the state according to mexico
+
+take the remark we need to do then support for the market to tell it
+
+i think that it believes the core of the company in its first quarter of heavy mid-october to fuel earnings after prices on friday
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/__init__.py b/PaddleNLP/PaddleTextGEN/vae_text/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/args.py b/PaddleNLP/PaddleTextGEN/vae_text/args.py
new file mode 100644
index 00000000..1d856577
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/args.py
@@ -0,0 +1,163 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import distutils.util
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--dataset_prefix", type=str, help="file prefix for train data")
+
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default='adam',
+ help="optimizer to use, only supprt[sgd|adam]")
+
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=0.001,
+ help="learning rate for optimizer")
+
+ parser.add_argument(
+ "--num_layers",
+ type=int,
+ default=1,
+ help="layers number of encoder and decoder")
+ parser.add_argument(
+ "--hidden_size",
+ type=int,
+ default=256,
+ help="hidden size of encoder and decoder")
+ parser.add_argument("--vocab_size", type=int, help="source vocab size")
+
+ parser.add_argument(
+ "--batch_size", type=int, help="batch size of each step")
+
+ parser.add_argument(
+ "--max_epoch", type=int, default=20, help="max epoch for the training")
+
+ parser.add_argument(
+ "--max_len",
+ type=int,
+ default=1280,
+ help="max length for source and target sentence")
+ parser.add_argument(
+ "--dec_dropout_in",
+ type=float,
+ default=0.5,
+ help="decoder input drop probability")
+ parser.add_argument(
+ "--dec_dropout_out",
+ type=float,
+ default=0.5,
+ help="decoder output drop probability")
+ parser.add_argument(
+ "--enc_dropout_in",
+ type=float,
+ default=0.,
+ help="encoder input drop probability")
+ parser.add_argument(
+ "--enc_dropout_out",
+ type=float,
+ default=0.,
+ help="encoder output drop probability")
+ parser.add_argument(
+ "--word_keep_prob",
+ type=float,
+ default=0.5,
+ help="word keep probability")
+ parser.add_argument(
+ "--init_scale",
+ type=float,
+ default=0.0,
+ help="init scale for parameter")
+ parser.add_argument(
+ "--max_grad_norm",
+ type=float,
+ default=5.0,
+ help="max grad norm for global norm clip")
+
+ parser.add_argument(
+ "--model_path",
+ type=str,
+ default='model',
+ help="model path for model to save")
+
+ parser.add_argument(
+ "--reload_model", type=str, help="reload model to inference")
+
+ parser.add_argument(
+ "--infer_output_file",
+ type=str,
+ default='infer_output.txt',
+ help="file name for inference output")
+
+ parser.add_argument(
+ "--beam_size", type=int, default=10, help="file name for inference")
+
+ parser.add_argument(
+ '--use_gpu',
+ type=bool,
+ default=False,
+ help='Whether using gpu [True|False]')
+
+ parser.add_argument(
+ "--enable_ce",
+ action='store_true',
+ help="The flag indicating whether to run the task "
+ "for continuous evaluation.")
+
+ parser.add_argument(
+ "--profile", action='store_true', help="Whether enable the profile.")
+
+ parser.add_argument(
+ "--warm_up",
+ type=int,
+ default=10,
+ help='number of warm up epochs for KL')
+
+ parser.add_argument(
+ "--kl_start", type=float, default=0.1, help='KL start value, upto 1.0')
+
+ parser.add_argument(
+ "--attr_init",
+ type=str,
+ default='normal_initializer',
+ help="initializer for paramters")
+
+ parser.add_argument(
+ "--cache_num", type=int, default=1, help='cache num for reader')
+
+ parser.add_argument(
+ "--max_decay",
+ type=int,
+ default=5,
+ help='max decay tries (if exceeds, early stop)')
+
+ parser.add_argument(
+ "--sort_cache",
+ action='store_true',
+ help='sort cache before batch to accelerate training')
+
+ args = parser.parse_args()
+ return args
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/download.py b/PaddleNLP/PaddleTextGEN/vae_text/download.py
new file mode 100644
index 00000000..ba4acacd
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/download.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the 'License');
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an 'AS IS' BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+'''
+Script for downloading training data.
+'''
+import os
+import sys
+import shutil
+import argparse
+import tempfile
+import urllib
+import tarfile
+import io
+if sys.version_info >= (3, 0):
+ import urllib.request
+import zipfile
+
+URLLIB = urllib
+if sys.version_info >= (3, 0):
+ URLLIB = urllib.request
+
+TASKS = ['ptb', 'swda']
+TASK2PATH = {
+ 'ptb': 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz',
+ 'swda': 'https://baidu-nlp.bj.bcebos.com/TextGen/swda.tar.gz'
+}
+
+
+def un_tar(tar_name, dir_name):
+ try:
+ t = tarfile.open(tar_name)
+ t.extractall(path=dir_name)
+ return True
+ except Exception as e:
+ print(e)
+ return False
+
+
+def download_and_extract(task, data_path):
+ print('Downloading and extracting %s...' % task)
+ data_file = os.path.join(data_path, TASK2PATH[task].split('/')[-1])
+ URLLIB.urlretrieve(TASK2PATH[task], data_file)
+ un_tar(data_file, data_path)
+ os.remove(data_file)
+ if task == 'ptb':
+ src_dir = os.path.join(data_path, 'simple-examples')
+ dst_dir = os.path.join(data_path, 'ptb')
+ if not os.path.exists(dst_dir):
+ os.mkdir(dst_dir)
+ shutil.copy(os.path.join(src_dir, 'data/ptb.train.txt'), dst_dir)
+ shutil.copy(os.path.join(src_dir, 'data/ptb.valid.txt'), dst_dir)
+ shutil.copy(os.path.join(src_dir, 'data/ptb.test.txt'), dst_dir)
+ shutil.rmtree(src_dir)
+ print('\tCompleted!')
+
+
+def main(arguments):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '-d',
+ '--data_dir',
+ help='directory to save data to',
+ type=str,
+ default='data')
+ parser.add_argument(
+ '-t',
+ '--task',
+ help='tasks to download data for as a comma separated string',
+ type=str,
+ default='ptb')
+ args = parser.parse_args(arguments)
+
+ if not os.path.isdir(args.data_dir):
+ os.mkdir(args.data_dir)
+
+ download_and_extract(args.task, args.data_dir)
+
+
+if __name__ == '__main__':
+ sys.exit(main(sys.argv[1:]))
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/infer.py b/PaddleNLP/PaddleTextGEN/vae_text/infer.py
new file mode 100644
index 00000000..c21fff3a
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/infer.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import time
+import os
+import random
+import io
+import math
+
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.framework as framework
+from paddle.fluid.executor import Executor
+
+import reader
+
+import sys
+line_tok = '\n'
+space_tok = ' '
+if sys.version[0] == '2':
+ reload(sys)
+ sys.setdefaultencoding("utf-8")
+ line_tok = u'\n'
+ space_tok = u' '
+
+import os
+
+from args import *
+#from . import lm_model
+import logging
+import pickle
+
+from model import VAE
+from reader import BOS_ID, EOS_ID, get_vocab
+
+
+def infer():
+ args = parse_args()
+
+ num_layers = args.num_layers
+ src_vocab_size = args.vocab_size
+ tar_vocab_size = args.vocab_size
+ batch_size = args.batch_size
+ init_scale = args.init_scale
+ max_grad_norm = args.max_grad_norm
+ hidden_size = args.hidden_size
+ attr_init = args.attr_init
+ latent_size = 32
+
+ if args.enable_ce:
+ fluid.default_main_program().random_seed = 102
+ framework.default_startup_program().random_seed = 102
+
+ model = VAE(hidden_size,
+ latent_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=num_layers,
+ init_scale=init_scale,
+ attr_init=attr_init)
+
+ beam_size = args.beam_size
+ trans_res = model.build_graph(mode='sampling', beam_size=beam_size)
+ # clone from default main program and use it as the validation program
+ main_program = fluid.default_main_program()
+
+ place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
+ exe = Executor(place)
+ exe.run(framework.default_startup_program())
+
+ dir_name = args.reload_model
+ print("dir name", dir_name)
+ fluid.io.load_params(exe, dir_name)
+ vocab, tar_id2vocab = get_vocab(args.dataset_prefix)
+ infer_output = np.ones((batch_size, 1), dtype='int64') * BOS_ID
+
+ fetch_outs = exe.run(feed={'tar': infer_output},
+ fetch_list=[trans_res.name],
+ use_program_cache=False)
+
+ with io.open(args.infer_output_file, 'w', encoding='utf-8') as out_file:
+
+ for line in fetch_outs[0]:
+ end_id = -1
+ if EOS_ID in line:
+ end_id = np.where(line == EOS_ID)[0][0]
+ new_line = [tar_id2vocab[e[0]] for e in line[1:end_id]]
+ out_file.write(space_tok.join(new_line))
+ out_file.write(line_tok)
+
+
+def check_version():
+ """
+ Log error and exit when the installed version of paddlepaddle is
+ not satisfied.
+ """
+ err = "PaddlePaddle version 1.6 or higher is required, " \
+ "or a suitable develop version is satisfied as well. \n" \
+ "Please make sure the version is good with your code." \
+
+ try:
+ fluid.require_version('1.6.0')
+ except Exception as e:
+ logger.error(err)
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ check_version()
+ infer()
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/infer.sh b/PaddleNLP/PaddleTextGEN/vae_text/infer.sh
new file mode 100644
index 00000000..17940a04
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/infer.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+set -x
+export CUDA_VISIBLE_DEVICES=0
+dataset=$1
+k=$2
+python infer.py \
+ --vocab_size 10003 \
+ --batch_size 32 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --dataset_prefix data/${dataset}/${dataset} \
+ --use_gpu True \
+ --reload_model ${dataset}_model/epoch_${k} \
+
+
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/model.py b/PaddleNLP/PaddleTextGEN/vae_text/model.py
new file mode 100644
index 00000000..a582ff9e
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/model.py
@@ -0,0 +1,369 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle.fluid.layers as layers
+import paddle.fluid as fluid
+import numpy as np
+from paddle.fluid import ParamAttr
+from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode
+from paddle.fluid.contrib.layers import basic_lstm, BasicLSTMUnit
+
+from reader import BOS_ID, EOS_ID
+
+INF = 1. * 1e5
+alpha = 0.6
+normal_initializer = lambda x: fluid.initializer.NormalInitializer(loc=0., scale=x**-0.5)
+uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x)
+zero_constant = fluid.initializer.Constant(0.0)
+
+
+class EncoderCell(RNNCell):
+ def __init__(self,
+ num_layers,
+ hidden_size,
+ param_attr_initializer,
+ param_attr_scale,
+ dropout_prob=0.):
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size
+ self.dropout_prob = dropout_prob
+ self.lstm_cells = []
+
+ for i in range(num_layers):
+ lstm_name = "enc_layers_" + str(i)
+ self.lstm_cells.append(
+ LSTMCell(
+ hidden_size, forget_bias=0., name=lstm_name))
+
+ def call(self, step_input, states):
+ new_states = []
+ for i in range(self.num_layers):
+ out, new_state = self.lstm_cells[i](step_input, states[i])
+ step_input = layers.dropout(
+ out,
+ self.dropout_prob,
+ dropout_implementation="upscale_in_train"
+ ) if self.dropout_prob > 0 else out
+ new_states.append(new_state)
+ return step_input, new_states
+
+ @property
+ def state_shape(self):
+ return [cell.state_shape for cell in self.lstm_cells]
+
+
+class DecoderCell(RNNCell):
+ def __init__(self,
+ num_layers,
+ hidden_size,
+ latent_z,
+ param_attr_initializer,
+ param_attr_scale,
+ dropout_prob=0.):
+ self.num_layers = num_layers
+ self.hidden_size = hidden_size
+ self.dropout_prob = dropout_prob
+ self.latent_z = latent_z
+ self.lstm_cells = []
+
+ param_attr = ParamAttr(
+ initializer=param_attr_initializer(param_attr_scale))
+
+ for i in range(num_layers):
+ lstm_name = "dec_layers_" + str(i)
+ self.lstm_cells.append(
+ LSTMCell(
+ hidden_size, param_attr, forget_bias=0., name=lstm_name))
+
+ def call(self, step_input, states):
+ lstm_states = states
+ new_lstm_states = []
+ step_input = layers.concat([step_input, self.latent_z], 1)
+ for i in range(self.num_layers):
+ out, lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
+ step_input = layers.dropout(
+ out,
+ self.dropout_prob,
+ dropout_implementation="upscale_in_train"
+ ) if self.dropout_prob > 0 else out
+ new_lstm_states.append(lstm_state)
+ return step_input, new_lstm_states
+
+
+class VAE(object):
+ def __init__(self,
+ hidden_size,
+ latent_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=1,
+ init_scale=0.1,
+ dec_dropout_in=0.5,
+ dec_dropout_out=0.5,
+ enc_dropout_in=0.,
+ enc_dropout_out=0.,
+ word_keep_prob=0.5,
+ batch_first=True,
+ attr_init="normal_initializer"):
+
+ self.hidden_size = hidden_size
+ self.latent_size = latent_size
+ self.src_vocab_size = src_vocab_size
+ self.tar_vocab_size = tar_vocab_size
+ self.batch_size = batch_size
+ self.num_layers = num_layers
+ self.init_scale = init_scale
+ self.dec_dropout_in = dec_dropout_in
+ self.dec_dropout_out = dec_dropout_out
+ self.enc_dropout_in = enc_dropout_in
+ self.enc_dropout_out = enc_dropout_out
+ self.word_keep_prob = word_keep_prob
+ self.batch_first = batch_first
+
+ if attr_init == "normal_initializer":
+ self.param_attr_initializer = normal_initializer
+ self.param_attr_scale = hidden_size
+ elif attr_init == "uniform_initializer":
+ self.param_attr_initializer = uniform_initializer
+ self.param_attr_scale = init_scale
+ else:
+ raise TypeError("The type of 'attr_initializer' is not supported")
+
+ self.src_embeder = lambda x: fluid.embedding(
+ input=x,
+ size=[self.src_vocab_size, self.hidden_size],
+ dtype='float32',
+ is_sparse=False,
+ param_attr=fluid.ParamAttr(
+ name='source_embedding',
+ initializer=self.param_attr_initializer(self.param_attr_scale)))
+
+ self.tar_embeder = lambda x: fluid.embedding(
+ input=x,
+ size=[self.tar_vocab_size, self.hidden_size],
+ dtype='float32',
+ is_sparse=False,
+ param_attr=fluid.ParamAttr(
+ name='target_embedding',
+ initializer=self.param_attr_initializer(self.param_attr_scale)))
+
+ def _build_data(self, mode='train'):
+ if mode == 'train':
+ self.src = fluid.data(name="src", shape=[None, None], dtype='int64')
+ self.src_sequence_length = fluid.data(
+ name="src_sequence_length", shape=[None], dtype='int32')
+ self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64')
+ self.tar_sequence_length = fluid.data(
+ name="tar_sequence_length", shape=[None], dtype='int32')
+ self.label = fluid.data(
+ name="label", shape=[None, None, 1], dtype='int64')
+ self.kl_weight = fluid.data(
+ name='kl_weight', shape=[1], dtype='float32')
+ else:
+ self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64')
+
+ def _emebdding(self, mode='train'):
+ if mode == 'train':
+ self.src_emb = self.src_embeder(self.src)
+ self.tar_emb = self.tar_embeder(self.tar)
+
+ def _build_encoder(self):
+ self.enc_input = layers.dropout(
+ self.src_emb,
+ dropout_prob=self.enc_dropout_in,
+ dropout_implementation="upscale_in_train")
+ enc_cell = EncoderCell(self.num_layers, self.hidden_size,
+ self.param_attr_initializer,
+ self.param_attr_scale, self.enc_dropout_out)
+ enc_output, enc_final_state = rnn(
+ cell=enc_cell,
+ inputs=self.enc_input,
+ sequence_length=self.src_sequence_length)
+ return enc_output, enc_final_state
+
+ def _build_distribution(self, enc_final_state=None):
+ enc_hidden = [
+ layers.concat(
+ state, axis=-1) for state in enc_final_state
+ ]
+ enc_hidden = layers.concat(enc_hidden, axis=-1)
+ z_mean_log_var = layers.fc(input=enc_hidden,
+ size=self.latent_size * 2,
+ name='fc_dist')
+ z_mean, z_log_var = layers.split(z_mean_log_var, 2, -1)
+ return z_mean, z_log_var
+
+ def _build_decoder(self,
+ z_mean=None,
+ z_log_var=None,
+ enc_output=None,
+ mode='train',
+ beam_size=10):
+ dec_input = layers.dropout(
+ self.tar_emb,
+ dropout_prob=self.dec_dropout_in,
+ dropout_implementation="upscale_in_train")
+
+ # `output_layer` will be used within BeamSearchDecoder
+ output_layer = lambda x: layers.fc(x,
+ size=self.tar_vocab_size,
+ num_flatten_dims=len(x.shape) - 1,
+ name="output_w")
+
+ # `sample_output_layer` samples an id from the logits distribution instead of argmax(logits)
+ # it will be used within BeamSearchDecoder
+ sample_output_layer = lambda x: layers.unsqueeze(layers.one_hot(
+ layers.unsqueeze(
+ layers.sampling_id(
+ layers.softmax(
+ layers.squeeze(output_layer(x),[1])
+ ),dtype='int'), [1]),
+ depth=self.tar_vocab_size), [1])
+
+ if mode == 'train':
+ latent_z = self._sampling(z_mean, z_log_var)
+ else:
+ latent_z = layers.gaussian_random_batch_size_like(
+ self.tar, shape=[-1, self.latent_size])
+ dec_first_hidden_cell = layers.fc(latent_z,
+ 2 * self.hidden_size *
+ self.num_layers,
+ name='fc_hc')
+ dec_first_hidden, dec_first_cell = layers.split(dec_first_hidden_cell,
+ 2)
+ if self.num_layers > 1:
+ dec_first_hidden = layers.split(dec_first_hidden, self.num_layers)
+ dec_first_cell = layers.split(dec_first_cell, self.num_layers)
+ else:
+ dec_first_hidden = [dec_first_hidden]
+ dec_first_cell = [dec_first_cell]
+ dec_initial_states = [[h, c]
+ for h, c in zip(dec_first_hidden, dec_first_cell)]
+ dec_cell = DecoderCell(self.num_layers, self.hidden_size, latent_z,
+ self.param_attr_initializer,
+ self.param_attr_scale, self.dec_dropout_out)
+
+ if mode == 'train':
+ dec_output, _ = rnn(cell=dec_cell,
+ inputs=dec_input,
+ initial_states=dec_initial_states,
+ sequence_length=self.tar_sequence_length)
+ dec_output = output_layer(dec_output)
+
+ return dec_output
+ elif mode == 'greedy':
+ start_token = 1
+ end_token = 2
+ max_length = 100
+ beam_search_decoder = BeamSearchDecoder(
+ dec_cell,
+ start_token,
+ end_token,
+ beam_size=1,
+ embedding_fn=self.tar_embeder,
+ output_fn=output_layer)
+ outputs, _ = dynamic_decode(
+ beam_search_decoder,
+ inits=dec_initial_states,
+ max_step_num=max_length)
+ return outputs
+
+ elif mode == 'sampling':
+ start_token = 1
+ end_token = 2
+ max_length = 100
+ beam_search_decoder = BeamSearchDecoder(
+ dec_cell,
+ start_token,
+ end_token,
+ beam_size=1,
+ embedding_fn=self.tar_embeder,
+ output_fn=sample_output_layer)
+
+ outputs, _ = dynamic_decode(
+ beam_search_decoder,
+ inits=dec_initial_states,
+ max_step_num=max_length)
+ return outputs
+ else:
+ print("mode not supprt", mode)
+
+ def _sampling(self, z_mean, z_log_var):
+ """reparameterization trick
+ """
+ # by default, random_normal has mean=0 and std=1.0
+ epsilon = layers.gaussian_random_batch_size_like(
+ self.tar, shape=[-1, self.latent_size])
+ epsilon.stop_gradient = True
+ return z_mean + layers.exp(0.5 * z_log_var) * epsilon
+
+ def _kl_dvg(self, means, logvars):
+ """compute the KL divergence between Gaussian distribution
+ """
+ kl_cost = -0.5 * (logvars - fluid.layers.square(means) -
+ fluid.layers.exp(logvars) + 1.0)
+ kl_cost = fluid.layers.reduce_mean(kl_cost, 0)
+
+ return fluid.layers.reduce_sum(kl_cost)
+
+ def _compute_loss(self, mean, logvars, dec_output):
+
+ kl_loss = self._kl_dvg(mean, logvars)
+
+ rec_loss = layers.softmax_with_cross_entropy(
+ logits=dec_output, label=self.label, soft_label=False)
+
+ rec_loss = layers.reshape(rec_loss, shape=[self.batch_size, -1])
+
+ max_tar_seq_len = layers.shape(self.tar)[1]
+ tar_mask = layers.sequence_mask(
+ self.tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32')
+ rec_loss = rec_loss * tar_mask
+ rec_loss = layers.reduce_mean(rec_loss, dim=[0])
+ rec_loss = layers.reduce_sum(rec_loss)
+
+ loss = kl_loss * self.kl_weight + rec_loss
+
+ return loss, kl_loss, rec_loss
+
+ def _beam_search(self, enc_last_hidden, enc_last_cell):
+ pass
+
+ def build_graph(self, mode='train', beam_size=10):
+ if mode == 'train' or mode == 'eval':
+ self._build_data()
+ self._emebdding()
+ enc_output, enc_final_state = self._build_encoder()
+ z_mean, z_log_var = self._build_distribution(enc_final_state)
+ dec_output = self._build_decoder(z_mean, z_log_var, enc_output)
+
+ loss, kl_loss, rec_loss = self._compute_loss(z_mean, z_log_var,
+ dec_output)
+ return loss, kl_loss, rec_loss
+
+ elif mode == "sampling" or mode == 'greedy':
+ self._build_data(mode)
+ self._emebdding(mode)
+ dec_output = self._build_decoder(mode=mode, beam_size=1)
+
+ return dec_output
+ else:
+ print("not support mode ", mode)
+ raise Exception("not support mode: " + mode)
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/reader.py b/PaddleNLP/PaddleTextGEN/vae_text/reader.py
new file mode 100644
index 00000000..86bb6afd
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/reader.py
@@ -0,0 +1,206 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utilities for parsing PTB text files."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import os
+import sys
+import numpy as np
+import io
+from collections import Counter
+Py3 = sys.version_info[0] == 3
+
+if Py3:
+ line_tok = '\n'
+else:
+ line_tok = u'\n'
+
+PAD_ID = 0
+BOS_ID = 1
+EOS_ID = 2
+UNK_ID = 3
+
+
+def _read_words(filename):
+ data = []
+ with io.open(filename, "r", encoding='utf-8') as f:
+ if Py3:
+ return f.read().replace("\n", "").split()
+ else:
+ return f.read().decode("utf-8").replace(u"\n", u"").split()
+
+
+def read_all_line(filename):
+ data = []
+ with io.open(filename, "r", encoding='utf-8') as f:
+ for line in f.readlines():
+ data.append(line.strip())
+ return data
+
+
+def _vocab(vocab_file, train_file, max_vocab_cnt):
+ lines = read_all_line(train_file)
+
+ all_words = []
+ for line in lines:
+ all_words.extend(line.split())
+ vocab_count = Counter(all_words).most_common()
+ raw_vocab_size = min(len(vocab_count), max_vocab_cnt)
+ with io.open(vocab_file, "w", encoding='utf-8') as f:
+ for voc, fre in vocab_count[0:max_vocab_cnt]:
+ f.write(voc)
+ f.write(line_tok)
+
+
+def _build_vocab(vocab_file, train_file=None, max_vocab_cnt=-1):
+ if not os.path.exists(vocab_file):
+ _vocab(vocab_file, train_file, max_vocab_cnt)
+ vocab_dict = {"": 0, "": 1, "": 2, "": 3}
+ ids = 4
+ with io.open(vocab_file, "r", encoding='utf-8') as f:
+ for line in f.readlines():
+ vocab_dict[line.strip()] = ids
+ ids += 1
+ # rev_vocab = {value:key for key, value in vocab_dict.items()}
+ print("vocab word num", ids)
+
+ return vocab_dict
+
+
+def _para_file_to_ids(src_file, src_vocab):
+
+ src_data = []
+ with io.open(src_file, "r", encoding='utf-8') as f_src:
+ for line in f_src.readlines():
+ arra = line.strip().split()
+ ids = [BOS_ID]
+ ids.extend(
+ [src_vocab[w] if w in src_vocab else UNK_ID for w in arra])
+ ids.append(EOS_ID)
+ src_data.append(ids)
+
+ return src_data
+
+
+def filter_len(src, max_sequence_len=128):
+ new_src = []
+
+ for id1 in src:
+ if len(id1) > max_sequence_len:
+ id1 = id1[:max_sequence_len]
+
+ new_src.append(id1)
+
+ return new_src
+
+
+def raw_data(dataset_prefix, max_sequence_len=50, max_vocab_cnt=-1):
+
+ src_vocab_file = dataset_prefix + ".vocab.txt"
+ src_train_file = dataset_prefix + ".train.txt"
+ src_eval_file = dataset_prefix + ".valid.txt"
+ src_test_file = dataset_prefix + ".test.txt"
+
+ src_vocab = _build_vocab(src_vocab_file, src_train_file, max_vocab_cnt)
+
+ train_src = _para_file_to_ids(src_train_file, src_vocab)
+ train_src = filter_len(train_src, max_sequence_len=max_sequence_len)
+ eval_src = _para_file_to_ids(src_eval_file, src_vocab)
+
+ test_src = _para_file_to_ids(src_test_file, src_vocab)
+
+ return train_src, eval_src, test_src, src_vocab
+
+
+def get_vocab(dataset_prefix, max_sequence_len=50):
+ src_vocab_file = dataset_prefix + ".vocab.txt"
+ src_vocab = _build_vocab(src_vocab_file)
+ rev_vocab = {}
+ for key, value in src_vocab.items():
+ rev_vocab[value] = key
+
+ return src_vocab, rev_vocab
+
+
+def raw_mono_data(vocab_file, file_path):
+
+ src_vocab = _build_vocab(vocab_file)
+ test_src, test_tar = _para_file_to_ids( file_path, file_path, \
+ src_vocab, src_vocab )
+
+ return (test_src, test_tar)
+
+
+def get_data_iter(raw_data,
+ batch_size,
+ sort_cache=False,
+ cache_num=1,
+ mode='train',
+ enable_ce=False):
+
+ src_data = raw_data
+
+ data_len = len(src_data)
+
+ index = np.arange(data_len)
+ if mode == "train" and not enable_ce:
+ np.random.shuffle(index)
+
+ def to_pad_np(data):
+ max_len = 0
+ for ele in data:
+ if len(ele) > max_len:
+ max_len = len(ele)
+
+ ids = np.ones(
+ (batch_size, max_len), dtype='int64') * PAD_ID # PAD_ID = 0
+ mask = np.zeros((batch_size), dtype='int32')
+
+ for i, ele in enumerate(data):
+ ids[i, :len(ele)] = ele
+ mask[i] = len(ele)
+
+ return ids, mask
+
+ b_src = []
+
+ if mode != "train":
+ cache_num = 1
+ for j in range(data_len):
+ if len(b_src) == batch_size * cache_num:
+ if sort_cache:
+ new_cache = sorted(b_src, key=lambda k: len(k))
+ new_cache = b_src
+ for i in range(cache_num):
+ batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
+ src_ids, src_mask = to_pad_np(batch_data)
+ yield (src_ids, src_mask)
+
+ b_src = []
+
+ b_src.append(src_data[index[j]])
+
+ if len(b_src) > 0:
+ if sort_cache:
+ new_cache = sorted(b_src, key=lambda k: len(k))
+ new_cache = b_src
+ for i in range(0, len(b_src), batch_size):
+ end_index = min((i + 1) * batch_size, len(b_src))
+ batch_data = new_cache[i * batch_size:end_index]
+ src_ids, src_mask = to_pad_np(batch_data)
+ yield (src_ids, src_mask)
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/run.sh b/PaddleNLP/PaddleTextGEN/vae_text/run.sh
new file mode 100644
index 00000000..af0db8a1
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/run.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+set -x
+export CUDA_VISIBLE_DEVICES=0
+dataset=$1
+python train.py \
+ --vocab_size 10003 \
+ --batch_size 32 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --dataset_prefix data/${dataset}/${dataset} \
+ --model_path ${dataset}_model\
+ --use_gpu True \
+ --max_epoch 200 \
+
+
diff --git a/PaddleNLP/PaddleTextGEN/vae_text/train.py b/PaddleNLP/PaddleTextGEN/vae_text/train.py
new file mode 100644
index 00000000..08b9e1ae
--- /dev/null
+++ b/PaddleNLP/PaddleTextGEN/vae_text/train.py
@@ -0,0 +1,313 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import time
+import os
+import random
+import math
+import contextlib
+
+import paddle
+import paddle.fluid as fluid
+import paddle.fluid.framework as framework
+import paddle.fluid.profiler as profiler
+from paddle.fluid.executor import Executor
+import reader
+
+import sys
+if sys.version[0] == '2':
+ reload(sys)
+ sys.setdefaultencoding("utf-8")
+import os
+
+from args import *
+from model import VAE
+import logging
+import pickle
+
+
+@contextlib.contextmanager
+def profile_context(profile=True):
+ if profile:
+ with profiler.profiler('All', 'total', 'seq2seq.profile'):
+ yield
+ else:
+ yield
+
+
+def main():
+ args = parse_args()
+ print(args)
+ num_layers = args.num_layers
+ src_vocab_size = args.vocab_size
+ tar_vocab_size = args.vocab_size
+ batch_size = args.batch_size
+ init_scale = args.init_scale
+ max_grad_norm = args.max_grad_norm
+ hidden_size = args.hidden_size
+ attr_init = args.attr_init
+ latent_size = 32
+
+ main_program = fluid.Program()
+ startup_program = fluid.Program()
+ if args.enable_ce:
+ fluid.default_main_program().random_seed = 123
+ framework.default_startup_program().random_seed = 123
+
+ # Training process
+ with fluid.program_guard(main_program, startup_program):
+ with fluid.unique_name.guard():
+ model = VAE(hidden_size,
+ latent_size,
+ src_vocab_size,
+ tar_vocab_size,
+ batch_size,
+ num_layers=num_layers,
+ init_scale=init_scale,
+ attr_init=attr_init)
+
+ loss, kl_loss, rec_loss = model.build_graph()
+ # clone from default main program and use it as the validation program
+ main_program = fluid.default_main_program()
+ inference_program = fluid.default_main_program().clone(
+ for_test=True)
+
+ fluid.clip.set_gradient_clip(
+ clip=fluid.clip.GradientClipByGlobalNorm(
+ clip_norm=max_grad_norm))
+
+ learning_rate = fluid.layers.create_global_var(
+ name="learning_rate",
+ shape=[1],
+ value=float(args.learning_rate),
+ dtype="float32",
+ persistable=True)
+
+ opt_type = args.optimizer
+ if opt_type == "sgd":
+ optimizer = fluid.optimizer.SGD(learning_rate)
+ elif opt_type == "adam":
+ optimizer = fluid.optimizer.Adam(learning_rate)
+ else:
+ print("only support [sgd|adam]")
+ raise Exception("opt type not support")
+
+ optimizer.minimize(loss)
+
+ place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
+ exe = Executor(place)
+ exe.run(startup_program)
+
+ train_program = fluid.compiler.CompiledProgram(main_program)
+
+ dataset_prefix = args.dataset_prefix
+ print("begin to load data")
+ raw_data = reader.raw_data(dataset_prefix, args.max_len)
+ print("finished load data")
+ train_data, valid_data, test_data, _ = raw_data
+
+ anneal_r = 1.0 / (args.warm_up * len(train_data) / args.batch_size)
+
+ def prepare_input(batch, kl_weight=1.0, lr=None):
+ src_ids, src_mask = batch
+ res = {}
+ src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
+ in_tar = src_ids[:, :-1]
+ label_tar = src_ids[:, 1:]
+
+ in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
+ label_tar = label_tar.reshape(
+ (label_tar.shape[0], label_tar.shape[1], 1))
+
+ res['src'] = src_ids
+ res['tar'] = in_tar
+ res['label'] = label_tar
+ res['src_sequence_length'] = src_mask
+ res['tar_sequence_length'] = src_mask - 1
+ res['kl_weight'] = np.array([kl_weight]).astype(np.float32)
+ if lr is not None:
+ res['learning_rate'] = np.array([lr]).astype(np.float32)
+
+ return res, np.sum(src_mask), np.sum(src_mask - 1)
+
+ # get train epoch size
+ def eval(data):
+ eval_data_iter = reader.get_data_iter(data, batch_size, mode='eval')
+ total_loss = 0.0
+ word_count = 0.0
+ batch_count = 0.0
+ for batch_id, batch in enumerate(eval_data_iter):
+ input_data_feed, src_word_num, dec_word_sum = prepare_input(batch)
+ fetch_outs = exe.run(inference_program,
+ feed=input_data_feed,
+ fetch_list=[loss.name],
+ use_program_cache=False)
+
+ cost_train = np.array(fetch_outs[0])
+
+ total_loss += cost_train * batch_size
+ word_count += dec_word_sum
+ batch_count += batch_size
+
+ nll = total_loss / batch_count
+ ppl = np.exp(total_loss / word_count)
+
+ return nll, ppl
+
+ def train():
+ ce_time = []
+ ce_ppl = []
+ max_epoch = args.max_epoch
+ kl_w = args.kl_start
+ lr_w = args.learning_rate
+ best_valid_nll = 1e100 # +inf
+ best_epoch_id = -1
+ decay_cnt = 0
+ max_decay = args.max_decay
+ decay_factor = 0.5
+ decay_ts = 2
+ steps_not_improved = 0
+ for epoch_id in range(max_epoch):
+ start_time = time.time()
+ if args.enable_ce:
+ train_data_iter = reader.get_data_iter(
+ train_data,
+ batch_size,
+ args.sort_cache,
+ args.cache_num,
+ enable_ce=True)
+ else:
+ train_data_iter = reader.get_data_iter(
+ train_data, batch_size, args.sort_cache, args.cache_num)
+
+ total_loss = 0
+ total_rec_loss = 0
+ total_kl_loss = 0
+ word_count = 0.0
+ batch_count = 0.0
+ batch_times = []
+ batch_start_time = time.time()
+ for batch_id, batch in enumerate(train_data_iter):
+ kl_w = min(1.0, kl_w + anneal_r)
+ kl_weight = kl_w
+ input_data_feed, src_word_num, dec_word_sum = prepare_input(
+ batch, kl_weight, lr_w)
+ fetch_outs = exe.run(
+ program=train_program,
+ feed=input_data_feed,
+ fetch_list=[loss.name, kl_loss.name, rec_loss.name],
+ use_program_cache=False)
+
+ cost_train = np.array(fetch_outs[0])
+ kl_cost_train = np.array(fetch_outs[1])
+ rec_cost_train = np.array(fetch_outs[2])
+
+ total_loss += cost_train * batch_size
+ total_rec_loss += rec_cost_train * batch_size
+ total_kl_loss += kl_cost_train * batch_size
+ word_count += dec_word_sum
+ batch_count += batch_size
+ batch_end_time = time.time()
+ batch_time = batch_end_time - batch_start_time
+ batch_times.append(batch_time)
+
+ if batch_id > 0 and batch_id % 200 == 0:
+ print("-- Epoch:[%d]; Batch:[%d]; Time: %.4f s; "
+ "kl_weight: %.4f; kl_loss: %.4f; rec_loss: %.4f; "
+ "nll: %.4f; ppl: %.4f" %
+ (epoch_id, batch_id, batch_time, kl_w, total_kl_loss /
+ batch_count, total_rec_loss / batch_count, total_loss
+ / batch_count, np.exp(total_loss / word_count)))
+ ce_ppl.append(np.exp(total_loss / word_count))
+
+ end_time = time.time()
+ epoch_time = end_time - start_time
+ ce_time.append(epoch_time)
+ print(
+ "\nTrain epoch:[%d]; Epoch Time: %.4f; avg_time: %.4f s/step\n"
+ % (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
+
+ val_nll, val_ppl = eval(valid_data)
+ print("dev ppl", val_ppl)
+ test_nll, test_ppl = eval(test_data)
+ print("test ppl", test_ppl)
+
+ if val_nll < best_valid_nll:
+ best_valid_nll = val_nll
+ steps_not_improved = 0
+ best_nll = test_nll
+ best_ppl = test_ppl
+ best_epoch_id = epoch_id
+ dir_name = os.path.join(args.model_path,
+ "epoch_" + str(best_epoch_id))
+ print("save model {}".format(dir_name))
+ fluid.io.save_params(exe, dir_name, main_program)
+ else:
+ steps_not_improved += 1
+ if steps_not_improved == decay_ts:
+ old_lr = lr_w
+ lr_w *= decay_factor
+ steps_not_improved = 0
+ new_lr = lr_w
+
+ print('-----\nchange lr, old lr: %f, new lr: %f\n-----' %
+ (old_lr, new_lr))
+
+ dir_name = args.model_path + "/epoch_" + str(best_epoch_id)
+ fluid.io.load_params(exe, dir_name)
+
+ decay_cnt += 1
+ if decay_cnt == max_decay:
+ break
+
+ print('\nbest testing nll: %.4f, best testing ppl %.4f\n' %
+ (best_nll, best_ppl))
+
+ with profile_context(args.profile):
+ train()
+
+
+def get_cards():
+ num = 0
+ cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
+ if cards != '':
+ num = len(cards.split(","))
+ return num
+
+
+def check_version():
+ """
+ Log error and exit when the installed version of paddlepaddle is
+ not satisfied.
+ """
+ err = "PaddlePaddle version 1.6 or higher is required, " \
+ "or a suitable develop version is satisfied as well. \n" \
+ "Please make sure the version is good with your code." \
+
+ try:
+ fluid.require_version('1.6.0')
+ except Exception as e:
+ logger.error(err)
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ check_version()
+ main()
--
GitLab