提交 ae47e2a8 编写于 作者: G guosheng

Refine seq2seq

上级 8aca373d
运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.7版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
# Sequence to Sequence (Seq2Seq)
以下是本范例模型的简要目录结构及说明:
```
.
├── README.md # 文档,本文件
├── args.py # 训练、预测以及模型参数配置程序
├── reader.py # 数据读入程序
├── download.py # 数据下载程序
├── train.py # 训练主程序
├── infer.py # 预测主程序
├── run.sh # 默认配置的启动脚本
├── infer.sh # 默认配置的解码脚本
├── attention_model.py # 带注意力机制的翻译模型程序
└── base_model.py # 无注意力机制的翻译模型程序
```
## 简介
Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。
本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html)
## 模型概览
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。
## 数据介绍
本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集
### 数据获取
```
python download.py
```
## 模型训练
`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行:
```
sh run.sh
```
默认使用带有注意力机制的RNN模型,可以通过修改 `attention` 参数为False来训练不带注意力机制的RNN模型。
```sh
export CUDA_VISIBLE_DEVICES=0
python train.py \
--src_lang en --tar_lang vi \
--attention True \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--train_data_prefix data/en-vi/train \
--eval_data_prefix data/en-vi/tst2012 \
--test_data_prefix data/en-vi/tst2013 \
--vocab_prefix data/en-vi/vocab \
--use_gpu True \
--model_path ./attention_models
```
训练程序会在每个epoch训练结束之后,save一次模型。
默认使用动态图模式进行训练,可以通过设置 `eager_run` 参数为False来以静态图模式进行训练,如下:
```sh
export CUDA_VISIBLE_DEVICES=0
python train.py \
--src_lang en --tar_lang vi \
--attention True \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--train_data_prefix data/en-vi/train \
--eval_data_prefix data/en-vi/tst2012 \
--test_data_prefix data/en-vi/tst2013 \
--vocab_prefix data/en-vi/vocab \
--use_gpu True \
--model_path ./attention_models \
--eager_run False
```
## 模型预测
当模型训练完成之后, 可以利用infer.sh的脚本进行预测,默认使用beam search的方法进行预测,加载第10个epoch的模型进行预测,对test的数据集进行解码
```
sh infer.sh
```
如果想预测别的数据文件,只需要将 --infer_file参数进行修改。
```sh
export CUDA_VISIBLE_DEVICES=0
python infer.py \
--attention True \
--src_lang en --tar_lang vi \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--vocab_prefix data/en-vi/vocab \
--infer_file data/en-vi/tst2013.en \
--reload_model attention_models/epoch_10 \
--infer_output_file attention_infer_output/infer_output.txt \
--beam_size 10 \
--use_gpu True
```
和训练类似,预测时同样可以以静态图模式进行,如下:
```sh
export CUDA_VISIBLE_DEVICES=0
python infer.py \
--attention True \
--src_lang en --tar_lang vi \
--num_layers 2 \
--hidden_size 512 \
--src_vocab_size 17191 \
--tar_vocab_size 7709 \
--batch_size 128 \
--dropout 0.2 \
--init_scale 0.1 \
--max_grad_norm 5.0 \
--vocab_prefix data/en-vi/vocab \
--infer_file data/en-vi/tst2013.en \
--reload_model attention_models/epoch_10 \
--infer_output_file attention_infer_output/infer_output.txt \
--beam_size 10 \
--use_gpu True
--eager_run False
```
## 效果评价
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下:
```sh
mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt
```
每个模型分别训练了10次,单次取第10个epoch保存的模型进行预测,取beam_size=10。效果如下(为了便于观察,对10次结果按照升序进行了排序):
```
> no attention
tst2012 BLEU:
[10.75 10.85 10.9 10.94 10.97 11.01 11.01 11.04 11.13 11.4]
tst2013 BLEU:
[10.71 10.71 10.74 10.76 10.91 10.94 11.02 11.16 11.21 11.44]
> with attention
tst2012 BLEU:
[21.14 22.34 22.54 22.65 22.71 22.71 23.08 23.15 23.3 23.4]
tst2013 BLEU:
[23.41 24.79 25.11 25.12 25.19 25.24 25.39 25.61 25.61 25.63]
```
...@@ -17,13 +17,58 @@ from __future__ import division ...@@ -17,13 +17,58 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import glob import glob
import six
import os
import io import io
import numpy as np
import itertools import itertools
from functools import partial
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset from paddle.fluid.io import BatchSampler, DataLoader, Dataset
def create_data_loader(args, device, for_train=True):
data_loaders = [None, None]
data_prefixes = [args.train_data_prefix, args.eval_data_prefix
] if args.eval_data_prefix else [args.train_data_prefix]
for i, data_prefix in enumerate(data_prefixes):
dataset = Seq2SeqDataset(
fpattern=data_prefix + "." + args.src_lang,
trg_fpattern=data_prefix + "." + args.tar_lang,
src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
trg_vocab_fpath=args.vocab_prefix + "." + args.tar_lang,
token_delimiter=None,
start_mark="<s>",
end_mark="</s>",
unk_mark="<unk>",
max_length=args.max_len if i == 0 else None,
truncate=True)
(args.src_vocab_size, args.tar_vocab_size, bos_id, eos_id,
unk_id) = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.batch_size * 20,
sort_type=SortType.POOL,
shuffle=False if args.enable_ce else True)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
collate_fn=partial(
prepare_train_input,
bos_id=bos_id,
eos_id=eos_id,
pad_id=eos_id),
num_workers=0,
return_list=True)
data_loaders[i] = data_loader
return data_loaders
def prepare_train_input(insts, bos_id, eos_id, pad_id): def prepare_train_input(insts, bos_id, eos_id, pad_id):
src, src_length = pad_batch_data( src, src_length = pad_batch_data(
[inst[0] for inst in insts], pad_id=pad_id) [inst[0] for inst in insts], pad_id=pad_id)
...@@ -118,10 +163,11 @@ class TokenBatchCreator(object): ...@@ -118,10 +163,11 @@ class TokenBatchCreator(object):
class SampleInfo(object): class SampleInfo(object):
def __init__(self, i, max_len, min_len): def __init__(self, i, lens):
self.i = i self.i = i
self.min_len = min_len # to be consistent with origianl reader implementation
self.max_len = max_len self.min_len = lens[0]
self.max_len = lens[0]
class MinMaxFilter(object): class MinMaxFilter(object):
...@@ -131,9 +177,8 @@ class MinMaxFilter(object): ...@@ -131,9 +177,8 @@ class MinMaxFilter(object):
self._creator = underlying_creator self._creator = underlying_creator
def append(self, info): def append(self, info):
if info.max_len > self._max_len or info.min_len < self._min_len: if (self._min_len is None or info.min_len >= self._min_len) and (
return self._max_len is None or info.max_len <= self._max_len):
else:
return self._creator.append(info) return self._creator.append(info)
@property @property
...@@ -151,22 +196,30 @@ class Seq2SeqDataset(Dataset): ...@@ -151,22 +196,30 @@ class Seq2SeqDataset(Dataset):
start_mark="<s>", start_mark="<s>",
end_mark="<e>", end_mark="<e>",
unk_mark="<unk>", unk_mark="<unk>",
only_src=False, trg_fpattern=None,
trg_fpattern=None): byte_data=False,
# convert str to bytes, and use byte data min_length=None,
# field_delimiter = field_delimiter.encode("utf8") max_length=None,
# token_delimiter = token_delimiter.encode("utf8") truncate=False):
# start_mark = start_mark.encode("utf8") if byte_data:
# end_mark = end_mark.encode("utf8") # The WMT16 bpe data used here seems including bytes can not be
# unk_mark = unk_mark.encode("utf8") # decoded by utf8. Thus convert str to bytes, and use byte data
self._src_vocab = self.load_dict(src_vocab_fpath) field_delimiter = field_delimiter.encode("utf8")
self._trg_vocab = self.load_dict(trg_vocab_fpath) token_delimiter = token_delimiter.encode("utf8")
start_mark = start_mark.encode("utf8")
end_mark = end_mark.encode("utf8")
unk_mark = unk_mark.encode("utf8")
self._byte_data = byte_data
self._src_vocab = self.load_dict(src_vocab_fpath, byte_data=byte_data)
self._trg_vocab = self.load_dict(trg_vocab_fpath, byte_data=byte_data)
self._bos_idx = self._src_vocab[start_mark] self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark] self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark] self._unk_idx = self._src_vocab[unk_mark]
self._only_src = only_src
self._field_delimiter = field_delimiter self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter self._token_delimiter = token_delimiter
self._min_length = min_length
self._max_length = max_length
self._truncate = truncate
self.load_src_trg_ids(fpattern, trg_fpattern) self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None): def load_src_trg_ids(self, fpattern, trg_fpattern=None):
...@@ -195,26 +248,32 @@ class Seq2SeqDataset(Dataset): ...@@ -195,26 +248,32 @@ class Seq2SeqDataset(Dataset):
self._sample_infos = [] self._sample_infos = []
slots = [self._src_seq_ids, self._trg_seq_ids] slots = [self._src_seq_ids, self._trg_seq_ids]
lens = []
for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)): for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
lens = [] fields = converters(line)
for field, slot in zip(converters(line), slots): lens = [len(field) for field in fields]
slot.append(field) sample = SampleInfo(i, lens)
lens.append(len(field)) if (self._min_length is None or
# self._sample_infos.append(SampleInfo(i, max(lens), min(lens))) sample.min_len >= self._min_length) and (
self._sample_infos.append(SampleInfo(i, lens[0], lens[0])) self._max_length is None or
sample.max_len <= self._max_length or self._truncate):
for field, slot in zip(fields, slots):
slot.append(field[:self._max_length] if self._truncate and
self._max_length is not None else field)
self._sample_infos.append(sample)
def _load_lines(self, fpattern, trg_fpattern=None): def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern) fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path" assert len(fpaths) > 0, "no matching file to the provided data path"
(f_mode, f_encoding,
endl) = ("rb", None, b"\n") if self._byte_data else ("r", "utf8",
"\n")
if trg_fpattern is None: if trg_fpattern is None:
for fpath in fpaths: for fpath in fpaths:
# with io.open(fpath, "rb") as f: with io.open(fpath, f_mode, encoding=f_encoding) as f:
with io.open(fpath, "r", encoding="utf8") as f:
for line in f: for line in f:
fields = line.strip("\n").split(self._field_delimiter) fields = line.strip(endl).split(self._field_delimiter)
yield fields yield fields
else: else:
# separated source and target language data files # separated source and target language data files
...@@ -228,24 +287,24 @@ class Seq2SeqDataset(Dataset): ...@@ -228,24 +287,24 @@ class Seq2SeqDataset(Dataset):
with that of source language" with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths): for fpath, trg_fpath in zip(fpaths, trg_fpaths):
# with io.open(fpath, "rb") as f: with io.open(fpath, f_mode, encoding=f_encoding) as f:
# with io.open(trg_fpath, "rb") as trg_f: with io.open(
with io.open(fpath, "r", encoding="utf8") as f: trg_fpath, f_mode, encoding=f_encoding) as trg_f:
with io.open(trg_fpath, "r", encoding="utf8") as trg_f:
for line in zip(f, trg_f): for line in zip(f, trg_f):
fields = [field.strip("\n") for field in line] fields = [field.strip(endl) for field in line]
yield fields yield fields
@staticmethod @staticmethod
def load_dict(dict_path, reverse=False): def load_dict(dict_path, reverse=False, byte_data=False):
word_dict = {} word_dict = {}
# with io.open(dict_path, "rb") as fdict: (f_mode, f_encoding,
with io.open(dict_path, "r", encoding="utf8") as fdict: endl) = ("rb", None, b"\n") if byte_data else ("r", "utf8", "\n")
with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if reverse: if reverse:
word_dict[idx] = line.strip("\n") word_dict[idx] = line.strip(endl)
else: else:
word_dict[line.strip("\n")] = idx word_dict[line.strip(endl)] = idx
return word_dict return word_dict
def get_vocab_summary(self): def get_vocab_summary(self):
...@@ -266,19 +325,21 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -266,19 +325,21 @@ class Seq2SeqBatchSampler(BatchSampler):
batch_size, batch_size,
pool_size=10000, pool_size=10000,
sort_type=SortType.NONE, sort_type=SortType.NONE,
min_length=0, min_length=None,
max_length=100, max_length=None,
shuffle=False, shuffle=False,
shuffle_batch=False, shuffle_batch=False,
use_token_batch=False, use_token_batch=False,
clip_last_batch=False, clip_last_batch=False,
seed=None): distribute_mode=True,
seed=0):
for arg, value in locals().items(): for arg, value in locals().items():
if arg != "self": if arg != "self":
setattr(self, "_" + arg, value) setattr(self, "_" + arg, value)
self._random = np.random self._random = np.random
self._random.seed(seed) self._random.seed(seed)
# for multi-devices # for multi-devices
self._distribute_mode = distribute_mode
self._nranks = ParallelEnv().nranks self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id self._device_id = ParallelEnv().dev_id
...@@ -337,11 +398,14 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -337,11 +398,14 @@ class Seq2SeqBatchSampler(BatchSampler):
# for multi-device # for multi-device
for batch_id, batch in enumerate(batches): for batch_id, batch in enumerate(batches):
if batch_id % self._nranks == self._local_rank: if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch] batch_indices = [info.i for info in batch]
yield batch_indices yield batch_indices
if self._local_rank > len(batches) % self._nranks: if self._distribute_mode and len(batches) % self._nranks != 0:
yield batch_indices if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices
def __len__(self): def __len__(self):
if not self._use_token_batch: if not self._use_token_batch:
...@@ -349,5 +413,6 @@ class Seq2SeqBatchSampler(BatchSampler): ...@@ -349,5 +413,6 @@ class Seq2SeqBatchSampler(BatchSampler):
len(self._dataset) + self._batch_size * self._nranks - 1) // ( len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks) self._batch_size * self._nranks)
else: else:
batch_number = 100 # TODO(guosheng): fix the uncertain length
batch_number = 1
return batch_number return batch_number
export CUDA_VISIBLE_DEVICES=0
python train.py \ python train.py \
--src_lang en --tar_lang vi \ --src_lang en --tar_lang vi \
--attention True \ --attention True \
......
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit
from text import DynamicDecode, RNN, RNNCell
from model import Model, Loss
class ConvBNPool(fluid.dygraph.Layer):
def __init__(self,
out_ch,
channels,
act="relu",
is_test=False,
pool=True,
use_cudnn=True):
super(ConvBNPool, self).__init__()
self.pool = pool
filter_size = 3
conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5
conv_param_0 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_0))
conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5
conv_param_1 = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, conv_std_1))
self.conv_0_layer = Conv2D(
channels[0],
out_ch[0],
3,
padding=1,
param_attr=conv_param_0,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_0_layer = BatchNorm(out_ch[0], act=act, is_test=is_test)
self.conv_1_layer = Conv2D(
out_ch[0],
num_filters=out_ch[1],
filter_size=3,
padding=1,
param_attr=conv_param_1,
bias_attr=False,
act=None,
use_cudnn=use_cudnn)
self.bn_1_layer = BatchNorm(out_ch[1], act=act, is_test=is_test)
if self.pool:
self.pool_layer = Pool2D(
pool_size=2,
pool_type='max',
pool_stride=2,
use_cudnn=use_cudnn,
ceil_mode=True)
def forward(self, inputs):
conv_0 = self.conv_0_layer(inputs)
bn_0 = self.bn_0_layer(conv_0)
conv_1 = self.conv_1_layer(bn_0)
bn_1 = self.bn_1_layer(conv_1)
if self.pool:
bn_pool = self.pool_layer(bn_1)
return bn_pool
return bn_1
class OCRConv(fluid.dygraph.Layer):
def __init__(self, is_test=False, use_cudnn=True):
super(OCRConv, self).__init__()
self.conv_bn_pool_1 = ConvBNPool(
[16, 16], [1, 16], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_2 = ConvBNPool(
[32, 32], [16, 32], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_3 = ConvBNPool(
[64, 64], [32, 64], is_test=is_test, use_cudnn=use_cudnn)
self.conv_bn_pool_4 = ConvBNPool(
[128, 128], [64, 128],
is_test=is_test,
pool=False,
use_cudnn=use_cudnn)
def forward(self, inputs):
inputs_1 = self.conv_bn_pool_1(inputs)
inputs_2 = self.conv_bn_pool_2(inputs_1)
inputs_3 = self.conv_bn_pool_3(inputs_2)
inputs_4 = self.conv_bn_pool_4(inputs_3)
return inputs_4
class SimpleAttention(fluid.dygraph.Layer):
def __init__(self, decoder_size):
super(SimpleAttention, self).__init__()
self.fc1 = Linear(decoder_size, decoder_size, bias_attr=False)
self.fc2 = Linear(decoder_size, 1, bias_attr=False)
def forward(self, encoder_vec, encoder_proj, decoder_state):
decoder_state = self.fc1(decoder_state)
decoder_state = fluid.layers.unsqueeze(decoder_state, [1])
mix = fluid.layers.elementwise_add(encoder_proj, decoder_state)
mix = fluid.layers.tanh(x=mix)
attn_score = self.fc2(mix)
attn_scores = layers.squeeze(attn_score, [2])
attn_scores = fluid.layers.softmax(attn_scores)
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=attn_scores, axis=0)
context = fluid.layers.reduce_sum(scaled, dim=1)
return context
class GRUCell(RNNCell):
def __init__(self,
input_size,
hidden_size,
param_attr=None,
bias_attr=None,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False):
super(GRUCell, self).__init__()
self.hidden_size = hidden_size
self.fc_layer = Linear(
input_size,
hidden_size * 3,
param_attr=param_attr,
bias_attr=False)
self.gru_unit = GRUUnit(
hidden_size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
def forward(self, inputs, states):
# step_outputs, new_states = cell(step_inputs, states)
# for GRUCell, `step_outputs` and `new_states` both are hidden
x = self.fc_layer(inputs)
hidden, _, _ = self.gru_unit(x, states)
return hidden, hidden
@property
def state_shape(self):
return [self.hidden_size]
class EncoderNet(fluid.dygraph.Layer):
def __init__(self,
decoder_size,
rnn_hidden_size=200,
is_test=False,
use_cudnn=True):
super(EncoderNet, self).__init__()
self.rnn_hidden_size = rnn_hidden_size
para_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
self.ocr_convs = OCRConv(is_test=is_test, use_cudnn=use_cudnn)
self.gru_forward_layer = RNN(
cell=GRUCell(
input_size=128 * 6, # channel * h
hidden_size=rnn_hidden_size,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu'),
is_reverse=False,
time_major=False)
self.gru_backward_layer = RNN(
cell=GRUCell(
input_size=128 * 6, # channel * h
hidden_size=rnn_hidden_size,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu'),
is_reverse=True,
time_major=False)
self.encoded_proj_fc = Linear(
rnn_hidden_size * 2, decoder_size, bias_attr=False)
def forward(self, inputs):
conv_features = self.ocr_convs(inputs)
transpose_conv_features = fluid.layers.transpose(
conv_features, perm=[0, 3, 1, 2])
sliced_feature = fluid.layers.reshape(
transpose_conv_features, [
-1, transpose_conv_features.shape[1],
transpose_conv_features.shape[2] *
transpose_conv_features.shape[3]
],
inplace=False)
gru_forward, _ = self.gru_forward_layer(sliced_feature)
gru_backward, _ = self.gru_backward_layer(sliced_feature)
encoded_vector = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=2)
encoded_proj = self.encoded_proj_fc(encoded_vector)
return gru_backward, encoded_vector, encoded_proj
class DecoderCell(RNNCell):
def __init__(self, encoder_size, decoder_size):
super(DecoderCell, self).__init__()
self.attention = SimpleAttention(decoder_size)
self.gru_cell = GRUCell(
input_size=encoder_size * 2 +
decoder_size, # encoded_vector.shape[-1] + embed_size
hidden_size=decoder_size)
def forward(self, current_word, states, encoder_vec, encoder_proj):
context = self.attention(encoder_vec, encoder_proj, states)
decoder_inputs = layers.concat([current_word, context], axis=1)
hidden, _ = self.gru_cell(decoder_inputs, states)
return hidden, hidden
class GRUDecoderWithAttention(fluid.dygraph.Layer):
def __init__(self, encoder_size, decoder_size, num_classes):
super(GRUDecoderWithAttention, self).__init__()
self.gru_attention = RNN(DecoderCell(encoder_size, decoder_size),
is_reverse=False,
time_major=False)
self.out_layer = Linear(
input_dim=decoder_size,
output_dim=num_classes + 2,
bias_attr=None,
act='softmax')
def forward(self, inputs, decoder_initial_states, encoder_vec,
encoder_proj):
out, _ = self.gru_attention(
inputs,
initial_states=decoder_initial_states,
encoder_vec=encoder_vec,
encoder_proj=encoder_proj)
predict = self.out_layer(out)
return predict
class OCRAttention(Model):
def __init__(self, num_classes, encoder_size, decoder_size,
word_vector_dim):
super(OCRAttention, self).__init__()
self.encoder_net = EncoderNet(decoder_size)
self.fc = Linear(
input_dim=encoder_size,
output_dim=decoder_size,
bias_attr=False,
act='relu')
self.embedding = Embedding(
[num_classes + 2, word_vector_dim], dtype='float32')
self.gru_decoder_with_attention = GRUDecoderWithAttention(
encoder_size, decoder_size, num_classes)
def forward(self, inputs, label_in):
gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs)
decoder_boot = self.fc(gru_backward[:, 0])
trg_embedding = self.embedding(label_in)
prediction = self.gru_decoder_with_attention(
trg_embedding, decoder_boot, encoded_vector, encoded_proj)
return prediction
class CrossEntropyCriterion(Loss):
def __init__(self):
super(CrossEntropyCriterion, self).__init__()
def forward(self, outputs, labels):
predict, (label, mask) = outputs[0], labels
loss = layers.cross_entropy(predict, label=label, soft_label=False)
loss = layers.elementwise_mul(loss, mask, axis=0)
loss = layers.reduce_sum(loss)
return loss
...@@ -28,7 +28,7 @@ from callbacks import ProgBarLogger ...@@ -28,7 +28,7 @@ from callbacks import ProgBarLogger
from args import parse_args from args import parse_args
from seq2seq_base import BaseModel, CrossEntropyCriterion from seq2seq_base import BaseModel, CrossEntropyCriterion
from seq2seq_attn import AttentionModel from seq2seq_attn import AttentionModel
from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_train_input from reader import create_data_loader
def do_train(args): def do_train(args):
...@@ -38,7 +38,6 @@ def do_train(args): ...@@ -38,7 +38,6 @@ def do_train(args):
if args.enable_ce: if args.enable_ce:
fluid.default_main_program().random_seed = 102 fluid.default_main_program().random_seed = 102
fluid.default_startup_program().random_seed = 102 fluid.default_startup_program().random_seed = 102
args.shuffle = False
# define model # define model
inputs = [ inputs = [
...@@ -54,64 +53,25 @@ def do_train(args): ...@@ -54,64 +53,25 @@ def do_train(args):
labels = [Input([None, None, 1], "int64", name="label"), ] labels = [Input([None, None, 1], "int64", name="label"), ]
# def dataloader # def dataloader
data_loaders = [None, None] train_loader, eval_loader = create_data_loader(args, device)
data_prefixes = [args.train_data_prefix, args.eval_data_prefix
] if args.eval_data_prefix else [args.train_data_prefix]
for i, data_prefix in enumerate(data_prefixes):
dataset = Seq2SeqDataset(
fpattern=data_prefix + "." + args.src_lang,
trg_fpattern=data_prefix + "." + args.tar_lang,
src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
trg_vocab_fpath=args.vocab_prefix + "." + args.tar_lang,
token_delimiter=None,
start_mark="<s>",
end_mark="</s>",
unk_mark="<unk>")
(args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
unk_id) = dataset.get_vocab_summary()
batch_sampler = Seq2SeqBatchSampler(
dataset=dataset,
use_token_batch=False,
batch_size=args.batch_size,
pool_size=args.batch_size * 20,
sort_type=SortType.POOL,
shuffle=args.shuffle)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
feed_list=None if fluid.in_dygraph_mode() else
[x.forward() for x in inputs + labels],
collate_fn=partial(
prepare_train_input,
bos_id=bos_id,
eos_id=eos_id,
pad_id=eos_id),
num_workers=0,
return_list=True)
data_loaders[i] = data_loader
train_loader, eval_loader = data_loaders
model_maker = AttentionModel if args.attention else BaseModel model_maker = AttentionModel if args.attention else BaseModel
model = model_maker(args.src_vocab_size, args.tar_vocab_size, model = model_maker(args.src_vocab_size, args.tar_vocab_size,
args.hidden_size, args.hidden_size, args.num_layers, args.hidden_size, args.hidden_size, args.num_layers,
args.dropout) args.dropout)
optimizer = fluid.optimizer.Adam(
learning_rate=args.learning_rate, parameter_list=model.parameters())
optimizer._grad_clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=args.max_grad_norm)
model.prepare( model.prepare(
fluid.optimizer.Adam( optimizer, CrossEntropyCriterion(), inputs=inputs, labels=labels)
learning_rate=args.learning_rate,
parameter_list=model.parameters()),
CrossEntropyCriterion(),
inputs=inputs,
labels=labels)
model.fit(train_data=train_loader, model.fit(train_data=train_loader,
eval_data=eval_loader, eval_data=eval_loader,
epochs=args.max_epoch, epochs=args.max_epoch,
eval_freq=1, eval_freq=1,
save_freq=1, save_freq=1,
save_dir=args.model_path, save_dir=args.model_path,
log_freq=1, log_freq=1)
verbose=2)
if __name__ == "__main__": if __name__ == "__main__":
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import paddle.fluid.profiler as profiler
import paddle.fluid as fluid
import data_reader
from paddle.fluid.dygraph.base import to_variable
import argparse
import functools
from utility import add_arguments, print_arguments, get_attention_feeder_data
from model import Input, set_device
from nets import OCRAttention, CrossEntropyCriterion
from eval import evaluate
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('epoch_num', int, 30, "Epoch number.")
add_arg('lr', float, 0.001, "Learning rate.")
add_arg('lr_decay_strategy', str, "", "Learning rate decay strategy.")
add_arg('log_period', int, 200, "Log period.")
add_arg('save_model_period', int, 2000, "Save model period. '-1' means never saving the model.")
add_arg('eval_period', int, 2000, "Evaluate period. '-1' means never evaluating the model.")
add_arg('save_model_dir', str, "./output", "The directory the model to be saved to.")
add_arg('train_images', str, None, "The directory of images to be used for training.")
add_arg('train_list', str, None, "The list file of images to be used for training.")
add_arg('test_images', str, None, "The directory of images to be used for test.")
add_arg('test_list', str, None, "The list file of images to be used for training.")
add_arg('init_model', str, None, "The init model file of directory.")
add_arg('use_gpu', bool, True, "Whether use GPU to train.")
add_arg('parallel', bool, False, "Whether use parallel training.")
add_arg('profile', bool, False, "Whether to use profiling.")
add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
add_arg('skip_test', bool, False, "Whether to skip test phase.")
# model hyper paramters
add_arg('encoder_size', int, 200, "Encoder size.")
add_arg('decoder_size', int, 128, "Decoder size.")
add_arg('word_vector_dim', int, 128, "Word vector dim.")
add_arg('num_classes', int, 95, "Number classes.")
add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
add_arg('dynamic', bool, False, "Whether to use dygraph.")
def train(args):
device = set_device("gpu" if args.use_gpu else "cpu")
fluid.enable_dygraph(device) if args.dynamic else None
ocr_attention = OCRAttention(encoder_size=args.encoder_size, decoder_size=args.decoder_size,
num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
LR = args.lr
if args.lr_decay_strategy == "piecewise_decay":
learning_rate = fluid.layers.piecewise_decay([200000, 250000], [LR, LR * 0.1, LR * 0.01])
else:
learning_rate = LR
optimizer = fluid.optimizer.Adam(learning_rate=learning_rate, parameter_list=ocr_attention.parameters())
# grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(args.gradient_clip)
inputs = [
Input([None, 1, 48, 384], "float32", name="pixel"),
Input([None, None], "int64", name="label_in"),
]
labels = [
Input([None, None], "int64", name="label_out"),
Input([None, None], "float32", name="mask")]
ocr_attention.prepare(optimizer, CrossEntropyCriterion(), inputs=inputs, labels=labels)
train_reader = data_reader.data_reader(
args.batch_size,
shuffle=True,
images_dir=args.train_images,
list_file=args.train_list,
data_type='train')
# test_reader = data_reader.data_reader(
# args.batch_size,
# images_dir=args.test_images,
# list_file=args.test_list,
# data_type="test")
# if not os.path.exists(args.save_model_dir):
# os.makedirs(args.save_model_dir)
total_step = 0
epoch_num = args.epoch_num
for epoch in range(epoch_num):
batch_id = 0
total_loss = 0.0
for data in train_reader():
total_step += 1
data_dict = get_attention_feeder_data(data)
pixel = data_dict["pixel"]
label_in = data_dict["label_in"].reshape([pixel.shape[0], -1])
label_out = data_dict["label_out"].reshape([pixel.shape[0], -1])
mask = data_dict["mask"].reshape(label_out.shape).astype("float32")
avg_loss = ocr_attention.train(inputs=[pixel, label_in], labels=[label_out, mask])[0]
total_loss += avg_loss
if True:#batch_id > 0 and batch_id % args.log_period == 0:
print("epoch: {}, batch_id: {}, loss {}".format(epoch, batch_id,
total_loss / args.batch_size / args.log_period))
total_loss = 0.0
batch_id += 1
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
if args.profile:
if args.use_gpu:
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
train(args)
else:
with profiler.profiler("CPU", sorted_key='total') as cpuprof:
train(args)
else:
train(args)
\ No newline at end of file
...@@ -289,7 +289,6 @@ class Seq2SeqDataset(Dataset): ...@@ -289,7 +289,6 @@ class Seq2SeqDataset(Dataset):
start_mark="<s>", start_mark="<s>",
end_mark="<e>", end_mark="<e>",
unk_mark="<unk>", unk_mark="<unk>",
only_src=False,
trg_fpattern=None, trg_fpattern=None,
byte_data=False): byte_data=False):
if byte_data: if byte_data:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册