提交 b034390c 编写于 作者: 1024的传说's avatar 1024的传说 提交者: Guo Sheng

change rnn_search to seq2seq; change vae_text to variational seq2seq (#3734)

* re-commit seq2eq example: rnn search

* add new_seq2seq examples: rnn_search and vae_text

* modify dataset description in vae_text/README.md

* modify vae_text/README.md and vae_text/download.py

* update vae_text/infer.sh and vae_text/README.md

* add multi-gpu support for rnn_search; fix windows support for rnn_search and vae_text

* remove parallel argument in args.py

* remove hard codes

* remove hard codes

* update download.py for windows download

* fix model path for windows in vae_text/infer.py

* fix python2 utf-encoding

* change rnn_search to seq2seq; change vae_text to variational seq2seq

* change rnn_search to seq2seq; change vae_text to variational seq2seq

* mv variational seq2seq to variational_seq2seq
上级 224fe10d
运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.6版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
# Sequence to Sequence (Seq2Seq)
以下是本范例模型的简要目录结构及说明:
```
.
├── README.md # 文档,本文件
├── args.py # 训练、预测以及模型参数配置程序
├── reader.py # 数据读入程序
├── download.py # 数据下载程序
├── train.py # 训练主程序
├── infer.py # 预测主程序
├── run.sh # 默认配置的启动脚本
├── infer.sh # 默认配置的解码脚本
├── attention_model.py # 带注意力机制的翻译模型程序
└── base_model.py # 无注意力机制的翻译模型程序
```
## 简介
Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。
本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考[深度学习101](http://paddlepaddle.org/documentation/docs/zh/1.2/beginners_guide/basics/machine_translation/index.html)
**本目录旨在展示如何用Paddle Fluid 1.6 新增的Seq2Seq API** ,新 Seq2Seq API 组网更简单,从1.6版本开始不推荐使用low-level的API。如果您确实需要使用low-level的API来实现自己模型,样例可参看1.5版本的样例 [RNN Search](https://github.com/PaddlePaddle/models/tree/release/1.5/PaddleNLP/unarchived/neural_machine_translation/rnn_search)
## 模型概览
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。以下将分别介绍用到的这些方法。
## 数据介绍
本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集
### 数据获取
```
python download.py
```
## 模型训练
`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行:
```
sh run.sh
```
默认使用带有注意力机制的RNN模型,可以通过修改 --attention 为False来训练不带注意力机制的RNN模型。
```
python train.py \
--src_lang en --tar_lang vi \
--attention True \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--train_data_prefix data/en-vi/train \
--eval_data_prefix data/en-vi/tst2012 \
--test_data_prefix data/en-vi/tst2013 \
--vocab_prefix data/en-vi/vocab \
--use_gpu True \
--model_path ./attention_models
```
训练程序会在每个epoch训练结束之后,save一次模型。
## 模型预测
当模型训练完成之后, 可以利用infer.sh的脚本进行预测,默认使用beam search的方法进行预测,加载第10个epoch的模型进行预测,对test的数据集进行解码
```
sh infer.sh
```
如果想预测别的数据文件,只需要将 --infer_file参数进行修改。
```
python infer.py \
--attention True \
--src_lang en --tar_lang vi \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--vocab_prefix data/en-vi/vocab \
--infer_file data/en-vi/tst2013.en \
--reload_model attention_models/epoch_10/ \
--infer_output_file attention_infer_output/infer_output.txt \
--beam_size 10 \
--use_gpu True
```
## 效果评价
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下:
```sh
mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt
```
单个模型 beam_size = 10的效果如下:
```
> no attention
tst2012 BLEU: 10.99
tst2013 BLEU: 11.23
>with attention
tst2012 BLEU: 22.85
tst2013 BLEU: 25.68
```
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import distutils.util
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--train_data_prefix", type=str, help="file prefix for train data")
parser.add_argument(
"--eval_data_prefix", type=str, help="file prefix for eval data")
parser.add_argument(
"--test_data_prefix", type=str, help="file prefix for test data")
parser.add_argument(
"--vocab_prefix", type=str, help="file prefix for vocab")
parser.add_argument("--src_lang", type=str, help="source language suffix")
parser.add_argument("--tar_lang", type=str, help="target language suffix")
parser.add_argument(
"--attention",
type=eval,
default=False,
help="Whether use attention model")
parser.add_argument(
"--optimizer",
type=str,
default='adam',
help="optimizer to use, only supprt[sgd|adam]")
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="learning rate for optimizer")
parser.add_argument(
"--num_layers",
type=int,
default=1,
help="layers number of encoder and decoder")
parser.add_argument(
"--hidden_size",
type=int,
default=100,
help="hidden size of encoder and decoder")
parser.add_argument("--src_vocab_size", type=int, help="source vocab size")
parser.add_argument("--tar_vocab_size", type=int, help="target vocab size")
parser.add_argument(
"--batch_size", type=int, help="batch size of each step")
parser.add_argument(
"--max_epoch", type=int, default=12, help="max epoch for the training")
parser.add_argument(
"--max_len",
type=int,
default=50,
help="max length for source and target sentence")
parser.add_argument(
"--dropout", type=float, default=0.0, help="drop probability")
parser.add_argument(
"--init_scale",
type=float,
default=0.0,
help="init scale for parameter")
parser.add_argument(
"--max_grad_norm",
type=float,
default=5.0,
help="max grad norm for global norm clip")
parser.add_argument(
"--model_path",
type=str,
default='model',
help="model path for model to save")
parser.add_argument(
"--reload_model", type=str, help="reload model to inference")
parser.add_argument(
"--infer_file", type=str, help="file name for inference")
parser.add_argument(
"--infer_output_file",
type=str,
default='infer_output',
help="file name for inference output")
parser.add_argument(
"--beam_size", type=int, default=10, help="file name for inference")
parser.add_argument(
'--use_gpu',
type=eval,
default=False,
help='Whether using gpu [True|False]')
parser.add_argument(
"--enable_ce",
action='store_true',
help="The flag indicating whether to run the task "
"for continuous evaluation.")
parser.add_argument(
"--profile", action='store_true', help="Whether enable the profile.")
args = parser.parse_args()
return args
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
from paddle.fluid import ParamAttr
from paddle.fluid.contrib.layers import basic_lstm, BasicLSTMUnit
from base_model import BaseModel, DecoderCell
from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode
INF = 1. * 1e5
alpha = 0.6
class AttentionDecoderCell(DecoderCell):
def __init__(self, num_layers, hidden_size, dropout_prob=0.,
init_scale=0.1):
super(AttentionDecoderCell, self).__init__(num_layers, hidden_size,
dropout_prob, init_scale)
def attention(self, query, enc_output, mask=None):
query = layers.unsqueeze(query, [1])
memory = layers.fc(enc_output,
self.hidden_size,
num_flatten_dims=2,
param_attr=ParamAttr(
initializer=fluid.initializer.UniformInitializer(
low=-self.init_scale, high=self.init_scale)),
bias_attr=False)
attn = layers.matmul(query, memory, transpose_y=True)
if mask:
attn = layers.transpose(attn, [1, 0, 2])
attn = layers.elementwise_add(attn, mask * 1000000000, -1)
attn = layers.transpose(attn, [1, 0, 2])
weight = layers.softmax(attn)
weight_memory = layers.matmul(weight, memory)
return weight_memory
def call(self, step_input, states, enc_output, enc_padding_mask=None):
lstm_states, input_feed = states
new_lstm_states = []
step_input = layers.concat([step_input, input_feed], 1)
for i in range(self.num_layers):
out, new_lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
step_input = layers.dropout(
out,
self.dropout_prob,
dropout_implementation='upscale_in_train'
) if self.dropout_prob > 0 else out
new_lstm_states.append(new_lstm_state)
dec_att = self.attention(step_input, enc_output, enc_padding_mask)
dec_att = layers.squeeze(dec_att, [1])
concat_att_out = layers.concat([dec_att, step_input], 1)
out = layers.fc(concat_att_out,
self.hidden_size,
param_attr=ParamAttr(
initializer=fluid.initializer.UniformInitializer(
low=-self.init_scale, high=self.init_scale)),
bias_attr=False)
return out, [new_lstm_states, out]
class AttentionModel(BaseModel):
def __init__(self,
hidden_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=1,
init_scale=0.1,
dropout=None,
beam_start_token=1,
beam_end_token=2,
beam_max_step_num=100):
super(AttentionModel, self).__init__(
hidden_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=num_layers,
init_scale=init_scale,
dropout=dropout)
def _build_decoder(self, enc_final_state, mode='train', beam_size=10):
output_layer = lambda x: layers.fc(x,
size=self.tar_vocab_size,
num_flatten_dims=len(x.shape) - 1,
param_attr=fluid.ParamAttr(name="output_w",
initializer=fluid.initializer.UniformInitializer(low=-self.init_scale, high=self.init_scale)),
bias_attr=False)
dec_cell = AttentionDecoderCell(self.num_layers, self.hidden_size,
self.dropout, self.init_scale)
dec_initial_states = [
enc_final_state, dec_cell.get_initial_states(
batch_ref=self.enc_output, shape=[self.hidden_size])
]
max_src_seq_len = layers.shape(self.src)[1]
src_mask = layers.sequence_mask(
self.src_sequence_length, maxlen=max_src_seq_len, dtype='float32')
enc_padding_mask = (src_mask - 1.0)
if mode == 'train':
dec_output, _ = rnn(cell=dec_cell,
inputs=self.tar_emb,
initial_states=dec_initial_states,
sequence_length=None,
enc_output=self.enc_output,
enc_padding_mask=enc_padding_mask)
dec_output = output_layer(dec_output)
elif mode == 'beam_search':
output_layer = lambda x: layers.fc(x,
size=self.tar_vocab_size,
num_flatten_dims=len(x.shape) - 1,
param_attr=fluid.ParamAttr(name="output_w"),
bias_attr=False)
beam_search_decoder = BeamSearchDecoder(
dec_cell,
self.beam_start_token,
self.beam_end_token,
beam_size,
embedding_fn=self.tar_embeder,
output_fn=output_layer)
enc_output = beam_search_decoder.tile_beam_merge_with_batch(
self.enc_output, beam_size)
enc_padding_mask = beam_search_decoder.tile_beam_merge_with_batch(
enc_padding_mask, beam_size)
outputs, _ = dynamic_decode(
beam_search_decoder,
inits=dec_initial_states,
max_step_num=self.beam_max_step_num,
enc_output=enc_output,
enc_padding_mask=enc_padding_mask)
return outputs
return dec_output
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
from paddle.fluid import ParamAttr
from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode
INF = 1. * 1e5
alpha = 0.6
uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x)
zero_constant = fluid.initializer.Constant(0.0)
class EncoderCell(RNNCell):
def __init__(
self,
num_layers,
hidden_size,
dropout_prob=0.,
init_scale=0.1, ):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
self.lstm_cells = []
param_attr = ParamAttr(initializer=uniform_initializer(init_scale))
bias_attr = ParamAttr(initializer=zero_constant)
for i in range(num_layers):
self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr))
def call(self, step_input, states):
new_states = []
for i in range(self.num_layers):
out, new_state = self.lstm_cells[i](step_input, states[i])
step_input = layers.dropout(
out,
self.dropout_prob,
dropout_implementation="upscale_in_train"
) if self.dropout_prob > 0. else out
new_states.append(new_state)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.lstm_cells]
class DecoderCell(RNNCell):
def __init__(self, num_layers, hidden_size, dropout_prob=0.,
init_scale=0.1):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
self.lstm_cells = []
self.init_scale = init_scale
param_attr = ParamAttr(initializer=uniform_initializer(init_scale))
bias_attr = ParamAttr(initializer=zero_constant)
for i in range(num_layers):
self.lstm_cells.append(LSTMCell(hidden_size, param_attr, bias_attr))
def call(self, step_input, states):
new_lstm_states = []
for i in range(self.num_layers):
out, new_lstm_state = self.lstm_cells[i](step_input, states[i])
step_input = layers.dropout(
out,
self.dropout_prob,
dropout_implementation="upscale_in_train"
) if self.dropout_prob > 0. else out
new_lstm_states.append(new_lstm_state)
return step_input, new_lstm_states
class BaseModel(object):
def __init__(self,
hidden_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=1,
init_scale=0.1,
dropout=None,
beam_start_token=1,
beam_end_token=2,
beam_max_step_num=100):
self.hidden_size = hidden_size
self.src_vocab_size = src_vocab_size
self.tar_vocab_size = tar_vocab_size
self.batch_size = batch_size
self.num_layers = num_layers
self.init_scale = init_scale
self.dropout = dropout
self.beam_start_token = beam_start_token
self.beam_end_token = beam_end_token
self.beam_max_step_num = beam_max_step_num
self.src_embeder = lambda x: fluid.embedding(
input=x,
size=[self.src_vocab_size, self.hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='source_embedding',
initializer=uniform_initializer(init_scale)))
self.tar_embeder = lambda x: fluid.embedding(
input=x,
size=[self.tar_vocab_size, self.hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='target_embedding',
initializer=uniform_initializer(init_scale)))
def _build_data(self):
self.src = fluid.data(name="src", shape=[None, None], dtype='int64')
self.src_sequence_length = fluid.data(
name="src_sequence_length", shape=[None], dtype='int32')
self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64')
self.tar_sequence_length = fluid.data(
name="tar_sequence_length", shape=[None], dtype='int32')
self.label = fluid.data(
name="label", shape=[None, None, 1], dtype='int64')
def _emebdding(self):
self.src_emb = self.src_embeder(self.src)
self.tar_emb = self.tar_embeder(self.tar)
def _build_encoder(self):
enc_cell = EncoderCell(self.num_layers, self.hidden_size, self.dropout,
self.init_scale)
self.enc_output, enc_final_state = rnn(
cell=enc_cell,
inputs=self.src_emb,
sequence_length=self.src_sequence_length)
return self.enc_output, enc_final_state
def _build_decoder(self, enc_final_state, mode='train', beam_size=10):
dec_cell = DecoderCell(self.num_layers, self.hidden_size, self.dropout,
self.init_scale)
output_layer = lambda x: layers.fc(x,
size=self.tar_vocab_size,
num_flatten_dims=len(x.shape) - 1,
param_attr=fluid.ParamAttr(name="output_w",
initializer=uniform_initializer(self.init_scale)),
bias_attr=False)
if mode == 'train':
dec_output, dec_final_state = rnn(cell=dec_cell,
inputs=self.tar_emb,
initial_states=enc_final_state)
dec_output = output_layer(dec_output)
return dec_output
elif mode == 'beam_search':
beam_search_decoder = BeamSearchDecoder(
dec_cell,
self.beam_start_token,
self.beam_end_token,
beam_size,
embedding_fn=self.tar_embeder,
output_fn=output_layer)
outputs, _ = dynamic_decode(
beam_search_decoder,
inits=enc_final_state,
max_step_num=self.beam_max_step_num)
return outputs
def _compute_loss(self, dec_output):
loss = layers.softmax_with_cross_entropy(
logits=dec_output, label=self.label, soft_label=False)
loss = layers.unsqueeze(loss, axes=[2])
max_tar_seq_len = layers.shape(self.tar)[1]
tar_mask = layers.sequence_mask(
self.tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32')
loss = loss * tar_mask
loss = layers.reduce_mean(loss, dim=[0])
loss = layers.reduce_sum(loss)
return loss
def _beam_search(self, enc_last_hidden, enc_last_cell):
pass
def build_graph(self, mode='train', beam_size=10):
if mode == 'train' or mode == 'eval':
self._build_data()
self._emebdding()
enc_output, enc_final_state = self._build_encoder()
dec_output = self._build_decoder(enc_final_state)
loss = self._compute_loss(dec_output)
return loss
elif mode == "beam_search" or mode == 'greedy_search':
self._build_data()
self._emebdding()
enc_output, enc_final_state = self._build_encoder()
dec_output = self._build_decoder(
enc_final_state, mode=mode, beam_size=beam_size)
return dec_output
else:
print("not support mode ", mode)
raise Exception("not support mode: " + mode)
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Script for downloading training data.
'''
import os
import urllib
import sys
if sys.version_info >= (3, 0):
import urllib.request
import zipfile
URLLIB = urllib
if sys.version_info >= (3, 0):
URLLIB = urllib.request
remote_path = 'https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi'
base_path = 'data'
tar_path = os.path.join(base_path, 'en-vi')
filenames = [
'train.en', 'train.vi', 'tst2012.en', 'tst2012.vi', 'tst2013.en',
'tst2013.vi', 'vocab.en', 'vocab.vi'
]
def main(arguments):
print("Downloading data......")
if not os.path.exists(tar_path):
if not os.path.exists(base_path):
os.mkdir(base_path)
os.mkdir(tar_path)
for filename in filenames:
url = remote_path + '/' + filename
tar_file = os.path.join(tar_path, filename)
URLLIB.urlretrieve(url, tar_file)
print("Downloaded sucess......")
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import time
import os
import random
import logging
import math
import io
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.executor import Executor
import reader
import sys
line_tok = '\n'
space_tok = ' '
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
line_tok = u'\n'
space_tok = u' '
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
from args import *
import logging
import pickle
from attention_model import AttentionModel
from base_model import BaseModel
def infer():
args = parse_args()
num_layers = args.num_layers
src_vocab_size = args.src_vocab_size
tar_vocab_size = args.tar_vocab_size
batch_size = args.batch_size
dropout = args.dropout
init_scale = args.init_scale
max_grad_norm = args.max_grad_norm
hidden_size = args.hidden_size
# inference process
print("src", src_vocab_size)
# dropout type using upscale_in_train, dropout can be remove in inferecen
# So we can set dropout to 0
if args.attention:
model = AttentionModel(
hidden_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=num_layers,
init_scale=init_scale,
dropout=0.0)
else:
model = BaseModel(
hidden_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=num_layers,
init_scale=init_scale,
dropout=0.0)
beam_size = args.beam_size
trans_res = model.build_graph(mode='beam_search', beam_size=beam_size)
# clone from default main program and use it as the validation program
main_program = fluid.default_main_program()
main_program = main_program.clone(for_test=True)
print([param.name for param in main_program.blocks[0].all_parameters()])
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
source_vocab_file = args.vocab_prefix + "." + args.src_lang
infer_file = args.infer_file
infer_data = reader.raw_mono_data(source_vocab_file, infer_file)
def prepare_input(batch, epoch_id=0, with_lr=True):
src_ids, src_mask, tar_ids, tar_mask = batch
res = {}
src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
in_tar = tar_ids[:, :-1]
label_tar = tar_ids[:, 1:]
in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
in_tar = np.zeros_like(in_tar, dtype='int64')
label_tar = label_tar.reshape(
(label_tar.shape[0], label_tar.shape[1], 1))
label_tar = np.zeros_like(label_tar, dtype='int64')
res['src'] = src_ids
res['tar'] = in_tar
res['label'] = label_tar
res['src_sequence_length'] = src_mask
res['tar_sequence_length'] = tar_mask
return res, np.sum(tar_mask)
dir_name = args.reload_model
print("dir name", dir_name)
fluid.io.load_params(exe, dir_name)
train_data_iter = reader.get_data_iter(infer_data, 1, mode='eval')
tar_id2vocab = []
tar_vocab_file = args.vocab_prefix + "." + args.tar_lang
with io.open(tar_vocab_file, "r", encoding='utf-8') as f:
for line in f.readlines():
tar_id2vocab.append(line.strip())
infer_output_file = args.infer_output_file
infer_output_dir = infer_output_file.split('/')[0]
if not os.path.exists(infer_output_dir):
os.mkdir(infer_output_dir)
with io.open(infer_output_file, 'w', encoding='utf-8') as out_file:
for batch_id, batch in enumerate(train_data_iter):
input_data_feed, word_num = prepare_input(batch, epoch_id=0)
fetch_outs = exe.run(program=main_program,
feed=input_data_feed,
fetch_list=[trans_res.name],
use_program_cache=False)
for ins in fetch_outs[0]:
res = [tar_id2vocab[e] for e in ins[:, 0].reshape(-1)]
new_res = []
for ele in res:
if ele == "</s>":
break
new_res.append(ele)
out_file.write(space_tok.join(new_res))
out_file.write(line_tok)
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
if __name__ == '__main__':
check_version()
infer()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python infer.py \
--attention True \
--src_lang en --tar_lang vi \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--vocab_prefix data/en-vi/vocab \
--infer_file data/en-vi/tst2013.en \
--reload_model attention_models/epoch_10/ \
--infer_output_file attention_infer_output/infer_output.txt \
--beam_size 10 \
--use_gpu True
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for parsing PTB text files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import io
import sys
import numpy as np
Py3 = sys.version_info[0] == 3
UNK_ID = 0
def _read_words(filename):
data = []
with io.open(filename, "r", encoding='utf-8') as f:
if Py3:
return f.read().replace("\n", "<eos>").split()
else:
return f.read().decode("utf-8").replace(u"\n", u"<eos>").split()
def read_all_line(filenam):
data = []
with io.open(filename, "r", encoding='utf-8') as f:
for line in f.readlines():
data.append(line.strip())
def _build_vocab(filename):
vocab_dict = {}
ids = 0
with io.open(filename, "r", encoding='utf-8') as f:
for line in f.readlines():
vocab_dict[line.strip()] = ids
ids += 1
print("vocab word num", ids)
return vocab_dict
def _para_file_to_ids(src_file, tar_file, src_vocab, tar_vocab):
src_data = []
with io.open(src_file, "r", encoding='utf-8') as f_src:
for line in f_src.readlines():
arra = line.strip().split()
ids = [src_vocab[w] if w in src_vocab else UNK_ID for w in arra]
ids = ids
src_data.append(ids)
tar_data = []
with io.open(tar_file, "r", encoding='utf-8') as f_tar:
for line in f_tar.readlines():
arra = line.strip().split()
ids = [tar_vocab[w] if w in tar_vocab else UNK_ID for w in arra]
ids = [1] + ids + [2]
tar_data.append(ids)
return src_data, tar_data
def filter_len(src, tar, max_sequence_len=50):
new_src = []
new_tar = []
for id1, id2 in zip(src, tar):
if len(id1) > max_sequence_len:
id1 = id1[:max_sequence_len]
if len(id2) > max_sequence_len + 2:
id2 = id2[:max_sequence_len + 2]
new_src.append(id1)
new_tar.append(id2)
return new_src, new_tar
def raw_data(src_lang,
tar_lang,
vocab_prefix,
train_prefix,
eval_prefix,
test_prefix,
max_sequence_len=50):
src_vocab_file = vocab_prefix + "." + src_lang
tar_vocab_file = vocab_prefix + "." + tar_lang
src_train_file = train_prefix + "." + src_lang
tar_train_file = train_prefix + "." + tar_lang
src_eval_file = eval_prefix + "." + src_lang
tar_eval_file = eval_prefix + "." + tar_lang
src_test_file = test_prefix + "." + src_lang
tar_test_file = test_prefix + "." + tar_lang
src_vocab = _build_vocab(src_vocab_file)
tar_vocab = _build_vocab(tar_vocab_file)
train_src, train_tar = _para_file_to_ids( src_train_file, tar_train_file, \
src_vocab, tar_vocab )
train_src, train_tar = filter_len(
train_src, train_tar, max_sequence_len=max_sequence_len)
eval_src, eval_tar = _para_file_to_ids( src_eval_file, tar_eval_file, \
src_vocab, tar_vocab )
test_src, test_tar = _para_file_to_ids( src_test_file, tar_test_file, \
src_vocab, tar_vocab )
return ( train_src, train_tar), (eval_src, eval_tar), (test_src, test_tar),\
(src_vocab, tar_vocab)
def raw_mono_data(vocab_file, file_path):
src_vocab = _build_vocab(vocab_file)
test_src, test_tar = _para_file_to_ids( file_path, file_path, \
src_vocab, src_vocab )
return (test_src, test_tar)
def get_data_iter(raw_data,
batch_size,
mode='train',
enable_ce=False,
cache_num=20):
src_data, tar_data = raw_data
data_len = len(src_data)
index = np.arange(data_len)
if mode == "train" and not enable_ce:
np.random.shuffle(index)
def to_pad_np(data, source=False):
max_len = 0
for ele in data:
if len(ele) > max_len:
max_len = len(ele)
ids = np.ones((batch_size, max_len), dtype='int64') * 2
mask = np.zeros((batch_size), dtype='int32')
for i, ele in enumerate(data):
ids[i, :len(ele)] = ele
if not source:
mask[i] = len(ele) - 1
else:
mask[i] = len(ele)
return ids, mask
b_src = []
if mode != "train":
cache_num = 1
for j in range(data_len):
if len(b_src) == batch_size * cache_num:
# build batch size
# sort
new_cache = sorted(b_src, key=lambda k: len(k[0]))
for i in range(cache_num):
batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
src_cache = [w[0] for w in batch_data]
tar_cache = [w[1] for w in batch_data]
src_ids, src_mask = to_pad_np(src_cache, source=True)
tar_ids, tar_mask = to_pad_np(tar_cache)
yield (src_ids, src_mask, tar_ids, tar_mask)
b_src = []
b_src.append((src_data[index[j]], tar_data[index[j]]))
if len(b_src) == batch_size * cache_num:
new_cache = sorted(b_src, key=lambda k: len(k[0]))
for i in range(cache_num):
batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
src_cache = [w[0] for w in batch_data]
tar_cache = [w[1] for w in batch_data]
src_ids, src_mask = to_pad_np(src_cache, source=True)
tar_ids, tar_mask = to_pad_np(tar_cache)
yield (src_ids, src_mask, tar_ids, tar_mask)
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
python train.py \
--src_lang en --tar_lang vi \
--attention True \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--train_data_prefix data/en-vi/train \
--eval_data_prefix data/en-vi/tst2012 \
--test_data_prefix data/en-vi/tst2013 \
--vocab_prefix data/en-vi/vocab \
--use_gpu True \
--model_path attention_models
# -*- coding: utf-8 -*-
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import time
import os
import logging
import random
import math
import contextlib
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.profiler as profiler
from paddle.fluid.executor import Executor
import reader
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
from args import *
from base_model import BaseModel
from attention_model import AttentionModel
import logging
import pickle
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', 'seq2seq.profile'):
yield
else:
yield
def main():
args = parse_args()
print(args)
num_layers = args.num_layers
src_vocab_size = args.src_vocab_size
tar_vocab_size = args.tar_vocab_size
batch_size = args.batch_size
dropout = args.dropout
init_scale = args.init_scale
max_grad_norm = args.max_grad_norm
hidden_size = args.hidden_size
if args.enable_ce:
fluid.default_main_program().random_seed = 102
framework.default_startup_program().random_seed = 102
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
# Training process
if args.attention:
model = AttentionModel(
hidden_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=num_layers,
init_scale=init_scale,
dropout=dropout)
else:
model = BaseModel(
hidden_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=num_layers,
init_scale=init_scale,
dropout=dropout)
loss = model.build_graph()
inference_program = train_program.clone(for_test=True)
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=max_grad_norm))
lr = args.learning_rate
opt_type = args.optimizer
if opt_type == "sgd":
optimizer = fluid.optimizer.SGD(lr)
elif opt_type == "adam":
optimizer = fluid.optimizer.Adam(lr)
else:
print("only support [sgd|adam]")
raise Exception("opt type not support")
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = Executor(place)
exe.run(startup_program)
device_count = len(fluid.cuda_places()) if args.use_gpu else len(
fluid.cpu_places())
CompiledProgram = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name)
train_data_prefix = args.train_data_prefix
eval_data_prefix = args.eval_data_prefix
test_data_prefix = args.test_data_prefix
vocab_prefix = args.vocab_prefix
src_lang = args.src_lang
tar_lang = args.tar_lang
print("begin to load data")
raw_data = reader.raw_data(src_lang, tar_lang, vocab_prefix,
train_data_prefix, eval_data_prefix,
test_data_prefix, args.max_len)
print("finished load data")
train_data, valid_data, test_data, _ = raw_data
def prepare_input(batch, epoch_id=0, with_lr=True):
src_ids, src_mask, tar_ids, tar_mask = batch
res = {}
src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
in_tar = tar_ids[:, :-1]
label_tar = tar_ids[:, 1:]
in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
label_tar = label_tar.reshape(
(label_tar.shape[0], label_tar.shape[1], 1))
res['src'] = src_ids
res['tar'] = in_tar
res['label'] = label_tar
res['src_sequence_length'] = src_mask
res['tar_sequence_length'] = tar_mask
return res, np.sum(tar_mask)
# get train epoch size
def eval(data, epoch_id=0):
eval_data_iter = reader.get_data_iter(data, batch_size, mode='eval')
total_loss = 0.0
word_count = 0.0
for batch_id, batch in enumerate(eval_data_iter):
input_data_feed, word_num = prepare_input(
batch, epoch_id, with_lr=False)
fetch_outs = exe.run(inference_program,
feed=input_data_feed,
fetch_list=[loss.name],
use_program_cache=False)
cost_train = np.array(fetch_outs[0])
total_loss += cost_train * batch_size
word_count += word_num
ppl = np.exp(total_loss / word_count)
return ppl
def train():
ce_time = []
ce_ppl = []
max_epoch = args.max_epoch
for epoch_id in range(max_epoch):
start_time = time.time()
if args.enable_ce:
train_data_iter = reader.get_data_iter(
train_data, batch_size, enable_ce=True)
else:
train_data_iter = reader.get_data_iter(train_data, batch_size)
total_loss = 0
word_count = 0.0
batch_times = []
for batch_id, batch in enumerate(train_data_iter):
batch_start_time = time.time()
input_data_feed, word_num = prepare_input(
batch, epoch_id=epoch_id)
word_count += word_num
fetch_outs = exe.run(program=CompiledProgram,
feed=input_data_feed,
fetch_list=[loss.name],
use_program_cache=True)
cost_train = np.mean(fetch_outs[0])
# print(cost_train)
total_loss += cost_train * batch_size
batch_end_time = time.time()
batch_time = batch_end_time - batch_start_time
batch_times.append(batch_time)
if batch_id > 0 and batch_id % 100 == 0:
print("-- Epoch:[%d]; Batch:[%d]; Time: %.5f s; ppl: %.5f" %
(epoch_id, batch_id, batch_time,
np.exp(total_loss / word_count)))
ce_ppl.append(np.exp(total_loss / word_count))
total_loss = 0.0
word_count = 0.0
end_time = time.time()
epoch_time = end_time - start_time
ce_time.append(epoch_time)
print(
"\nTrain epoch:[%d]; Epoch Time: %.5f; avg_time: %.5f s/step\n"
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
if not args.profile:
dir_name = os.path.join(args.model_path,
"epoch_" + str(epoch_id))
print("begin to save", dir_name)
fluid.io.save_params(exe, dir_name, main_program=train_program)
print("save finished")
dev_ppl = eval(valid_data)
print("dev ppl", dev_ppl)
test_ppl = eval(test_data)
print("test ppl", test_ppl)
if args.enable_ce:
card_num = get_cards()
_ppl = 0
_time = 0
try:
_time = ce_time[-1]
_ppl = ce_ppl[-1]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" % (card_num, _time))
print("kpis\ttrain_ppl_card%s\t%f" % (card_num, _ppl))
with profile_context(args.profile):
train()
def get_cards():
num = 0
cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if cards != '':
num = len(cards.split(","))
return num
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
if __name__ == '__main__':
check_version()
main()
运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.6版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
# Variational Autoencoder (VAE) for Text Generation
以下是本范例模型的简要目录结构及说明:
```text
.
├── README.md # 文档,本文件
├── args.py # 训练、预测以及模型参数配置程序
├── reader.py # 数据读入程序
├── download.py # 数据下载程序
├── train.py # 训练主程序
├── infer.py # 预测主程序
├── run.sh # 默认配置的启动脚本
├── infer.sh # 默认配置的解码脚本
└── model.py # VAE模型配置程序
```
## 简介
本目录下此范例模型的实现,旨在展示如何用Paddle Fluid的 **<font color='red'>新Seq2Seq API</font>** 构建用于文本生成的VAE示例,其中LSTM作为编码器和解码器。 分别对官方PTB数据和SWDA数据进行培训。
关于VAE的详细介绍参照: [(Bowman et al., 2015) Generating Sentences from a Continuous Space](https://arxiv.org/pdf/1511.06349.pdf)
## 数据介绍
本教程使用了两个文本数据集:
PTB dataset,原始下载地址为: http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz。
SWDA dataset,来源于[Knowledge-Guided CVAE for dialog generation](https://arxiv.org/pdf/1703.10960.pdf),原始数据集下载地址为:https://github.com/snakeztc/NeuralDialog-CVAE ,感谢作者@[snakeztc](https://github.com/snakeztc)。我们过滤了数据集中长度小于5的短句子。
### 数据获取
```
python download.py --task ptb/swda
```
## 模型训练
`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行:
```
sh run.sh ptb/swda
```
如果需要修改模型的参数设置,也可以通过下面命令配置:
```
python train.py \
--vocab_size 10003 \
--batch_size 32 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--dataset_prefix data/${dataset}/${dataset} \
--model_path ${dataset}_model\
--use_gpu True \
--max_epoch 50 \
```
训练程序采用了 Early Stopping,会在每个epoch根据ppl的表现来决定是否保存模型。
## 模型预测
当模型训练完成之后, 可以利用infer.sh的脚本进行预测,选择加载模型保存目录下的第 k 个epoch的模型进行预测,生成batch_size条短文本。
```
sh infer.sh ptb/swda k
```
如果需要修改模型预测输出的参数设置,也可以通过下面命令配置:
```
python infer.py \
--vocab_size 10003 \
--batch_size 32 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--dataset_prefix data/${dataset}/${dataset} \
--use_gpu True \
--reload_model ${dataset}_model/epoch_${k} \
```
## 效果评价
```sh
PTB数据集:
Test PPL: 102.24
Test NLL: 108.22
SWDA数据集:
Test PPL: 64.21
Test NLL: 81.92
```
## 生成样例
the movie are discovered in the u.s. industry that on aircraft variations for a <unk> aircraft that was repaired
the percentage of treasury bonds rose to N N at N N up N N from the two days N N and a premium over N
he could n't plunge as a factor that attention now has n't picked up for the state according to mexico
take the remark we need to do then support for the market to tell it
i think that it believes the core of the company in its first quarter of heavy mid-october to fuel earnings after prices on friday
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import distutils.util
def parse_args():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--dataset_prefix", type=str, help="file prefix for train data")
parser.add_argument(
"--optimizer",
type=str,
default='adam',
help="optimizer to use, only supprt[sgd|adam]")
parser.add_argument(
"--learning_rate",
type=float,
default=0.001,
help="learning rate for optimizer")
parser.add_argument(
"--num_layers",
type=int,
default=1,
help="layers number of encoder and decoder")
parser.add_argument(
"--hidden_size",
type=int,
default=256,
help="hidden size of encoder and decoder")
parser.add_argument("--vocab_size", type=int, help="source vocab size")
parser.add_argument(
"--batch_size", type=int, help="batch size of each step")
parser.add_argument(
"--max_epoch", type=int, default=20, help="max epoch for the training")
parser.add_argument(
"--max_len",
type=int,
default=1280,
help="max length for source and target sentence")
parser.add_argument(
"--dec_dropout_in",
type=float,
default=0.5,
help="decoder input drop probability")
parser.add_argument(
"--dec_dropout_out",
type=float,
default=0.5,
help="decoder output drop probability")
parser.add_argument(
"--enc_dropout_in",
type=float,
default=0.,
help="encoder input drop probability")
parser.add_argument(
"--enc_dropout_out",
type=float,
default=0.,
help="encoder output drop probability")
parser.add_argument(
"--word_keep_prob",
type=float,
default=0.5,
help="word keep probability")
parser.add_argument(
"--init_scale",
type=float,
default=0.0,
help="init scale for parameter")
parser.add_argument(
"--max_grad_norm",
type=float,
default=5.0,
help="max grad norm for global norm clip")
parser.add_argument(
"--model_path",
type=str,
default='model',
help="model path for model to save")
parser.add_argument(
"--reload_model", type=str, help="reload model to inference")
parser.add_argument(
"--infer_output_file",
type=str,
default='infer_output.txt',
help="file name for inference output")
parser.add_argument(
"--beam_size", type=int, default=10, help="file name for inference")
parser.add_argument(
'--use_gpu',
type=bool,
default=False,
help='Whether using gpu [True|False]')
parser.add_argument(
"--enable_ce",
action='store_true',
help="The flag indicating whether to run the task "
"for continuous evaluation.")
parser.add_argument(
"--profile", action='store_true', help="Whether enable the profile.")
parser.add_argument(
"--warm_up",
type=int,
default=10,
help='number of warm up epochs for KL')
parser.add_argument(
"--kl_start", type=float, default=0.1, help='KL start value, upto 1.0')
parser.add_argument(
"--attr_init",
type=str,
default='normal_initializer',
help="initializer for paramters")
parser.add_argument(
"--cache_num", type=int, default=1, help='cache num for reader')
parser.add_argument(
"--max_decay",
type=int,
default=5,
help='max decay tries (if exceeds, early stop)')
parser.add_argument(
"--sort_cache",
action='store_true',
help='sort cache before batch to accelerate training')
args = parser.parse_args()
return args
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Script for downloading training data.
'''
import os
import sys
import shutil
import argparse
import tempfile
import urllib
import tarfile
import io
if sys.version_info >= (3, 0):
import urllib.request
import zipfile
URLLIB = urllib
if sys.version_info >= (3, 0):
URLLIB = urllib.request
TASKS = ['ptb', 'swda']
TASK2PATH = {
'ptb': 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz',
'swda': 'https://baidu-nlp.bj.bcebos.com/TextGen/swda.tar.gz'
}
def un_tar(tar_name, dir_name):
try:
t = tarfile.open(tar_name)
t.extractall(path=dir_name)
return True
except Exception as e:
print(e)
return False
def download_and_extract(task, data_path):
print('Downloading and extracting %s...' % task)
data_file = os.path.join(data_path, TASK2PATH[task].split('/')[-1])
URLLIB.urlretrieve(TASK2PATH[task], data_file)
un_tar(data_file, data_path)
os.remove(data_file)
if task == 'ptb':
src_dir = os.path.join(data_path, 'simple-examples')
dst_dir = os.path.join(data_path, 'ptb')
if not os.path.exists(dst_dir):
os.mkdir(dst_dir)
shutil.copy(os.path.join(src_dir, 'data/ptb.train.txt'), dst_dir)
shutil.copy(os.path.join(src_dir, 'data/ptb.valid.txt'), dst_dir)
shutil.copy(os.path.join(src_dir, 'data/ptb.test.txt'), dst_dir)
shutil.rmtree(src_dir)
print('\tCompleted!')
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument(
'-d',
'--data_dir',
help='directory to save data to',
type=str,
default='data')
parser.add_argument(
'-t',
'--task',
help='tasks to download data for as a comma separated string',
type=str,
default='ptb')
args = parser.parse_args(arguments)
if not os.path.isdir(args.data_dir):
os.mkdir(args.data_dir)
download_and_extract(args.task, args.data_dir)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import time
import os
import random
import io
import math
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.executor import Executor
import reader
import sys
line_tok = '\n'
space_tok = ' '
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
line_tok = u'\n'
space_tok = u' '
import os
from args import *
#from . import lm_model
import logging
import pickle
from model import VAE
from reader import BOS_ID, EOS_ID, get_vocab
def infer():
args = parse_args()
num_layers = args.num_layers
src_vocab_size = args.vocab_size
tar_vocab_size = args.vocab_size
batch_size = args.batch_size
init_scale = args.init_scale
max_grad_norm = args.max_grad_norm
hidden_size = args.hidden_size
attr_init = args.attr_init
latent_size = 32
if args.enable_ce:
fluid.default_main_program().random_seed = 102
framework.default_startup_program().random_seed = 102
model = VAE(hidden_size,
latent_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=num_layers,
init_scale=init_scale,
attr_init=attr_init)
beam_size = args.beam_size
trans_res = model.build_graph(mode='sampling', beam_size=beam_size)
# clone from default main program and use it as the validation program
main_program = fluid.default_main_program()
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
dir_name = args.reload_model
print("dir name", dir_name)
fluid.io.load_params(exe, dir_name)
vocab, tar_id2vocab = get_vocab(args.dataset_prefix)
infer_output = np.ones((batch_size, 1), dtype='int64') * BOS_ID
fetch_outs = exe.run(feed={'tar': infer_output},
fetch_list=[trans_res.name],
use_program_cache=False)
with io.open(args.infer_output_file, 'w', encoding='utf-8') as out_file:
for line in fetch_outs[0]:
end_id = -1
if EOS_ID in line:
end_id = np.where(line == EOS_ID)[0][0]
new_line = [tar_id2vocab[e[0]] for e in line[1:end_id]]
out_file.write(space_tok.join(new_line))
out_file.write(line_tok)
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
if __name__ == '__main__':
check_version()
infer()
#!/bin/bash
set -x
export CUDA_VISIBLE_DEVICES=0
dataset=$1
k=$2
python infer.py \
--vocab_size 10003 \
--batch_size 32 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--dataset_prefix data/${dataset}/${dataset} \
--use_gpu True \
--reload_model ${dataset}_model/epoch_${k} \
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.fluid.layers as layers
import paddle.fluid as fluid
import numpy as np
from paddle.fluid import ParamAttr
from paddle.fluid.layers import RNNCell, LSTMCell, rnn, BeamSearchDecoder, dynamic_decode
from paddle.fluid.contrib.layers import basic_lstm, BasicLSTMUnit
from reader import BOS_ID, EOS_ID
INF = 1. * 1e5
alpha = 0.6
normal_initializer = lambda x: fluid.initializer.NormalInitializer(loc=0., scale=x**-0.5)
uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x)
zero_constant = fluid.initializer.Constant(0.0)
class EncoderCell(RNNCell):
def __init__(self,
num_layers,
hidden_size,
param_attr_initializer,
param_attr_scale,
dropout_prob=0.):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
self.lstm_cells = []
for i in range(num_layers):
lstm_name = "enc_layers_" + str(i)
self.lstm_cells.append(
LSTMCell(
hidden_size, forget_bias=0., name=lstm_name))
def call(self, step_input, states):
new_states = []
for i in range(self.num_layers):
out, new_state = self.lstm_cells[i](step_input, states[i])
step_input = layers.dropout(
out,
self.dropout_prob,
dropout_implementation="upscale_in_train"
) if self.dropout_prob > 0 else out
new_states.append(new_state)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.lstm_cells]
class DecoderCell(RNNCell):
def __init__(self,
num_layers,
hidden_size,
latent_z,
param_attr_initializer,
param_attr_scale,
dropout_prob=0.):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.dropout_prob = dropout_prob
self.latent_z = latent_z
self.lstm_cells = []
param_attr = ParamAttr(
initializer=param_attr_initializer(param_attr_scale))
for i in range(num_layers):
lstm_name = "dec_layers_" + str(i)
self.lstm_cells.append(
LSTMCell(
hidden_size, param_attr, forget_bias=0., name=lstm_name))
def call(self, step_input, states):
lstm_states = states
new_lstm_states = []
step_input = layers.concat([step_input, self.latent_z], 1)
for i in range(self.num_layers):
out, lstm_state = self.lstm_cells[i](step_input, lstm_states[i])
step_input = layers.dropout(
out,
self.dropout_prob,
dropout_implementation="upscale_in_train"
) if self.dropout_prob > 0 else out
new_lstm_states.append(lstm_state)
return step_input, new_lstm_states
class VAE(object):
def __init__(self,
hidden_size,
latent_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=1,
init_scale=0.1,
dec_dropout_in=0.5,
dec_dropout_out=0.5,
enc_dropout_in=0.,
enc_dropout_out=0.,
word_keep_prob=0.5,
batch_first=True,
attr_init="normal_initializer"):
self.hidden_size = hidden_size
self.latent_size = latent_size
self.src_vocab_size = src_vocab_size
self.tar_vocab_size = tar_vocab_size
self.batch_size = batch_size
self.num_layers = num_layers
self.init_scale = init_scale
self.dec_dropout_in = dec_dropout_in
self.dec_dropout_out = dec_dropout_out
self.enc_dropout_in = enc_dropout_in
self.enc_dropout_out = enc_dropout_out
self.word_keep_prob = word_keep_prob
self.batch_first = batch_first
if attr_init == "normal_initializer":
self.param_attr_initializer = normal_initializer
self.param_attr_scale = hidden_size
elif attr_init == "uniform_initializer":
self.param_attr_initializer = uniform_initializer
self.param_attr_scale = init_scale
else:
raise TypeError("The type of 'attr_initializer' is not supported")
self.src_embeder = lambda x: fluid.embedding(
input=x,
size=[self.src_vocab_size, self.hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='source_embedding',
initializer=self.param_attr_initializer(self.param_attr_scale)))
self.tar_embeder = lambda x: fluid.embedding(
input=x,
size=[self.tar_vocab_size, self.hidden_size],
dtype='float32',
is_sparse=False,
param_attr=fluid.ParamAttr(
name='target_embedding',
initializer=self.param_attr_initializer(self.param_attr_scale)))
def _build_data(self, mode='train'):
if mode == 'train':
self.src = fluid.data(name="src", shape=[None, None], dtype='int64')
self.src_sequence_length = fluid.data(
name="src_sequence_length", shape=[None], dtype='int32')
self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64')
self.tar_sequence_length = fluid.data(
name="tar_sequence_length", shape=[None], dtype='int32')
self.label = fluid.data(
name="label", shape=[None, None, 1], dtype='int64')
self.kl_weight = fluid.data(
name='kl_weight', shape=[1], dtype='float32')
else:
self.tar = fluid.data(name="tar", shape=[None, None], dtype='int64')
def _emebdding(self, mode='train'):
if mode == 'train':
self.src_emb = self.src_embeder(self.src)
self.tar_emb = self.tar_embeder(self.tar)
def _build_encoder(self):
self.enc_input = layers.dropout(
self.src_emb,
dropout_prob=self.enc_dropout_in,
dropout_implementation="upscale_in_train")
enc_cell = EncoderCell(self.num_layers, self.hidden_size,
self.param_attr_initializer,
self.param_attr_scale, self.enc_dropout_out)
enc_output, enc_final_state = rnn(
cell=enc_cell,
inputs=self.enc_input,
sequence_length=self.src_sequence_length)
return enc_output, enc_final_state
def _build_distribution(self, enc_final_state=None):
enc_hidden = [
layers.concat(
state, axis=-1) for state in enc_final_state
]
enc_hidden = layers.concat(enc_hidden, axis=-1)
z_mean_log_var = layers.fc(input=enc_hidden,
size=self.latent_size * 2,
name='fc_dist')
z_mean, z_log_var = layers.split(z_mean_log_var, 2, -1)
return z_mean, z_log_var
def _build_decoder(self,
z_mean=None,
z_log_var=None,
enc_output=None,
mode='train',
beam_size=10):
dec_input = layers.dropout(
self.tar_emb,
dropout_prob=self.dec_dropout_in,
dropout_implementation="upscale_in_train")
# `output_layer` will be used within BeamSearchDecoder
output_layer = lambda x: layers.fc(x,
size=self.tar_vocab_size,
num_flatten_dims=len(x.shape) - 1,
name="output_w")
# `sample_output_layer` samples an id from the logits distribution instead of argmax(logits)
# it will be used within BeamSearchDecoder
sample_output_layer = lambda x: layers.unsqueeze(layers.one_hot(
layers.unsqueeze(
layers.sampling_id(
layers.softmax(
layers.squeeze(output_layer(x),[1])
),dtype='int'), [1]),
depth=self.tar_vocab_size), [1])
if mode == 'train':
latent_z = self._sampling(z_mean, z_log_var)
else:
latent_z = layers.gaussian_random_batch_size_like(
self.tar, shape=[-1, self.latent_size])
dec_first_hidden_cell = layers.fc(latent_z,
2 * self.hidden_size *
self.num_layers,
name='fc_hc')
dec_first_hidden, dec_first_cell = layers.split(dec_first_hidden_cell,
2)
if self.num_layers > 1:
dec_first_hidden = layers.split(dec_first_hidden, self.num_layers)
dec_first_cell = layers.split(dec_first_cell, self.num_layers)
else:
dec_first_hidden = [dec_first_hidden]
dec_first_cell = [dec_first_cell]
dec_initial_states = [[h, c]
for h, c in zip(dec_first_hidden, dec_first_cell)]
dec_cell = DecoderCell(self.num_layers, self.hidden_size, latent_z,
self.param_attr_initializer,
self.param_attr_scale, self.dec_dropout_out)
if mode == 'train':
dec_output, _ = rnn(cell=dec_cell,
inputs=dec_input,
initial_states=dec_initial_states,
sequence_length=self.tar_sequence_length)
dec_output = output_layer(dec_output)
return dec_output
elif mode == 'greedy':
start_token = 1
end_token = 2
max_length = 100
beam_search_decoder = BeamSearchDecoder(
dec_cell,
start_token,
end_token,
beam_size=1,
embedding_fn=self.tar_embeder,
output_fn=output_layer)
outputs, _ = dynamic_decode(
beam_search_decoder,
inits=dec_initial_states,
max_step_num=max_length)
return outputs
elif mode == 'sampling':
start_token = 1
end_token = 2
max_length = 100
beam_search_decoder = BeamSearchDecoder(
dec_cell,
start_token,
end_token,
beam_size=1,
embedding_fn=self.tar_embeder,
output_fn=sample_output_layer)
outputs, _ = dynamic_decode(
beam_search_decoder,
inits=dec_initial_states,
max_step_num=max_length)
return outputs
else:
print("mode not supprt", mode)
def _sampling(self, z_mean, z_log_var):
"""reparameterization trick
"""
# by default, random_normal has mean=0 and std=1.0
epsilon = layers.gaussian_random_batch_size_like(
self.tar, shape=[-1, self.latent_size])
epsilon.stop_gradient = True
return z_mean + layers.exp(0.5 * z_log_var) * epsilon
def _kl_dvg(self, means, logvars):
"""compute the KL divergence between Gaussian distribution
"""
kl_cost = -0.5 * (logvars - fluid.layers.square(means) -
fluid.layers.exp(logvars) + 1.0)
kl_cost = fluid.layers.reduce_mean(kl_cost, 0)
return fluid.layers.reduce_sum(kl_cost)
def _compute_loss(self, mean, logvars, dec_output):
kl_loss = self._kl_dvg(mean, logvars)
rec_loss = layers.softmax_with_cross_entropy(
logits=dec_output, label=self.label, soft_label=False)
rec_loss = layers.reshape(rec_loss, shape=[self.batch_size, -1])
max_tar_seq_len = layers.shape(self.tar)[1]
tar_mask = layers.sequence_mask(
self.tar_sequence_length, maxlen=max_tar_seq_len, dtype='float32')
rec_loss = rec_loss * tar_mask
rec_loss = layers.reduce_mean(rec_loss, dim=[0])
rec_loss = layers.reduce_sum(rec_loss)
loss = kl_loss * self.kl_weight + rec_loss
return loss, kl_loss, rec_loss
def _beam_search(self, enc_last_hidden, enc_last_cell):
pass
def build_graph(self, mode='train', beam_size=10):
if mode == 'train' or mode == 'eval':
self._build_data()
self._emebdding()
enc_output, enc_final_state = self._build_encoder()
z_mean, z_log_var = self._build_distribution(enc_final_state)
dec_output = self._build_decoder(z_mean, z_log_var, enc_output)
loss, kl_loss, rec_loss = self._compute_loss(z_mean, z_log_var,
dec_output)
return loss, kl_loss, rec_loss
elif mode == "sampling" or mode == 'greedy':
self._build_data(mode)
self._emebdding(mode)
dec_output = self._build_decoder(mode=mode, beam_size=1)
return dec_output
else:
print("not support mode ", mode)
raise Exception("not support mode: " + mode)
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for parsing PTB text files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import sys
import numpy as np
import io
from collections import Counter
Py3 = sys.version_info[0] == 3
if Py3:
line_tok = '\n'
else:
line_tok = u'\n'
PAD_ID = 0
BOS_ID = 1
EOS_ID = 2
UNK_ID = 3
def _read_words(filename):
data = []
with io.open(filename, "r", encoding='utf-8') as f:
if Py3:
return f.read().replace("\n", "<EOS>").split()
else:
return f.read().decode("utf-8").replace(u"\n", u"<EOS>").split()
def read_all_line(filename):
data = []
with io.open(filename, "r", encoding='utf-8') as f:
for line in f.readlines():
data.append(line.strip())
return data
def _vocab(vocab_file, train_file, max_vocab_cnt):
lines = read_all_line(train_file)
all_words = []
for line in lines:
all_words.extend(line.split())
vocab_count = Counter(all_words).most_common()
raw_vocab_size = min(len(vocab_count), max_vocab_cnt)
with io.open(vocab_file, "w", encoding='utf-8') as f:
for voc, fre in vocab_count[0:max_vocab_cnt]:
f.write(voc)
f.write(line_tok)
def _build_vocab(vocab_file, train_file=None, max_vocab_cnt=-1):
if not os.path.exists(vocab_file):
_vocab(vocab_file, train_file, max_vocab_cnt)
vocab_dict = {"<PAD>": 0, "<BOS>": 1, "<EOS>": 2, "<UNK>": 3}
ids = 4
with io.open(vocab_file, "r", encoding='utf-8') as f:
for line in f.readlines():
vocab_dict[line.strip()] = ids
ids += 1
# rev_vocab = {value:key for key, value in vocab_dict.items()}
print("vocab word num", ids)
return vocab_dict
def _para_file_to_ids(src_file, src_vocab):
src_data = []
with io.open(src_file, "r", encoding='utf-8') as f_src:
for line in f_src.readlines():
arra = line.strip().split()
ids = [BOS_ID]
ids.extend(
[src_vocab[w] if w in src_vocab else UNK_ID for w in arra])
ids.append(EOS_ID)
src_data.append(ids)
return src_data
def filter_len(src, max_sequence_len=128):
new_src = []
for id1 in src:
if len(id1) > max_sequence_len:
id1 = id1[:max_sequence_len]
new_src.append(id1)
return new_src
def raw_data(dataset_prefix, max_sequence_len=50, max_vocab_cnt=-1):
src_vocab_file = dataset_prefix + ".vocab.txt"
src_train_file = dataset_prefix + ".train.txt"
src_eval_file = dataset_prefix + ".valid.txt"
src_test_file = dataset_prefix + ".test.txt"
src_vocab = _build_vocab(src_vocab_file, src_train_file, max_vocab_cnt)
train_src = _para_file_to_ids(src_train_file, src_vocab)
train_src = filter_len(train_src, max_sequence_len=max_sequence_len)
eval_src = _para_file_to_ids(src_eval_file, src_vocab)
test_src = _para_file_to_ids(src_test_file, src_vocab)
return train_src, eval_src, test_src, src_vocab
def get_vocab(dataset_prefix, max_sequence_len=50):
src_vocab_file = dataset_prefix + ".vocab.txt"
src_vocab = _build_vocab(src_vocab_file)
rev_vocab = {}
for key, value in src_vocab.items():
rev_vocab[value] = key
return src_vocab, rev_vocab
def raw_mono_data(vocab_file, file_path):
src_vocab = _build_vocab(vocab_file)
test_src, test_tar = _para_file_to_ids( file_path, file_path, \
src_vocab, src_vocab )
return (test_src, test_tar)
def get_data_iter(raw_data,
batch_size,
sort_cache=False,
cache_num=1,
mode='train',
enable_ce=False):
src_data = raw_data
data_len = len(src_data)
index = np.arange(data_len)
if mode == "train" and not enable_ce:
np.random.shuffle(index)
def to_pad_np(data):
max_len = 0
for ele in data:
if len(ele) > max_len:
max_len = len(ele)
ids = np.ones(
(batch_size, max_len), dtype='int64') * PAD_ID # PAD_ID = 0
mask = np.zeros((batch_size), dtype='int32')
for i, ele in enumerate(data):
ids[i, :len(ele)] = ele
mask[i] = len(ele)
return ids, mask
b_src = []
if mode != "train":
cache_num = 1
for j in range(data_len):
if len(b_src) == batch_size * cache_num:
if sort_cache:
new_cache = sorted(b_src, key=lambda k: len(k))
new_cache = b_src
for i in range(cache_num):
batch_data = new_cache[i * batch_size:(i + 1) * batch_size]
src_ids, src_mask = to_pad_np(batch_data)
yield (src_ids, src_mask)
b_src = []
b_src.append(src_data[index[j]])
if len(b_src) > 0:
if sort_cache:
new_cache = sorted(b_src, key=lambda k: len(k))
new_cache = b_src
for i in range(0, len(b_src), batch_size):
end_index = min((i + 1) * batch_size, len(b_src))
batch_data = new_cache[i * batch_size:end_index]
src_ids, src_mask = to_pad_np(batch_data)
yield (src_ids, src_mask)
#!/bin/bash
set -x
export CUDA_VISIBLE_DEVICES=0
dataset=$1
python train.py \
--vocab_size 10003 \
--batch_size 32 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--dataset_prefix data/${dataset}/${dataset} \
--model_path ${dataset}_model\
--use_gpu True \
--max_epoch 200 \
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import time
import os
import random
import math
import contextlib
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import paddle.fluid.profiler as profiler
from paddle.fluid.executor import Executor
import reader
import sys
if sys.version[0] == '2':
reload(sys)
sys.setdefaultencoding("utf-8")
import os
from args import *
from model import VAE
import logging
import pickle
@contextlib.contextmanager
def profile_context(profile=True):
if profile:
with profiler.profiler('All', 'total', 'seq2seq.profile'):
yield
else:
yield
def main():
args = parse_args()
print(args)
num_layers = args.num_layers
src_vocab_size = args.vocab_size
tar_vocab_size = args.vocab_size
batch_size = args.batch_size
init_scale = args.init_scale
max_grad_norm = args.max_grad_norm
hidden_size = args.hidden_size
attr_init = args.attr_init
latent_size = 32
main_program = fluid.Program()
startup_program = fluid.Program()
if args.enable_ce:
fluid.default_main_program().random_seed = 123
framework.default_startup_program().random_seed = 123
# Training process
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
model = VAE(hidden_size,
latent_size,
src_vocab_size,
tar_vocab_size,
batch_size,
num_layers=num_layers,
init_scale=init_scale,
attr_init=attr_init)
loss, kl_loss, rec_loss = model.build_graph()
# clone from default main program and use it as the validation program
main_program = fluid.default_main_program()
inference_program = fluid.default_main_program().clone(
for_test=True)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=max_grad_norm))
learning_rate = fluid.layers.create_global_var(
name="learning_rate",
shape=[1],
value=float(args.learning_rate),
dtype="float32",
persistable=True)
opt_type = args.optimizer
if opt_type == "sgd":
optimizer = fluid.optimizer.SGD(learning_rate)
elif opt_type == "adam":
optimizer = fluid.optimizer.Adam(learning_rate)
else:
print("only support [sgd|adam]")
raise Exception("opt type not support")
optimizer.minimize(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = Executor(place)
exe.run(startup_program)
train_program = fluid.compiler.CompiledProgram(main_program)
dataset_prefix = args.dataset_prefix
print("begin to load data")
raw_data = reader.raw_data(dataset_prefix, args.max_len)
print("finished load data")
train_data, valid_data, test_data, _ = raw_data
anneal_r = 1.0 / (args.warm_up * len(train_data) / args.batch_size)
def prepare_input(batch, kl_weight=1.0, lr=None):
src_ids, src_mask = batch
res = {}
src_ids = src_ids.reshape((src_ids.shape[0], src_ids.shape[1]))
in_tar = src_ids[:, :-1]
label_tar = src_ids[:, 1:]
in_tar = in_tar.reshape((in_tar.shape[0], in_tar.shape[1]))
label_tar = label_tar.reshape(
(label_tar.shape[0], label_tar.shape[1], 1))
res['src'] = src_ids
res['tar'] = in_tar
res['label'] = label_tar
res['src_sequence_length'] = src_mask
res['tar_sequence_length'] = src_mask - 1
res['kl_weight'] = np.array([kl_weight]).astype(np.float32)
if lr is not None:
res['learning_rate'] = np.array([lr]).astype(np.float32)
return res, np.sum(src_mask), np.sum(src_mask - 1)
# get train epoch size
def eval(data):
eval_data_iter = reader.get_data_iter(data, batch_size, mode='eval')
total_loss = 0.0
word_count = 0.0
batch_count = 0.0
for batch_id, batch in enumerate(eval_data_iter):
input_data_feed, src_word_num, dec_word_sum = prepare_input(batch)
fetch_outs = exe.run(inference_program,
feed=input_data_feed,
fetch_list=[loss.name],
use_program_cache=False)
cost_train = np.array(fetch_outs[0])
total_loss += cost_train * batch_size
word_count += dec_word_sum
batch_count += batch_size
nll = total_loss / batch_count
ppl = np.exp(total_loss / word_count)
return nll, ppl
def train():
ce_time = []
ce_ppl = []
max_epoch = args.max_epoch
kl_w = args.kl_start
lr_w = args.learning_rate
best_valid_nll = 1e100 # +inf
best_epoch_id = -1
decay_cnt = 0
max_decay = args.max_decay
decay_factor = 0.5
decay_ts = 2
steps_not_improved = 0
for epoch_id in range(max_epoch):
start_time = time.time()
if args.enable_ce:
train_data_iter = reader.get_data_iter(
train_data,
batch_size,
args.sort_cache,
args.cache_num,
enable_ce=True)
else:
train_data_iter = reader.get_data_iter(
train_data, batch_size, args.sort_cache, args.cache_num)
total_loss = 0
total_rec_loss = 0
total_kl_loss = 0
word_count = 0.0
batch_count = 0.0
batch_times = []
batch_start_time = time.time()
for batch_id, batch in enumerate(train_data_iter):
kl_w = min(1.0, kl_w + anneal_r)
kl_weight = kl_w
input_data_feed, src_word_num, dec_word_sum = prepare_input(
batch, kl_weight, lr_w)
fetch_outs = exe.run(
program=train_program,
feed=input_data_feed,
fetch_list=[loss.name, kl_loss.name, rec_loss.name],
use_program_cache=False)
cost_train = np.array(fetch_outs[0])
kl_cost_train = np.array(fetch_outs[1])
rec_cost_train = np.array(fetch_outs[2])
total_loss += cost_train * batch_size
total_rec_loss += rec_cost_train * batch_size
total_kl_loss += kl_cost_train * batch_size
word_count += dec_word_sum
batch_count += batch_size
batch_end_time = time.time()
batch_time = batch_end_time - batch_start_time
batch_times.append(batch_time)
if batch_id > 0 and batch_id % 200 == 0:
print("-- Epoch:[%d]; Batch:[%d]; Time: %.4f s; "
"kl_weight: %.4f; kl_loss: %.4f; rec_loss: %.4f; "
"nll: %.4f; ppl: %.4f" %
(epoch_id, batch_id, batch_time, kl_w, total_kl_loss /
batch_count, total_rec_loss / batch_count, total_loss
/ batch_count, np.exp(total_loss / word_count)))
ce_ppl.append(np.exp(total_loss / word_count))
end_time = time.time()
epoch_time = end_time - start_time
ce_time.append(epoch_time)
print(
"\nTrain epoch:[%d]; Epoch Time: %.4f; avg_time: %.4f s/step\n"
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
val_nll, val_ppl = eval(valid_data)
print("dev ppl", val_ppl)
test_nll, test_ppl = eval(test_data)
print("test ppl", test_ppl)
if val_nll < best_valid_nll:
best_valid_nll = val_nll
steps_not_improved = 0
best_nll = test_nll
best_ppl = test_ppl
best_epoch_id = epoch_id
dir_name = os.path.join(args.model_path,
"epoch_" + str(best_epoch_id))
print("save model {}".format(dir_name))
fluid.io.save_params(exe, dir_name, main_program)
else:
steps_not_improved += 1
if steps_not_improved == decay_ts:
old_lr = lr_w
lr_w *= decay_factor
steps_not_improved = 0
new_lr = lr_w
print('-----\nchange lr, old lr: %f, new lr: %f\n-----' %
(old_lr, new_lr))
dir_name = args.model_path + "/epoch_" + str(best_epoch_id)
fluid.io.load_params(exe, dir_name)
decay_cnt += 1
if decay_cnt == max_decay:
break
print('\nbest testing nll: %.4f, best testing ppl %.4f\n' %
(best_nll, best_ppl))
with profile_context(args.profile):
train()
def get_cards():
num = 0
cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if cards != '':
num = len(cards.split(","))
return num
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
logger.error(err)
sys.exit(1)
if __name__ == '__main__':
check_version()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册