diff --git a/seq2seq/README.md b/seq2seq/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..ef8bfd17fffc60cbf5b8809e52af75a1bbceedb3
--- /dev/null
+++ b/seq2seq/README.md
@@ -0,0 +1,180 @@
+运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.7版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。
+
+# Sequence to Sequence (Seq2Seq)
+
+以下是本范例模型的简要目录结构及说明:
+
+```
+.
+├── README.md # 文档,本文件
+├── args.py # 训练、预测以及模型参数配置程序
+├── reader.py # 数据读入程序
+├── download.py # 数据下载程序
+├── train.py # 训练主程序
+├── infer.py # 预测主程序
+├── run.sh # 默认配置的启动脚本
+├── infer.sh # 默认配置的解码脚本
+├── attention_model.py # 带注意力机制的翻译模型程序
+└── base_model.py # 无注意力机制的翻译模型程序
+```
+
+## 简介
+
+Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。
+
+本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html)。
+
+## 模型概览
+
+本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。
+
+## 数据介绍
+
+本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集
+
+### 数据获取
+
+```
+python download.py
+```
+
+## 模型训练
+
+`run.sh`包含训练程序的主函数,要使用默认参数开始训练,只需要简单地执行:
+
+```
+sh run.sh
+```
+
+默认使用带有注意力机制的RNN模型,可以通过修改 `attention` 参数为False来训练不带注意力机制的RNN模型。
+
+```sh
+export CUDA_VISIBLE_DEVICES=0
+
+python train.py \
+ --src_lang en --tar_lang vi \
+ --attention True \
+ --num_layers 2 \
+ --hidden_size 512 \
+ --src_vocab_size 17191 \
+ --tar_vocab_size 7709 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --train_data_prefix data/en-vi/train \
+ --eval_data_prefix data/en-vi/tst2012 \
+ --test_data_prefix data/en-vi/tst2013 \
+ --vocab_prefix data/en-vi/vocab \
+ --use_gpu True \
+ --model_path ./attention_models
+```
+
+训练程序会在每个epoch训练结束之后,save一次模型。
+
+
+默认使用动态图模式进行训练,可以通过设置 `eager_run` 参数为False来以静态图模式进行训练,如下:
+
+```sh
+export CUDA_VISIBLE_DEVICES=0
+
+python train.py \
+ --src_lang en --tar_lang vi \
+ --attention True \
+ --num_layers 2 \
+ --hidden_size 512 \
+ --src_vocab_size 17191 \
+ --tar_vocab_size 7709 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --train_data_prefix data/en-vi/train \
+ --eval_data_prefix data/en-vi/tst2012 \
+ --test_data_prefix data/en-vi/tst2013 \
+ --vocab_prefix data/en-vi/vocab \
+ --use_gpu True \
+ --model_path ./attention_models \
+ --eager_run False
+```
+
+## 模型预测
+
+当模型训练完成之后, 可以利用infer.sh的脚本进行预测,默认使用beam search的方法进行预测,加载第10个epoch的模型进行预测,对test的数据集进行解码
+
+```
+sh infer.sh
+```
+
+如果想预测别的数据文件,只需要将 --infer_file参数进行修改。
+
+```sh
+export CUDA_VISIBLE_DEVICES=0
+
+python infer.py \
+ --attention True \
+ --src_lang en --tar_lang vi \
+ --num_layers 2 \
+ --hidden_size 512 \
+ --src_vocab_size 17191 \
+ --tar_vocab_size 7709 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --vocab_prefix data/en-vi/vocab \
+ --infer_file data/en-vi/tst2013.en \
+ --reload_model attention_models/epoch_10 \
+ --infer_output_file attention_infer_output/infer_output.txt \
+ --beam_size 10 \
+ --use_gpu True
+```
+
+和训练类似,预测时同样可以以静态图模式进行,如下:
+
+```sh
+export CUDA_VISIBLE_DEVICES=0
+
+python infer.py \
+ --attention True \
+ --src_lang en --tar_lang vi \
+ --num_layers 2 \
+ --hidden_size 512 \
+ --src_vocab_size 17191 \
+ --tar_vocab_size 7709 \
+ --batch_size 128 \
+ --dropout 0.2 \
+ --init_scale 0.1 \
+ --max_grad_norm 5.0 \
+ --vocab_prefix data/en-vi/vocab \
+ --infer_file data/en-vi/tst2013.en \
+ --reload_model attention_models/epoch_10 \
+ --infer_output_file attention_infer_output/infer_output.txt \
+ --beam_size 10 \
+ --use_gpu True
+ --eager_run False
+```
+
+## 效果评价
+
+使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下:
+
+```sh
+mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt
+```
+
+每个模型分别训练了10次,单次取第10个epoch保存的模型进行预测,取beam_size=10。效果如下(为了便于观察,对10次结果按照升序进行了排序):
+
+```
+> no attention
+tst2012 BLEU:
+[10.75 10.85 10.9 10.94 10.97 11.01 11.01 11.04 11.13 11.4]
+tst2013 BLEU:
+[10.71 10.71 10.74 10.76 10.91 10.94 11.02 11.16 11.21 11.44]
+
+> with attention
+tst2012 BLEU:
+[21.14 22.34 22.54 22.65 22.71 22.71 23.08 23.15 23.3 23.4]
+tst2013 BLEU:
+[23.41 24.79 25.11 25.12 25.19 25.24 25.39 25.61 25.61 25.63]
+```
diff --git a/seq2seq/reader.py b/seq2seq/reader.py
index 6f007bdde0ca167a7172def436c9ec98e1b75059..ebdbb47266e2c43b6e1ec862951f0f83bfe5cab0 100644
--- a/seq2seq/reader.py
+++ b/seq2seq/reader.py
@@ -17,13 +17,58 @@ from __future__ import division
from __future__ import print_function
import glob
+import six
+import os
import io
-import numpy as np
import itertools
+from functools import partial
+
+import numpy as np
+import paddle.fluid as fluid
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.io import BatchSampler, DataLoader, Dataset
+def create_data_loader(args, device, for_train=True):
+ data_loaders = [None, None]
+ data_prefixes = [args.train_data_prefix, args.eval_data_prefix
+ ] if args.eval_data_prefix else [args.train_data_prefix]
+ for i, data_prefix in enumerate(data_prefixes):
+ dataset = Seq2SeqDataset(
+ fpattern=data_prefix + "." + args.src_lang,
+ trg_fpattern=data_prefix + "." + args.tar_lang,
+ src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
+ trg_vocab_fpath=args.vocab_prefix + "." + args.tar_lang,
+ token_delimiter=None,
+ start_mark="",
+ end_mark="",
+ unk_mark="",
+ max_length=args.max_len if i == 0 else None,
+ truncate=True)
+ (args.src_vocab_size, args.tar_vocab_size, bos_id, eos_id,
+ unk_id) = dataset.get_vocab_summary()
+ batch_sampler = Seq2SeqBatchSampler(
+ dataset=dataset,
+ use_token_batch=False,
+ batch_size=args.batch_size,
+ pool_size=args.batch_size * 20,
+ sort_type=SortType.POOL,
+ shuffle=False if args.enable_ce else True)
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_sampler=batch_sampler,
+ places=device,
+ collate_fn=partial(
+ prepare_train_input,
+ bos_id=bos_id,
+ eos_id=eos_id,
+ pad_id=eos_id),
+ num_workers=0,
+ return_list=True)
+ data_loaders[i] = data_loader
+ return data_loaders
+
+
def prepare_train_input(insts, bos_id, eos_id, pad_id):
src, src_length = pad_batch_data(
[inst[0] for inst in insts], pad_id=pad_id)
@@ -118,10 +163,11 @@ class TokenBatchCreator(object):
class SampleInfo(object):
- def __init__(self, i, max_len, min_len):
+ def __init__(self, i, lens):
self.i = i
- self.min_len = min_len
- self.max_len = max_len
+ # to be consistent with origianl reader implementation
+ self.min_len = lens[0]
+ self.max_len = lens[0]
class MinMaxFilter(object):
@@ -131,9 +177,8 @@ class MinMaxFilter(object):
self._creator = underlying_creator
def append(self, info):
- if info.max_len > self._max_len or info.min_len < self._min_len:
- return
- else:
+ if (self._min_len is None or info.min_len >= self._min_len) and (
+ self._max_len is None or info.max_len <= self._max_len):
return self._creator.append(info)
@property
@@ -151,22 +196,30 @@ class Seq2SeqDataset(Dataset):
start_mark="",
end_mark="",
unk_mark="",
- only_src=False,
- trg_fpattern=None):
- # convert str to bytes, and use byte data
- # field_delimiter = field_delimiter.encode("utf8")
- # token_delimiter = token_delimiter.encode("utf8")
- # start_mark = start_mark.encode("utf8")
- # end_mark = end_mark.encode("utf8")
- # unk_mark = unk_mark.encode("utf8")
- self._src_vocab = self.load_dict(src_vocab_fpath)
- self._trg_vocab = self.load_dict(trg_vocab_fpath)
+ trg_fpattern=None,
+ byte_data=False,
+ min_length=None,
+ max_length=None,
+ truncate=False):
+ if byte_data:
+ # The WMT16 bpe data used here seems including bytes can not be
+ # decoded by utf8. Thus convert str to bytes, and use byte data
+ field_delimiter = field_delimiter.encode("utf8")
+ token_delimiter = token_delimiter.encode("utf8")
+ start_mark = start_mark.encode("utf8")
+ end_mark = end_mark.encode("utf8")
+ unk_mark = unk_mark.encode("utf8")
+ self._byte_data = byte_data
+ self._src_vocab = self.load_dict(src_vocab_fpath, byte_data=byte_data)
+ self._trg_vocab = self.load_dict(trg_vocab_fpath, byte_data=byte_data)
self._bos_idx = self._src_vocab[start_mark]
self._eos_idx = self._src_vocab[end_mark]
self._unk_idx = self._src_vocab[unk_mark]
- self._only_src = only_src
self._field_delimiter = field_delimiter
self._token_delimiter = token_delimiter
+ self._min_length = min_length
+ self._max_length = max_length
+ self._truncate = truncate
self.load_src_trg_ids(fpattern, trg_fpattern)
def load_src_trg_ids(self, fpattern, trg_fpattern=None):
@@ -195,26 +248,32 @@ class Seq2SeqDataset(Dataset):
self._sample_infos = []
slots = [self._src_seq_ids, self._trg_seq_ids]
- lens = []
for i, line in enumerate(self._load_lines(fpattern, trg_fpattern)):
- lens = []
- for field, slot in zip(converters(line), slots):
- slot.append(field)
- lens.append(len(field))
- # self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
- self._sample_infos.append(SampleInfo(i, lens[0], lens[0]))
+ fields = converters(line)
+ lens = [len(field) for field in fields]
+ sample = SampleInfo(i, lens)
+ if (self._min_length is None or
+ sample.min_len >= self._min_length) and (
+ self._max_length is None or
+ sample.max_len <= self._max_length or self._truncate):
+ for field, slot in zip(fields, slots):
+ slot.append(field[:self._max_length] if self._truncate and
+ self._max_length is not None else field)
+ self._sample_infos.append(sample)
def _load_lines(self, fpattern, trg_fpattern=None):
fpaths = glob.glob(fpattern)
fpaths = sorted(fpaths) # TODO: Add custum sort
assert len(fpaths) > 0, "no matching file to the provided data path"
+ (f_mode, f_encoding,
+ endl) = ("rb", None, b"\n") if self._byte_data else ("r", "utf8",
+ "\n")
if trg_fpattern is None:
for fpath in fpaths:
- # with io.open(fpath, "rb") as f:
- with io.open(fpath, "r", encoding="utf8") as f:
+ with io.open(fpath, f_mode, encoding=f_encoding) as f:
for line in f:
- fields = line.strip("\n").split(self._field_delimiter)
+ fields = line.strip(endl).split(self._field_delimiter)
yield fields
else:
# separated source and target language data files
@@ -228,24 +287,24 @@ class Seq2SeqDataset(Dataset):
with that of source language"
for fpath, trg_fpath in zip(fpaths, trg_fpaths):
- # with io.open(fpath, "rb") as f:
- # with io.open(trg_fpath, "rb") as trg_f:
- with io.open(fpath, "r", encoding="utf8") as f:
- with io.open(trg_fpath, "r", encoding="utf8") as trg_f:
+ with io.open(fpath, f_mode, encoding=f_encoding) as f:
+ with io.open(
+ trg_fpath, f_mode, encoding=f_encoding) as trg_f:
for line in zip(f, trg_f):
- fields = [field.strip("\n") for field in line]
+ fields = [field.strip(endl) for field in line]
yield fields
@staticmethod
- def load_dict(dict_path, reverse=False):
+ def load_dict(dict_path, reverse=False, byte_data=False):
word_dict = {}
- # with io.open(dict_path, "rb") as fdict:
- with io.open(dict_path, "r", encoding="utf8") as fdict:
+ (f_mode, f_encoding,
+ endl) = ("rb", None, b"\n") if byte_data else ("r", "utf8", "\n")
+ with io.open(dict_path, f_mode, encoding=f_encoding) as fdict:
for idx, line in enumerate(fdict):
if reverse:
- word_dict[idx] = line.strip("\n")
+ word_dict[idx] = line.strip(endl)
else:
- word_dict[line.strip("\n")] = idx
+ word_dict[line.strip(endl)] = idx
return word_dict
def get_vocab_summary(self):
@@ -266,19 +325,21 @@ class Seq2SeqBatchSampler(BatchSampler):
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
- min_length=0,
- max_length=100,
+ min_length=None,
+ max_length=None,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
- seed=None):
+ distribute_mode=True,
+ seed=0):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
+ self._distribute_mode = distribute_mode
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self._device_id = ParallelEnv().dev_id
@@ -337,11 +398,14 @@ class Seq2SeqBatchSampler(BatchSampler):
# for multi-device
for batch_id, batch in enumerate(batches):
- if batch_id % self._nranks == self._local_rank:
+ if not self._distribute_mode or (
+ batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
- if self._local_rank > len(batches) % self._nranks:
- yield batch_indices
+ if self._distribute_mode and len(batches) % self._nranks != 0:
+ if self._local_rank >= len(batches) % self._nranks:
+ # use previous data to pad
+ yield batch_indices
def __len__(self):
if not self._use_token_batch:
@@ -349,5 +413,6 @@ class Seq2SeqBatchSampler(BatchSampler):
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
- batch_number = 100
+ # TODO(guosheng): fix the uncertain length
+ batch_number = 1
return batch_number
diff --git a/seq2seq/run.sh b/seq2seq/run.sh
index 2fe8b7a0700ae434c1015375a6080ccfeaf0ca03..4872fc996a8a86118acf5f47d8ccfd8e9fc48f11 100644
--- a/seq2seq/run.sh
+++ b/seq2seq/run.sh
@@ -1,3 +1,5 @@
+export CUDA_VISIBLE_DEVICES=0
+
python train.py \
--src_lang en --tar_lang vi \
--attention True \
diff --git a/seq2seq/seq2seq_add_attn.py b/seq2seq/seq2seq_add_attn.py
deleted file mode 100644
index ca0a4739f0c26860f00c486f58efbea638cd0ad5..0000000000000000000000000000000000000000
--- a/seq2seq/seq2seq_add_attn.py
+++ /dev/null
@@ -1,293 +0,0 @@
-import numpy as np
-import paddle.fluid as fluid
-import paddle.fluid.layers as layers
-from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm, Embedding, GRUUnit
-
-from text import DynamicDecode, RNN, RNNCell
-from model import Model, Loss
-
-
-class ConvBNPool(fluid.dygraph.Layer):
- def __init__(self,
- out_ch,
- channels,
- act="relu",
- is_test=False,
- pool=True,
- use_cudnn=True):
- super(ConvBNPool, self).__init__()
- self.pool = pool
-
- filter_size = 3
- conv_std_0 = (2.0 / (filter_size**2 * channels[0]))**0.5
- conv_param_0 = fluid.ParamAttr(
- initializer=fluid.initializer.Normal(0.0, conv_std_0))
-
- conv_std_1 = (2.0 / (filter_size**2 * channels[1]))**0.5
- conv_param_1 = fluid.ParamAttr(
- initializer=fluid.initializer.Normal(0.0, conv_std_1))
-
- self.conv_0_layer = Conv2D(
- channels[0],
- out_ch[0],
- 3,
- padding=1,
- param_attr=conv_param_0,
- bias_attr=False,
- act=None,
- use_cudnn=use_cudnn)
- self.bn_0_layer = BatchNorm(out_ch[0], act=act, is_test=is_test)
- self.conv_1_layer = Conv2D(
- out_ch[0],
- num_filters=out_ch[1],
- filter_size=3,
- padding=1,
- param_attr=conv_param_1,
- bias_attr=False,
- act=None,
- use_cudnn=use_cudnn)
- self.bn_1_layer = BatchNorm(out_ch[1], act=act, is_test=is_test)
-
- if self.pool:
- self.pool_layer = Pool2D(
- pool_size=2,
- pool_type='max',
- pool_stride=2,
- use_cudnn=use_cudnn,
- ceil_mode=True)
-
- def forward(self, inputs):
- conv_0 = self.conv_0_layer(inputs)
- bn_0 = self.bn_0_layer(conv_0)
- conv_1 = self.conv_1_layer(bn_0)
- bn_1 = self.bn_1_layer(conv_1)
- if self.pool:
- bn_pool = self.pool_layer(bn_1)
-
- return bn_pool
- return bn_1
-
-
-class OCRConv(fluid.dygraph.Layer):
- def __init__(self, is_test=False, use_cudnn=True):
- super(OCRConv, self).__init__()
- self.conv_bn_pool_1 = ConvBNPool(
- [16, 16], [1, 16], is_test=is_test, use_cudnn=use_cudnn)
- self.conv_bn_pool_2 = ConvBNPool(
- [32, 32], [16, 32], is_test=is_test, use_cudnn=use_cudnn)
- self.conv_bn_pool_3 = ConvBNPool(
- [64, 64], [32, 64], is_test=is_test, use_cudnn=use_cudnn)
- self.conv_bn_pool_4 = ConvBNPool(
- [128, 128], [64, 128],
- is_test=is_test,
- pool=False,
- use_cudnn=use_cudnn)
-
- def forward(self, inputs):
- inputs_1 = self.conv_bn_pool_1(inputs)
- inputs_2 = self.conv_bn_pool_2(inputs_1)
- inputs_3 = self.conv_bn_pool_3(inputs_2)
- inputs_4 = self.conv_bn_pool_4(inputs_3)
-
- return inputs_4
-
-
-class SimpleAttention(fluid.dygraph.Layer):
- def __init__(self, decoder_size):
- super(SimpleAttention, self).__init__()
-
- self.fc1 = Linear(decoder_size, decoder_size, bias_attr=False)
- self.fc2 = Linear(decoder_size, 1, bias_attr=False)
-
- def forward(self, encoder_vec, encoder_proj, decoder_state):
- decoder_state = self.fc1(decoder_state)
- decoder_state = fluid.layers.unsqueeze(decoder_state, [1])
-
- mix = fluid.layers.elementwise_add(encoder_proj, decoder_state)
- mix = fluid.layers.tanh(x=mix)
-
- attn_score = self.fc2(mix)
- attn_scores = layers.squeeze(attn_score, [2])
- attn_scores = fluid.layers.softmax(attn_scores)
-
- scaled = fluid.layers.elementwise_mul(
- x=encoder_vec, y=attn_scores, axis=0)
-
- context = fluid.layers.reduce_sum(scaled, dim=1)
- return context
-
-
-class GRUCell(RNNCell):
- def __init__(self,
- input_size,
- hidden_size,
- param_attr=None,
- bias_attr=None,
- gate_activation='sigmoid',
- candidate_activation='tanh',
- origin_mode=False):
- super(GRUCell, self).__init__()
- self.hidden_size = hidden_size
- self.fc_layer = Linear(
- input_size,
- hidden_size * 3,
- param_attr=param_attr,
- bias_attr=False)
-
- self.gru_unit = GRUUnit(
- hidden_size * 3,
- param_attr=param_attr,
- bias_attr=bias_attr,
- activation=candidate_activation,
- gate_activation=gate_activation,
- origin_mode=origin_mode)
-
- def forward(self, inputs, states):
- # step_outputs, new_states = cell(step_inputs, states)
- # for GRUCell, `step_outputs` and `new_states` both are hidden
- x = self.fc_layer(inputs)
- hidden, _, _ = self.gru_unit(x, states)
- return hidden, hidden
-
- @property
- def state_shape(self):
- return [self.hidden_size]
-
-
-class EncoderNet(fluid.dygraph.Layer):
- def __init__(self,
- decoder_size,
- rnn_hidden_size=200,
- is_test=False,
- use_cudnn=True):
- super(EncoderNet, self).__init__()
- self.rnn_hidden_size = rnn_hidden_size
- para_attr = fluid.ParamAttr(
- initializer=fluid.initializer.Normal(0.0, 0.02))
- bias_attr = fluid.ParamAttr(
- initializer=fluid.initializer.Normal(0.0, 0.02), learning_rate=2.0)
- self.ocr_convs = OCRConv(is_test=is_test, use_cudnn=use_cudnn)
-
- self.gru_forward_layer = RNN(
- cell=GRUCell(
- input_size=128 * 6, # channel * h
- hidden_size=rnn_hidden_size,
- param_attr=para_attr,
- bias_attr=bias_attr,
- candidate_activation='relu'),
- is_reverse=False,
- time_major=False)
- self.gru_backward_layer = RNN(
- cell=GRUCell(
- input_size=128 * 6, # channel * h
- hidden_size=rnn_hidden_size,
- param_attr=para_attr,
- bias_attr=bias_attr,
- candidate_activation='relu'),
- is_reverse=True,
- time_major=False)
-
- self.encoded_proj_fc = Linear(
- rnn_hidden_size * 2, decoder_size, bias_attr=False)
-
- def forward(self, inputs):
- conv_features = self.ocr_convs(inputs)
- transpose_conv_features = fluid.layers.transpose(
- conv_features, perm=[0, 3, 1, 2])
-
- sliced_feature = fluid.layers.reshape(
- transpose_conv_features, [
- -1, transpose_conv_features.shape[1],
- transpose_conv_features.shape[2] *
- transpose_conv_features.shape[3]
- ],
- inplace=False)
-
- gru_forward, _ = self.gru_forward_layer(sliced_feature)
-
- gru_backward, _ = self.gru_backward_layer(sliced_feature)
-
- encoded_vector = fluid.layers.concat(
- input=[gru_forward, gru_backward], axis=2)
-
- encoded_proj = self.encoded_proj_fc(encoded_vector)
-
- return gru_backward, encoded_vector, encoded_proj
-
-
-class DecoderCell(RNNCell):
- def __init__(self, encoder_size, decoder_size):
- super(DecoderCell, self).__init__()
- self.attention = SimpleAttention(decoder_size)
- self.gru_cell = GRUCell(
- input_size=encoder_size * 2 +
- decoder_size, # encoded_vector.shape[-1] + embed_size
- hidden_size=decoder_size)
-
- def forward(self, current_word, states, encoder_vec, encoder_proj):
- context = self.attention(encoder_vec, encoder_proj, states)
- decoder_inputs = layers.concat([current_word, context], axis=1)
- hidden, _ = self.gru_cell(decoder_inputs, states)
- return hidden, hidden
-
-
-class GRUDecoderWithAttention(fluid.dygraph.Layer):
- def __init__(self, encoder_size, decoder_size, num_classes):
- super(GRUDecoderWithAttention, self).__init__()
- self.gru_attention = RNN(DecoderCell(encoder_size, decoder_size),
- is_reverse=False,
- time_major=False)
- self.out_layer = Linear(
- input_dim=decoder_size,
- output_dim=num_classes + 2,
- bias_attr=None,
- act='softmax')
-
- def forward(self, inputs, decoder_initial_states, encoder_vec,
- encoder_proj):
- out, _ = self.gru_attention(
- inputs,
- initial_states=decoder_initial_states,
- encoder_vec=encoder_vec,
- encoder_proj=encoder_proj)
- predict = self.out_layer(out)
- return predict
-
-
-class OCRAttention(Model):
- def __init__(self, num_classes, encoder_size, decoder_size,
- word_vector_dim):
- super(OCRAttention, self).__init__()
- self.encoder_net = EncoderNet(decoder_size)
- self.fc = Linear(
- input_dim=encoder_size,
- output_dim=decoder_size,
- bias_attr=False,
- act='relu')
- self.embedding = Embedding(
- [num_classes + 2, word_vector_dim], dtype='float32')
- self.gru_decoder_with_attention = GRUDecoderWithAttention(
- encoder_size, decoder_size, num_classes)
-
- def forward(self, inputs, label_in):
- gru_backward, encoded_vector, encoded_proj = self.encoder_net(inputs)
-
- decoder_boot = self.fc(gru_backward[:, 0])
- trg_embedding = self.embedding(label_in)
- prediction = self.gru_decoder_with_attention(
- trg_embedding, decoder_boot, encoded_vector, encoded_proj)
-
- return prediction
-
-
-class CrossEntropyCriterion(Loss):
- def __init__(self):
- super(CrossEntropyCriterion, self).__init__()
-
- def forward(self, outputs, labels):
- predict, (label, mask) = outputs[0], labels
-
- loss = layers.cross_entropy(predict, label=label, soft_label=False)
- loss = layers.elementwise_mul(loss, mask, axis=0)
- loss = layers.reduce_sum(loss)
- return loss
diff --git a/seq2seq/train.py b/seq2seq/train.py
index 23dfccb5ad3c5f2527e43874da59b46dde3a51c0..9e809b0000052ef8e482c8a0acf8ca95f955e880 100644
--- a/seq2seq/train.py
+++ b/seq2seq/train.py
@@ -28,7 +28,7 @@ from callbacks import ProgBarLogger
from args import parse_args
from seq2seq_base import BaseModel, CrossEntropyCriterion
from seq2seq_attn import AttentionModel
-from reader import Seq2SeqDataset, Seq2SeqBatchSampler, SortType, prepare_train_input
+from reader import create_data_loader
def do_train(args):
@@ -38,7 +38,6 @@ def do_train(args):
if args.enable_ce:
fluid.default_main_program().random_seed = 102
fluid.default_startup_program().random_seed = 102
- args.shuffle = False
# define model
inputs = [
@@ -54,64 +53,25 @@ def do_train(args):
labels = [Input([None, None, 1], "int64", name="label"), ]
# def dataloader
- data_loaders = [None, None]
- data_prefixes = [args.train_data_prefix, args.eval_data_prefix
- ] if args.eval_data_prefix else [args.train_data_prefix]
- for i, data_prefix in enumerate(data_prefixes):
- dataset = Seq2SeqDataset(
- fpattern=data_prefix + "." + args.src_lang,
- trg_fpattern=data_prefix + "." + args.tar_lang,
- src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
- trg_vocab_fpath=args.vocab_prefix + "." + args.tar_lang,
- token_delimiter=None,
- start_mark="",
- end_mark="",
- unk_mark="")
- (args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
- unk_id) = dataset.get_vocab_summary()
- batch_sampler = Seq2SeqBatchSampler(
- dataset=dataset,
- use_token_batch=False,
- batch_size=args.batch_size,
- pool_size=args.batch_size * 20,
- sort_type=SortType.POOL,
- shuffle=args.shuffle)
- data_loader = DataLoader(
- dataset=dataset,
- batch_sampler=batch_sampler,
- places=device,
- feed_list=None if fluid.in_dygraph_mode() else
- [x.forward() for x in inputs + labels],
- collate_fn=partial(
- prepare_train_input,
- bos_id=bos_id,
- eos_id=eos_id,
- pad_id=eos_id),
- num_workers=0,
- return_list=True)
- data_loaders[i] = data_loader
- train_loader, eval_loader = data_loaders
+ train_loader, eval_loader = create_data_loader(args, device)
model_maker = AttentionModel if args.attention else BaseModel
model = model_maker(args.src_vocab_size, args.tar_vocab_size,
args.hidden_size, args.hidden_size, args.num_layers,
args.dropout)
-
+ optimizer = fluid.optimizer.Adam(
+ learning_rate=args.learning_rate, parameter_list=model.parameters())
+ optimizer._grad_clip = fluid.clip.GradientClipByGlobalNorm(
+ clip_norm=args.max_grad_norm)
model.prepare(
- fluid.optimizer.Adam(
- learning_rate=args.learning_rate,
- parameter_list=model.parameters()),
- CrossEntropyCriterion(),
- inputs=inputs,
- labels=labels)
+ optimizer, CrossEntropyCriterion(), inputs=inputs, labels=labels)
model.fit(train_data=train_loader,
eval_data=eval_loader,
epochs=args.max_epoch,
eval_freq=1,
save_freq=1,
save_dir=args.model_path,
- log_freq=1,
- verbose=2)
+ log_freq=1)
if __name__ == "__main__":
diff --git a/seq2seq/train_ocr.py b/seq2seq/train_ocr.py
deleted file mode 100644
index 2dd7835b2258256f17bfa8ece70875531a5e8c39..0000000000000000000000000000000000000000
--- a/seq2seq/train_ocr.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from __future__ import print_function
-
-import os
-import sys
-sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
-
-import paddle.fluid.profiler as profiler
-import paddle.fluid as fluid
-
-import data_reader
-
-from paddle.fluid.dygraph.base import to_variable
-import argparse
-import functools
-from utility import add_arguments, print_arguments, get_attention_feeder_data
-from model import Input, set_device
-from nets import OCRAttention, CrossEntropyCriterion
-from eval import evaluate
-
-parser = argparse.ArgumentParser(description=__doc__)
-add_arg = functools.partial(add_arguments, argparser=parser)
-# yapf: disable
-add_arg('batch_size', int, 32, "Minibatch size.")
-add_arg('epoch_num', int, 30, "Epoch number.")
-add_arg('lr', float, 0.001, "Learning rate.")
-add_arg('lr_decay_strategy', str, "", "Learning rate decay strategy.")
-add_arg('log_period', int, 200, "Log period.")
-add_arg('save_model_period', int, 2000, "Save model period. '-1' means never saving the model.")
-add_arg('eval_period', int, 2000, "Evaluate period. '-1' means never evaluating the model.")
-add_arg('save_model_dir', str, "./output", "The directory the model to be saved to.")
-add_arg('train_images', str, None, "The directory of images to be used for training.")
-add_arg('train_list', str, None, "The list file of images to be used for training.")
-add_arg('test_images', str, None, "The directory of images to be used for test.")
-add_arg('test_list', str, None, "The list file of images to be used for training.")
-add_arg('init_model', str, None, "The init model file of directory.")
-add_arg('use_gpu', bool, True, "Whether use GPU to train.")
-add_arg('parallel', bool, False, "Whether use parallel training.")
-add_arg('profile', bool, False, "Whether to use profiling.")
-add_arg('skip_batch_num', int, 0, "The number of first minibatches to skip as warm-up for better performance test.")
-add_arg('skip_test', bool, False, "Whether to skip test phase.")
-# model hyper paramters
-add_arg('encoder_size', int, 200, "Encoder size.")
-add_arg('decoder_size', int, 128, "Decoder size.")
-add_arg('word_vector_dim', int, 128, "Word vector dim.")
-add_arg('num_classes', int, 95, "Number classes.")
-add_arg('gradient_clip', float, 5.0, "Gradient clip value.")
-add_arg('dynamic', bool, False, "Whether to use dygraph.")
-
-
-def train(args):
- device = set_device("gpu" if args.use_gpu else "cpu")
- fluid.enable_dygraph(device) if args.dynamic else None
-
- ocr_attention = OCRAttention(encoder_size=args.encoder_size, decoder_size=args.decoder_size,
- num_classes=args.num_classes, word_vector_dim=args.word_vector_dim)
- LR = args.lr
- if args.lr_decay_strategy == "piecewise_decay":
- learning_rate = fluid.layers.piecewise_decay([200000, 250000], [LR, LR * 0.1, LR * 0.01])
- else:
- learning_rate = LR
- optimizer = fluid.optimizer.Adam(learning_rate=learning_rate, parameter_list=ocr_attention.parameters())
- # grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(args.gradient_clip)
-
- inputs = [
- Input([None, 1, 48, 384], "float32", name="pixel"),
- Input([None, None], "int64", name="label_in"),
- ]
- labels = [
- Input([None, None], "int64", name="label_out"),
- Input([None, None], "float32", name="mask")]
-
- ocr_attention.prepare(optimizer, CrossEntropyCriterion(), inputs=inputs, labels=labels)
-
-
- train_reader = data_reader.data_reader(
- args.batch_size,
- shuffle=True,
- images_dir=args.train_images,
- list_file=args.train_list,
- data_type='train')
-
- # test_reader = data_reader.data_reader(
- # args.batch_size,
- # images_dir=args.test_images,
- # list_file=args.test_list,
- # data_type="test")
-
- # if not os.path.exists(args.save_model_dir):
- # os.makedirs(args.save_model_dir)
- total_step = 0
- epoch_num = args.epoch_num
- for epoch in range(epoch_num):
- batch_id = 0
- total_loss = 0.0
-
- for data in train_reader():
-
- total_step += 1
- data_dict = get_attention_feeder_data(data)
- pixel = data_dict["pixel"]
- label_in = data_dict["label_in"].reshape([pixel.shape[0], -1])
- label_out = data_dict["label_out"].reshape([pixel.shape[0], -1])
- mask = data_dict["mask"].reshape(label_out.shape).astype("float32")
-
- avg_loss = ocr_attention.train(inputs=[pixel, label_in], labels=[label_out, mask])[0]
- total_loss += avg_loss
-
- if True:#batch_id > 0 and batch_id % args.log_period == 0:
- print("epoch: {}, batch_id: {}, loss {}".format(epoch, batch_id,
- total_loss / args.batch_size / args.log_period))
- total_loss = 0.0
-
- batch_id += 1
-
-
-if __name__ == '__main__':
- args = parser.parse_args()
- print_arguments(args)
- if args.profile:
- if args.use_gpu:
- with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
- train(args)
- else:
- with profiler.profiler("CPU", sorted_key='total') as cpuprof:
- train(args)
- else:
- train(args)
\ No newline at end of file
diff --git a/transformer/reader.py b/transformer/reader.py
index 66fb8dc02b99f345f337d8a91b6c7eeaff71fe18..8b2d8fa028aff2170a3d3f8bb43ff7c3a93abf8e 100644
--- a/transformer/reader.py
+++ b/transformer/reader.py
@@ -289,7 +289,6 @@ class Seq2SeqDataset(Dataset):
start_mark="",
end_mark="",
unk_mark="",
- only_src=False,
trg_fpattern=None,
byte_data=False):
if byte_data: