From 21f50136c3d2361f0657388770b628a2a2c6b2e2 Mon Sep 17 00:00:00 2001 From: guosheng Date: Tue, 31 Mar 2020 03:20:14 +0800 Subject: [PATCH] Add seq2seq --- rnn_api.py | 94 ++++++++---- seq2seq/args.py | 138 +++++++++++++++++ seq2seq/configure.py | 350 +++++++++++++++++++++++++++++++++++++++++++ seq2seq/reader.py | 334 +++++++++++++++++++++++++++++++++++++++++ seq2seq/seq2seq.py | 257 +++++++++++++++++++++++++++++++ seq2seq/seq2seq.yaml | 83 ++++++++++ seq2seq/train.py | 135 +++++++++++++++++ 7 files changed, 1358 insertions(+), 33 deletions(-) create mode 100644 seq2seq/args.py create mode 100644 seq2seq/configure.py create mode 100644 seq2seq/reader.py create mode 100644 seq2seq/seq2seq.py create mode 100644 seq2seq/seq2seq.yaml create mode 100644 seq2seq/train.py diff --git a/rnn_api.py b/rnn_api.py index c947a40..9ccae3e 100644 --- a/rnn_api.py +++ b/rnn_api.py @@ -1,4 +1,5 @@ import collections +import copy import six import sys from functools import partial, reduce @@ -172,15 +173,15 @@ class BasicLSTMCell(RNNCell): """ def __init__(self, - hidden_size, input_size, + hidden_size, param_attr=None, bias_attr=None, gate_activation=None, activation=None, forget_bias=1.0, dtype='float32'): - super(BasicLSTMCell, self).__init__(dtype) + super(BasicLSTMCell, self).__init__() self._hidden_size = hidden_size self._param_attr = param_attr @@ -227,7 +228,7 @@ class BasicLSTMCell(RNNCell): return [[self._hidden_size], [self._hidden_size]] -class BasicGRUUnit(RNNCell): +class BasicGRUCell(RNNCell): """ **** BasicGRUUnit class, using basic operators to build GRU @@ -262,57 +263,84 @@ class BasicGRUUnit(RNNCell): """ def __init__(self, + input_size, hidden_size, param_attr=None, bias_attr=None, gate_activation=None, activation=None, dtype='float32'): - super(BasicGRUUnit, self).__init__(dtype) - - self._hidden_size = hidden_size + super(BasicGRUCell, self).__init__() + self._input_size = input_size + self._hiden_size = hidden_size self._param_attr = param_attr self._bias_attr = bias_attr self._gate_activation = gate_activation or layers.sigmoid self._activation = activation or layers.tanh - self._forget_bias = layers.fill_constant( - [1], dtype=dtype, value=forget_bias) - self._forget_bias.stop_gradient = False self._dtype = dtype - self._input_size = input_size - self._weight = self.create_parameter( - attr=self._param_attr, - shape=[ - self._input_size + self._hidden_size, 4 * self._hidden_size - ], + if self._param_attr is not None and self._param_attr.name is not None: + gate_param_attr = copy.deepcopy(self._param_attr) + candidate_param_attr = copy.deepcopy(self._param_attr) + gate_param_attr.name += "_gate" + candidate_param_attr.name += "_candidate" + else: + gate_param_attr = self._param_attr + candidate_param_attr = self._param_attr + + self._gate_weight = self.create_parameter( + attr=gate_param_attr, + shape=[self._input_size + self._hiden_size, 2 * self._hiden_size], dtype=self._dtype) - self._bias = self.create_parameter( - attr=self._bias_attr, - shape=[4 * self._hidden_size], - dtype=self._dtype, - is_bias=True) + self._candidate_weight = self.create_parameter( + attr=candidate_param_attr, + shape=[self._input_size + self._hiden_size, self._hiden_size], + dtype=self._dtype) + + if self._bias_attr is not None and self._bias_attr.name is not None: + gate_bias_attr = copy.deepcopy(self._bias_attr) + candidate_bias_attr = copy.deepcopy(self._bias_attr) + gate_bias_attr.name += "_gate" + candidate_bias_attr.name += "_candidate" + else: + gate_bias_attr = self._bias_attr + candidate_bias_attr = self._bias_attr + + self._gate_bias = self.create_parameter(attr=gate_bias_attr, + shape=[2 * self._hiden_size], + dtype=self._dtype, + is_bias=True) + self._candidate_bias = self.create_parameter(attr=candidate_bias_attr, + shape=[self._hiden_size], + dtype=self._dtype, + is_bias=True) def forward(self, input, state): - pre_hidden, pre_cell = state - concat_input_hidden = layers.concat([input, pre_hidden], 1) - gate_input = layers.matmul(x=concat_input_hidden, y=self._weight) + pre_hidden = state + concat_input_hidden = layers.concat([input, pre_hidden], axis=1) - gate_input = layers.elementwise_add(gate_input, self._bias) - i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1) - new_cell = layers.elementwise_add( - layers.elementwise_mul( - pre_cell, - layers.sigmoid(layers.elementwise_add(f, self._forget_bias))), - layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j))) - new_hidden = layers.tanh(new_cell) * layers.sigmoid(o) + gate_input = layers.matmul(x=concat_input_hidden, y=self._gate_weight) - return new_hidden, [new_hidden, new_cell] + gate_input = layers.elementwise_add(gate_input, self._gate_bias) + + gate_input = self._gate_activation(gate_input) + r, u = layers.split(gate_input, num_or_sections=2, dim=1) + + r_hidden = r * pre_hidden + + candidate = layers.matmul(layers.concat([input, r_hidden], 1), + self._candidate_weight) + candidate = layers.elementwise_add(candidate, self._candidate_bias) + + c = self._activation(candidate) + new_hidden = u * pre_hidden + (1 - u) * c + + return new_hidden @property def state_shape(self): - return [[self._hidden_size], [self._hidden_size]] + return [self._hidden_size] class RNN(fluid.dygraph.Layer): diff --git a/seq2seq/args.py b/seq2seq/args.py new file mode 100644 index 0000000..928bb1e --- /dev/null +++ b/seq2seq/args.py @@ -0,0 +1,138 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import distutils.util + + +def parse_args(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--train_data_prefix", + type=str, + help="file prefix for train data") + parser.add_argument("--eval_data_prefix", + type=str, + help="file prefix for eval data") + parser.add_argument("--test_data_prefix", + type=str, + help="file prefix for test data") + parser.add_argument("--vocab_prefix", + type=str, + help="file prefix for vocab") + parser.add_argument("--src_lang", type=str, help="source language suffix") + parser.add_argument("--tar_lang", type=str, help="target language suffix") + + parser.add_argument("--attention", + type=eval, + default=False, + help="Whether use attention model") + + parser.add_argument("--optimizer", + type=str, + default='adam', + help="optimizer to use, only supprt[sgd|adam]") + + parser.add_argument("--learning_rate", + type=float, + default=0.001, + help="learning rate for optimizer") + + parser.add_argument("--num_layers", + type=int, + default=1, + help="layers number of encoder and decoder") + parser.add_argument("--hidden_size", + type=int, + default=100, + help="hidden size of encoder and decoder") + parser.add_argument("--src_vocab_size", type=int, help="source vocab size") + parser.add_argument("--tar_vocab_size", type=int, help="target vocab size") + + parser.add_argument("--batch_size", + type=int, + help="batch size of each step") + + parser.add_argument("--max_epoch", + type=int, + default=12, + help="max epoch for the training") + + parser.add_argument("--max_len", + type=int, + default=50, + help="max length for source and target sentence") + parser.add_argument("--dropout", + type=float, + default=0.0, + help="drop probability") + parser.add_argument("--init_scale", + type=float, + default=0.0, + help="init scale for parameter") + parser.add_argument("--max_grad_norm", + type=float, + default=5.0, + help="max grad norm for global norm clip") + + parser.add_argument("--model_path", + type=str, + default='model', + help="model path for model to save") + + parser.add_argument("--reload_model", + type=str, + help="reload model to inference") + + parser.add_argument("--infer_file", + type=str, + help="file name for inference") + parser.add_argument("--infer_output_file", + type=str, + default='infer_output', + help="file name for inference output") + parser.add_argument("--beam_size", + type=int, + default=10, + help="file name for inference") + + parser.add_argument('--use_gpu', + type=eval, + default=False, + help='Whether using gpu [True|False]') + + parser.add_argument('--eager_run', + type=eval, + default=False, + help='Whether to use dygraph') + + parser.add_argument("--enable_ce", + action='store_true', + help="The flag indicating whether to run the task " + "for continuous evaluation.") + + parser.add_argument("--profile", + action='store_true', + help="Whether enable the profile.") + # NOTE: profiler args, used for benchmark + parser.add_argument( + "--profiler_path", + type=str, + default='./seq2seq.profile', + help="the profiler output file path. (used for benchmark)") + args = parser.parse_args() + return args diff --git a/seq2seq/configure.py b/seq2seq/configure.py new file mode 100644 index 0000000..67e6012 --- /dev/null +++ b/seq2seq/configure.py @@ -0,0 +1,350 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import sys +import argparse +import json +import yaml +import six +import logging + +logging_only_message = "%(message)s" +logging_details = "%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s" + + +class JsonConfig(object): + """ + A high-level api for handling json configure file. + """ + + def __init__(self, config_path): + self._config_dict = self._parse(config_path) + + def _parse(self, config_path): + try: + with open(config_path) as json_file: + config_dict = json.load(json_file) + except: + raise IOError("Error in parsing bert model config file '%s'" % + config_path) + else: + return config_dict + + def __getitem__(self, key): + return self._config_dict[key] + + def print_config(self): + for arg, value in sorted(six.iteritems(self._config_dict)): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +class ArgumentGroup(object): + def __init__(self, parser, title, des): + self._group = parser.add_argument_group(title=title, description=des) + + def add_arg(self, name, type, default, help, **kwargs): + type = str2bool if type == bool else type + self._group.add_argument( + "--" + name, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) + + +class ArgConfig(object): + """ + A high-level api for handling argument configs. + """ + + def __init__(self): + parser = argparse.ArgumentParser() + + train_g = ArgumentGroup(parser, "training", "training options.") + train_g.add_arg("epoch", int, 3, "Number of epoches for fine-tuning.") + train_g.add_arg("learning_rate", float, 5e-5, + "Learning rate used to train with warmup.") + train_g.add_arg( + "lr_scheduler", + str, + "linear_warmup_decay", + "scheduler of learning rate.", + choices=['linear_warmup_decay', 'noam_decay']) + train_g.add_arg("weight_decay", float, 0.01, + "Weight decay rate for L2 regularizer.") + train_g.add_arg( + "warmup_proportion", float, 0.1, + "Proportion of training steps to perform linear learning rate warmup for." + ) + train_g.add_arg("save_steps", int, 1000, + "The steps interval to save checkpoints.") + train_g.add_arg("use_fp16", bool, False, + "Whether to use fp16 mixed precision training.") + train_g.add_arg( + "loss_scaling", float, 1.0, + "Loss scaling factor for mixed precision training, only valid when use_fp16 is enabled." + ) + train_g.add_arg("pred_dir", str, None, + "Path to save the prediction results") + + log_g = ArgumentGroup(parser, "logging", "logging related.") + log_g.add_arg("skip_steps", int, 10, + "The steps interval to print loss.") + log_g.add_arg("verbose", bool, False, "Whether to output verbose log.") + + run_type_g = ArgumentGroup(parser, "run_type", "running type options.") + run_type_g.add_arg("use_cuda", bool, True, + "If set, use GPU for training.") + run_type_g.add_arg( + "use_fast_executor", bool, False, + "If set, use fast parallel executor (in experiment).") + run_type_g.add_arg( + "num_iteration_per_drop_scope", int, 1, + "Ihe iteration intervals to clean up temporary variables.") + run_type_g.add_arg("do_train", bool, True, + "Whether to perform training.") + run_type_g.add_arg("do_predict", bool, True, + "Whether to perform prediction.") + + custom_g = ArgumentGroup(parser, "customize", "customized options.") + + self.custom_g = custom_g + + self.parser = parser + + def add_arg(self, name, dtype, default, descrip): + self.custom_g.add_arg(name, dtype, default, descrip) + + def build_conf(self): + return self.parser.parse_args() + + +def str2bool(v): + # because argparse does not support to parse "true, False" as python + # boolean directly + return v.lower() in ("true", "t", "1") + + +def print_arguments(args, log=None): + if not log: + print('----------- Configuration Arguments -----------') + for arg, value in sorted(six.iteritems(vars(args))): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + else: + log.info('----------- Configuration Arguments -----------') + for arg, value in sorted(six.iteritems(vars(args))): + log.info('%s: %s' % (arg, value)) + log.info('------------------------------------------------') + + +class PDConfig(object): + """ + A high-level API for managing configuration files in PaddlePaddle. + Can jointly work with command-line-arugment, json files and yaml files. + """ + + def __init__(self, json_file="", yaml_file="", fuse_args=True): + """ + Init funciton for PDConfig. + json_file: the path to the json configure file. + yaml_file: the path to the yaml configure file. + fuse_args: if fuse the json/yaml configs with argparse. + """ + assert isinstance(json_file, str) + assert isinstance(yaml_file, str) + + if json_file != "" and yaml_file != "": + raise Warning( + "json_file and yaml_file can not co-exist for now. please only use one configure file type." + ) + return + + self.args = None + self.arg_config = {} + self.json_config = {} + self.yaml_config = {} + + parser = argparse.ArgumentParser() + + self.default_g = ArgumentGroup(parser, "default", "default options.") + self.yaml_g = ArgumentGroup(parser, "yaml", "options from yaml.") + self.json_g = ArgumentGroup(parser, "json", "options from json.") + self.com_g = ArgumentGroup(parser, "custom", "customized options.") + + self.default_g.add_arg("do_train", bool, False, + "Whether to perform training.") + self.default_g.add_arg("do_predict", bool, False, + "Whether to perform predicting.") + self.default_g.add_arg("do_eval", bool, False, + "Whether to perform evaluating.") + self.default_g.add_arg("do_save_inference_model", bool, False, + "Whether to perform model saving for inference.") + + # NOTE: args for profiler + self.default_g.add_arg("is_profiler", int, 0, "the switch of profiler tools. (used for benchmark)") + self.default_g.add_arg("profiler_path", str, './', "the profiler output file path. (used for benchmark)") + self.default_g.add_arg("max_iter", int, 0, "the max train batch num.(used for benchmark)") + + self.parser = parser + + if json_file != "": + self.load_json(json_file, fuse_args=fuse_args) + + if yaml_file: + self.load_yaml(yaml_file, fuse_args=fuse_args) + + def load_json(self, file_path, fuse_args=True): + + if not os.path.exists(file_path): + raise Warning("the json file %s does not exist." % file_path) + return + + with open(file_path, "r") as fin: + self.json_config = json.loads(fin.read()) + fin.close() + + if fuse_args: + for name in self.json_config: + if isinstance(self.json_config[name], list): + self.json_g.add_arg( + name, + type(self.json_config[name][0]), + self.json_config[name], + "This is from %s" % file_path, + nargs=len(self.json_config[name])) + continue + if not isinstance(self.json_config[name], int) \ + and not isinstance(self.json_config[name], float) \ + and not isinstance(self.json_config[name], str) \ + and not isinstance(self.json_config[name], bool): + + continue + + self.json_g.add_arg(name, + type(self.json_config[name]), + self.json_config[name], + "This is from %s" % file_path) + + def load_yaml(self, file_path, fuse_args=True): + + if not os.path.exists(file_path): + raise Warning("the yaml file %s does not exist." % file_path) + return + + with open(file_path, "r") as fin: + self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) + fin.close() + + if fuse_args: + for name in self.yaml_config: + if isinstance(self.yaml_config[name], list): + self.yaml_g.add_arg( + name, + type(self.yaml_config[name][0]), + self.yaml_config[name], + "This is from %s" % file_path, + nargs=len(self.yaml_config[name])) + continue + + if not isinstance(self.yaml_config[name], int) \ + and not isinstance(self.yaml_config[name], float) \ + and not isinstance(self.yaml_config[name], str) \ + and not isinstance(self.yaml_config[name], bool): + + continue + + self.yaml_g.add_arg(name, + type(self.yaml_config[name]), + self.yaml_config[name], + "This is from %s" % file_path) + + def build(self): + self.args = self.parser.parse_args() + self.arg_config = vars(self.args) + + def __add__(self, new_arg): + assert isinstance(new_arg, list) or isinstance(new_arg, tuple) + assert len(new_arg) >= 3 + assert self.args is None + + name = new_arg[0] + dtype = new_arg[1] + dvalue = new_arg[2] + desc = new_arg[3] if len( + new_arg) == 4 else "Description is not provided." + + self.com_g.add_arg(name, dtype, dvalue, desc) + + return self + + def __getattr__(self, name): + if name in self.arg_config: + return self.arg_config[name] + + if name in self.json_config: + return self.json_config[name] + + if name in self.yaml_config: + return self.yaml_config[name] + + raise Warning("The argument %s is not defined." % name) + + def Print(self): + + print("-" * 70) + for name in self.arg_config: + print("%s:\t\t\t\t%s" % (str(name), str(self.arg_config[name]))) + + for name in self.json_config: + if name not in self.arg_config: + print("%s:\t\t\t\t%s" % + (str(name), str(self.json_config[name]))) + + for name in self.yaml_config: + if name not in self.arg_config: + print("%s:\t\t\t\t%s" % + (str(name), str(self.yaml_config[name]))) + + print("-" * 70) + + +if __name__ == "__main__": + """ + pd_config = PDConfig(json_file = "./test/bert_config.json") + pd_config.build() + + print(pd_config.do_train) + print(pd_config.hidden_size) + + pd_config = PDConfig(yaml_file = "./test/bert_config.yaml") + pd_config.build() + + print(pd_config.do_train) + print(pd_config.hidden_size) + """ + + pd_config = PDConfig(yaml_file="./test/bert_config.yaml") + pd_config += ("my_age", int, 18, "I am forever 18.") + pd_config.build() + + print(pd_config.do_train) + print(pd_config.hidden_size) + print(pd_config.my_age) diff --git a/seq2seq/reader.py b/seq2seq/reader.py new file mode 100644 index 0000000..3bdbe05 --- /dev/null +++ b/seq2seq/reader.py @@ -0,0 +1,334 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import six +import os +import tarfile +import itertools + +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.fluid.io import BatchSampler, DataLoader, Dataset + + +def prepare_train_input(insts, src_pad_idx, trg_pad_idx): + """ + Put all padded data needed by training into a list. + """ + src, src_length = pad_batch_data([inst[0] for inst in insts], src_pad_idx) + trg, trg_length = pad_batch_data([inst[1] for inst in insts], trg_pad_idx) + label, _ = pad_batch_data([inst[2] for inst in insts], trg_pad_idx) + return src, src_length, trg, trg_length, np.expand_dims(label, -1) + + +def pad_batch_data(insts, pad_idx): + """ + Pad the instances to the max sequence length in batch, and generate the + corresponding position data and attention bias. + """ + inst_length = np.array([len(inst) for inst in insts], dtype="int64") + max_len = np.max(inst_length) + inst_data = np.array( + [inst + [pad_idx] * (max_len - len(inst)) for inst in insts]) + return inst_data, inst_length + + +class SortType(object): + GLOBAL = 'global' + POOL = 'pool' + NONE = "none" + + +class Converter(object): + def __init__(self, vocab, beg, end, unk, delimiter, add_beg): + self._vocab = vocab + self._beg = beg + self._end = end + self._unk = unk + self._delimiter = delimiter + self._add_beg = add_beg + + def __call__(self, sentence): + return ([self._beg] if self._add_beg else []) + [ + self._vocab.get(w, self._unk) + for w in sentence.split(self._delimiter) + ] + [self._end] + + +class ComposedConverter(object): + def __init__(self, converters): + self._converters = converters + + def __call__(self, parallel_sentence): + return [ + self._converters[i](parallel_sentence[i]) + for i in range(len(self._converters)) + ] + + +class SentenceBatchCreator(object): + def __init__(self, batch_size): + self.batch = [] + self._batch_size = batch_size + + def append(self, info): + self.batch.append(info) + if len(self.batch) == self._batch_size: + tmp = self.batch + self.batch = [] + return tmp + + +class TokenBatchCreator(object): + def __init__(self, batch_size): + self.batch = [] + self.max_len = -1 + self._batch_size = batch_size + + def append(self, info): + cur_len = info.max_len + max_len = max(self.max_len, cur_len) + if max_len * (len(self.batch) + 1) > self._batch_size: + result = self.batch + self.batch = [info] + self.max_len = cur_len + return result + else: + self.max_len = max_len + self.batch.append(info) + + +class SampleInfo(object): + def __init__(self, i, max_len, min_len): + self.i = i + self.min_len = min_len + self.max_len = max_len + + +class MinMaxFilter(object): + def __init__(self, max_len, min_len, underlying_creator): + self._min_len = min_len + self._max_len = max_len + self._creator = underlying_creator + + def append(self, info): + if info.max_len > self._max_len or info.min_len < self._min_len: + return + else: + return self._creator.append(info) + + @property + def batch(self): + return self._creator.batch + + +class Seq2SeqDataset(Dataset): + def __init__(self, + src_vocab_fpath, + trg_vocab_fpath, + fpattern, + tar_fname=None, + field_delimiter="\t", + token_delimiter=" ", + start_mark="", + end_mark="", + unk_mark="", + only_src=False): + # convert str to bytes, and use byte data + field_delimiter = field_delimiter.encode("utf8") + token_delimiter = token_delimiter.encode("utf8") + start_mark = start_mark.encode("utf8") + end_mark = end_mark.encode("utf8") + unk_mark = unk_mark.encode("utf8") + self._src_vocab = self.load_dict(src_vocab_fpath) + self._trg_vocab = self.load_dict(trg_vocab_fpath) + self._bos_idx = self._src_vocab[start_mark] + self._eos_idx = self._src_vocab[end_mark] + self._unk_idx = self._src_vocab[unk_mark] + self._only_src = only_src + self._field_delimiter = field_delimiter + self._token_delimiter = token_delimiter + self.load_src_trg_ids(fpattern, tar_fname) + + def load_src_trg_ids(self, fpattern, tar_fname): + converters = [ + Converter(vocab=self._src_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=False) + ] + if not self._only_src: + converters.append( + Converter(vocab=self._trg_vocab, + beg=self._bos_idx, + end=self._eos_idx, + unk=self._unk_idx, + delimiter=self._token_delimiter, + add_beg=True)) + + converters = ComposedConverter(converters) + + self._src_seq_ids = [] + self._trg_seq_ids = None if self._only_src else [] + self._sample_infos = [] + + for i, line in enumerate(self._load_lines(fpattern, tar_fname)): + src_trg_ids = converters(line) + self._src_seq_ids.append(src_trg_ids[0]) + lens = [len(src_trg_ids[0])] + if not self._only_src: + self._trg_seq_ids.append(src_trg_ids[1]) + lens.append(len(src_trg_ids[1])) + self._sample_infos.append(SampleInfo(i, max(lens), min(lens))) + + def _load_lines(self, fpattern, tar_fname): + fpaths = glob.glob(fpattern) + assert len(fpaths) > 0, "no matching file to the provided data path" + + if len(fpaths) == 1 and tarfile.is_tarfile(fpaths[0]): + if tar_fname is None: + raise Exception("If tar file provided, please set tar_fname.") + + f = tarfile.open(fpaths[0], "rb") + for line in f.extractfile(tar_fname): + fields = line.strip(b"\n").split(self._field_delimiter) + if (not self._only_src + and len(fields) == 2) or (self._only_src + and len(fields) == 1): + yield fields + else: + for fpath in fpaths: + if not os.path.isfile(fpath): + raise IOError("Invalid file: %s" % fpath) + + with open(fpath, "rb") as f: + for line in f: + fields = line.strip(b"\n").split(self._field_delimiter) + if (not self._only_src and len(fields) == 2) or ( + self._only_src and len(fields) == 1): + yield fields + + @staticmethod + def load_dict(dict_path, reverse=False): + word_dict = {} + with open(dict_path, "rb") as fdict: + for idx, line in enumerate(fdict): + if reverse: + word_dict[idx] = line.strip(b"\n") + else: + word_dict[line.strip(b"\n")] = idx + return word_dict + + def get_vocab_summary(self): + return len(self._src_vocab), len( + self._trg_vocab), self._bos_idx, self._eos_idx, self._unk_idx + + def __getitem__(self, idx): + return (self._src_seq_ids[idx], self._trg_seq_ids[idx][:-1], + self._trg_seq_ids[idx][1:] + ) if not self._only_src else self._src_seq_ids[idx] + + def __len__(self): + return len(self._sample_infos) + + +class Seq2SeqBatchSampler(BatchSampler): + def __init__(self, + dataset, + batch_size, + pool_size, + sort_type=SortType.GLOBAL, + min_length=0, + max_length=100, + shuffle=True, + shuffle_batch=False, + use_token_batch=False, + clip_last_batch=False, + seed=0): + for arg, value in locals().items(): + if arg != "self": + setattr(self, "_" + arg, value) + self._random = np.random + self._random.seed(seed) + # for multi-devices + self._nranks = ParallelEnv().nranks + self._local_rank = ParallelEnv().local_rank + self._device_id = ParallelEnv().dev_id + + def __iter__(self): + # global sort or global shuffle + if self._sort_type == SortType.GLOBAL: + infos = sorted(self._dataset._sample_infos, + key=lambda x: x.max_len) + else: + if self._shuffle: + infos = self._dataset._sample_infos + self._random.shuffle(infos) + else: + infos = self._dataset._sample_infos + + if self._sort_type == SortType.POOL: + reverse = True + for i in range(0, len(infos), self._pool_size): + # to avoid placing short next to long sentences + reverse = not reverse + infos[i:i + self._pool_size] = sorted( + infos[i:i + self._pool_size], + key=lambda x: x.max_len, + reverse=reverse) + + batches = [] + batch_creator = TokenBatchCreator( + self._batch_size + ) if self._use_token_batch else SentenceBatchCreator(self._batch_size * + self._nranks) + batch_creator = MinMaxFilter(self._max_length, self._min_length, + batch_creator) + + for info in infos: + batch = batch_creator.append(info) + if batch is not None: + batches.append(batch) + + if not self._clip_last_batch and len(batch_creator.batch) != 0: + batches.append(batch_creator.batch) + + if self._shuffle_batch: + self._random.shuffle(batches) + + if not self._use_token_batch: + # when producing batches according to sequence number, to confirm + # neighbor batches which would be feed and run parallel have similar + # length (thus similar computational cost) after shuffle, we as take + # them as a whole when shuffling and split here + batches = [[ + batch[self._batch_size * i:self._batch_size * (i + 1)] + for i in range(self._nranks) + ] for batch in batches] + batches = itertools.chain.from_iterable(batches) + + # for multi-device + for batch_id, batch in enumerate(batches): + if batch_id % self._nranks == self._local_rank: + batch_indices = [info.i for info in batch] + yield batch_indices + if self._local_rank > len(batches) % self._nranks: + yield batch_indices + + def __len__(self): + return 100 diff --git a/seq2seq/seq2seq.py b/seq2seq/seq2seq.py new file mode 100644 index 0000000..74431b9 --- /dev/null +++ b/seq2seq/seq2seq.py @@ -0,0 +1,257 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +from paddle.fluid import ParamAttr +from paddle.fluid.initializer import UniformInitializer +from paddle.fluid.dygraph import Embedding, Linear, Layer +from rnn_api import DynamicDecode, RNN, BasicLSTMCell, RNNCell +from model import Model, Loss + + +class CrossEntropyCriterion(Loss): + def __init__(self): + super(CrossEntropyCriterion, self).__init__() + + def forward(self, outputs, labels): + (predict, mask), label = outputs, labels[0] + + cost = layers.softmax_with_cross_entropy(logits=predict, + label=label, + soft_label=False) + masked_cost = layers.elementwise_mul(cost, mask, axis=0) + batch_mean_cost = layers.reduce_mean(masked_cost, dim=[0]) + seq_cost = layers.reduce_sum(batch_mean_cost) + return seq_cost + + +class EncoderCell(RNNCell): + def __init__(self, + num_layers, + input_size, + hidden_size, + dropout_prob=0., + init_scale=0.1): + super(EncoderCell, self).__init__() + self.dropout_prob = dropout_prob + # use add_sublayer to add multi-layers + self.lstm_cells = [] + for i in range(num_layers): + self.lstm_cells.append( + self.add_sublayer( + "lstm_%d" % i, + BasicLSTMCell( + input_size=input_size if i == 0 else hidden_size, + hidden_size=hidden_size, + param_attr=ParamAttr(initializer=UniformInitializer( + low=-init_scale, high=init_scale))))) + + def forward(self, step_input, states): + new_states = [] + for i, lstm_cell in enumerate(self.lstm_cells): + out, new_state = lstm_cell(step_input, states[i]) + step_input = layers.dropout( + out, self.dropout_prob) if self.dropout_prob > 0 else out + new_states.append(new_state) + return step_input, new_states + + @property + def state_shape(self): + return [cell.state_shape for cell in self.lstm_cells] + + +class Encoder(Layer): + def __init__(self, + vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob=0., + init_scale=0.1): + super(Encoder, self).__init__() + self.embedder = Embedding( + size=[vocab_size, embed_dim], + param_attr=ParamAttr(initializer=UniformInitializer( + low=-init_scale, high=init_scale))) + self.stack_lstm = RNN(EncoderCell(num_layers, embed_dim, hidden_size, + init_scale), + is_reverse=False, + time_major=False) + + def forward(self, sequence, sequence_length): + inputs = self.embedder(sequence) + encoder_output, encoder_state = self.stack_lstm( + inputs, sequence_length=sequence_length) + return encoder_output, encoder_state + + +class AttentionLayer(Layer): + def __init__(self, hidden_size, bias=False, init_scale=0.1): + super(AttentionLayer, self).__init__() + self.input_proj = Linear( + hidden_size, + hidden_size, + param_attr=ParamAttr(initializer=UniformInitializer( + low=-init_scale, high=init_scale)), + bias_attr=bias) + self.output_proj = Linear( + hidden_size + hidden_size, + hidden_size, + param_attr=ParamAttr(initializer=UniformInitializer( + low=-init_scale, high=init_scale)), + bias_attr=bias) + + def forward(self, hidden, encoder_output, encoder_padding_mask): + query = self.input_proj(hidden) + attn_scores = layers.matmul(layers.unsqueeze(query, [1]), + encoder_output, + transpose_y=True) + if encoder_padding_mask is not None: + attn_scores = layers.elementwise_add(attn_scores, + encoder_padding_mask) + attn_scores = layers.softmax(attn_scores) + attn_out = layers.squeeze(layers.matmul(attn_scores, encoder_output), + [1]) + attn_out = layers.concat([attn_out, hidden], 1) + attn_out = self.output_proj(attn_out) + return attn_out + + +class DecoderCell(RNNCell): + def __init__(self, + num_layers, + input_size, + hidden_size, + dropout_prob=0., + init_scale=0.1): + super(DecoderCell, self).__init__() + self.dropout_prob = dropout_prob + # use add_sublayer to add multi-layers + self.lstm_cells = [] + for i in range(num_layers): + self.lstm_cells.append( + self.add_sublayer( + "lstm_%d" % i, + BasicLSTMCell(input_size=input_size + + hidden_size if i == 0 else hidden_size, + hidden_size=hidden_size))) + self.attention_layer = AttentionLayer(hidden_size) + + def forward(self, + step_input, + states, + encoder_output, + encoder_padding_mask=None): + lstm_states, input_feed = states + new_lstm_states = [] + step_input = layers.concat([step_input, input_feed], 1) + for i, lstm_cell in enumerate(self.lstm_cells): + out, new_lstm_state = lstm_cell(step_input, lstm_states[i]) + step_input = layers.dropout( + out, self.dropout_prob) if self.dropout_prob > 0 else out + new_lstm_states.append(new_lstm_state) + out = self.attention_layer(step_input, encoder_output, + encoder_padding_mask) + return out, [new_lstm_states, out] + + +class Decoder(Layer): + def __init__(self, + vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob=0., + init_scale=0.1): + super(Decoder, self).__init__() + self.embedder = Embedding( + size=[vocab_size, embed_dim], + param_attr=ParamAttr(initializer=UniformInitializer( + low=-init_scale, high=init_scale))) + self.lstm_attention = RNN(DecoderCell(num_layers, embed_dim, + hidden_size, init_scale), + is_reverse=False, + time_major=False) + self.output_layer = Linear( + hidden_size, + vocab_size, + param_attr=ParamAttr(initializer=UniformInitializer( + low=-init_scale, high=init_scale)), + bias_attr=False) + + def forward(self, target, decoder_initial_states, encoder_output, + encoder_padding_mask): + inputs = self.embedder(target) + decoder_output, _ = self.lstm_attention( + inputs, + decoder_initial_states, + encoder_output=encoder_output, + encoder_padding_mask=encoder_padding_mask) + predict = self.output_layer(decoder_output) + return predict + + +class Seq2Seq(Model): + def __init__(self, + src_vocab_size, + trg_vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob=0., + init_scale=0.1): + super(Seq2Seq, self).__init__() + self.hidden_size = hidden_size + self.encoder = Encoder(src_vocab_size, embed_dim, hidden_size, + num_layers, dropout_prob, init_scale) + self.decoder = Decoder(trg_vocab_size, embed_dim, hidden_size, + num_layers, dropout_prob, init_scale) + + def forward(self, src, src_length, trg, trg_length): + # encoder + encoder_output, encoder_final_state = self.encoder(src, src_length) + + # decoder initial states + decoder_initial_states = [ + encoder_final_state, + self.decoder.lstm_attention.cell.get_initial_states( + batch_ref=encoder_output, shape=[self.hidden_size]) + ] + # attention mask to avoid paying attention on padddings + src_mask = layers.sequence_mask(src_length, + maxlen=layers.shape(src)[1], + dtype=encoder_output.dtype) + encoder_padding_mask = (src_mask - 1.0) * 1e9 + encoder_padding_mask = layers.unsqueeze(encoder_padding_mask, [1]) + + # decoder with attentioon + predict = self.decoder(trg, decoder_initial_states, encoder_output, + encoder_padding_mask) + + # for target padding mask + mask = layers.sequence_mask(trg_length, + maxlen=layers.shape(trg)[1], + dtype=predict.dtype) + return predict, mask + + +class Seq2SeqInferModel(Seq2Seq): + def __init__(self, + vocab_size, + embed_dim, + hidden_size, + num_layers, + dropout_prob=0.): + pass diff --git a/seq2seq/seq2seq.yaml b/seq2seq/seq2seq.yaml new file mode 100644 index 0000000..8e0edb7 --- /dev/null +++ b/seq2seq/seq2seq.yaml @@ -0,0 +1,83 @@ +# used for continuous evaluation +enable_ce: False + +eager_run: False + +# The frequency to save trained models when training. +save_step: 10000 +# The frequency to fetch and print output when training. +print_step: 100 +# path of the checkpoint, to resume the previous training +init_from_checkpoint: "" +# path of the pretrain model, to better solve the current task +init_from_pretrain_model: "" +# path of trained parameter, to make prediction +init_from_params: "trained_params/step_100000/" +# the directory for saving model +save_model: "trained_models" +# the directory for saving inference model. +inference_model_dir: "infer_model" +# Set seed for CE or debug +random_seed: None +# The pattern to match training data files. +training_file: "wmt16_ende_data_bpe/train.tok.clean.bpe.32000.en-de" +# The pattern to match validation data files. +validation_file: "wmt16_ende_data_bpe/newstest2014.tok.bpe.32000.en-de" +# The pattern to match test data files. +predict_file: "wmt16_ende_data_bpe/newstest2016.tok.bpe.32000.en-de" +# The file to output the translation results of predict_file to. +output_file: "predict.txt" +# The path of vocabulary file of source language. +src_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000" +# The path of vocabulary file of target language. +trg_vocab_fpath: "wmt16_ende_data_bpe/vocab_all.bpe.32000" +# The , and tokens in the dictionary. +special_token: ["", "", ""] +# max length of sequences +max_length: 256 + +# whether to use cuda +use_cuda: True + +# args for reader, see reader.py for details +token_delimiter: " " +use_token_batch: True +pool_size: 200000 +sort_type: "pool" +shuffle: True +shuffle_batch: True +batch_size: 4096 + +# Hyparams for training: +# the number of epoches for training +epoch: 30 +# the hyper parameters for Adam optimizer. +# This static learning_rate will be multiplied to the LearningRateScheduler +# derived learning rate the to get the final learning rate. +learning_rate: 0.001 + + +# Hyparams for generation: +# the parameters for beam search. +beam_size: 5 +max_out_len: 256 +# the number of decoded sentences to output. +n_best: 1 + +# Hyparams for model: +# These following five vocabularies related configurations will be set +# automatically according to the passed vocabulary path and special tokens. +# size of source word dictionary. +src_vocab_size: 10000 +# size of target word dictionay +trg_vocab_size: 10000 +# index for token +bos_idx: 0 +# index for token +eos_idx: 1 +# index for token +unk_idx: 2 +embed_dim: 512 +hidden_size: 512 +num_layers: 2 +dropout: 0.1 diff --git a/seq2seq/train.py b/seq2seq/train.py new file mode 100644 index 0000000..04c2420 --- /dev/null +++ b/seq2seq/train.py @@ -0,0 +1,135 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import six +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +import time +import contextlib +from functools import partial + +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid.dygraph import to_variable +from paddle.fluid.io import DataLoader + +from configure import PDConfig +from reader import prepare_train_input, Seq2SeqDataset, Seq2SeqBatchSampler +from seq2seq import Seq2Seq, CrossEntropyCriterion +from model import Input, set_device +from callbacks import ProgBarLogger + + +class LoggerCallback(ProgBarLogger): + def __init__(self, log_freq=1, verbose=2, loss_normalizer=0.): + super(LoggerCallback, self).__init__(log_freq, verbose) + # TODO: wrap these override function to simplify + self.loss_normalizer = loss_normalizer + + def on_train_begin(self, logs=None): + super(LoggerCallback, self).on_train_begin(logs) + self.train_metrics += ["normalized loss", "ppl"] + + def on_train_batch_end(self, step, logs=None): + logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer + logs["ppl"] = np.exp(min(logs["loss"][0], 100)) + super(LoggerCallback, self).on_train_batch_end(step, logs) + + def on_eval_begin(self, logs=None): + super(LoggerCallback, self).on_eval_begin(logs) + self.eval_metrics += ["normalized loss", "ppl"] + + def on_eval_batch_end(self, step, logs=None): + logs["normalized loss"] = logs["loss"][0] - self.loss_normalizer + logs["ppl"] = np.exp(min(logs["loss"][0], 100)) + super(LoggerCallback, self).on_eval_batch_end(step, logs) + + +def do_train(args): + device = set_device("gpu" if args.use_cuda else "cpu") + fluid.enable_dygraph(device) if args.eager_run else None + + # set seed for CE + random_seed = eval(str(args.random_seed)) + if random_seed is not None: + fluid.default_main_program().random_seed = random_seed + fluid.default_startup_program().random_seed = random_seed + + # define model + inputs = [ + Input([None, None], "int64", name="src_word"), + Input([None], "int64", name="src_length"), + Input([None, None], "int64", name="trg_word"), + Input([None], "int64", name="trg_length"), + ] + labels = [ + Input([None, None, 1], "int64", name="label"), + ] + + dataset = Seq2SeqDataset(fpattern=args.training_file, + src_vocab_fpath=args.src_vocab_fpath, + trg_vocab_fpath=args.trg_vocab_fpath, + token_delimiter=args.token_delimiter, + start_mark=args.special_token[0], + end_mark=args.special_token[1], + unk_mark=args.special_token[2]) + args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \ + args.unk_idx = dataset.get_vocab_summary() + batch_sampler = Seq2SeqBatchSampler(dataset=dataset, + use_token_batch=args.use_token_batch, + batch_size=args.batch_size, + pool_size=args.pool_size, + sort_type=args.sort_type, + shuffle=args.shuffle, + shuffle_batch=args.shuffle_batch, + max_length=args.max_length) + train_loader = DataLoader(dataset=dataset, + batch_sampler=batch_sampler, + places=device, + feed_list=[x.forward() for x in inputs + labels], + collate_fn=partial(prepare_train_input, + src_pad_idx=args.eos_idx, + trg_pad_idx=args.eos_idx), + num_workers=0, + return_list=True) + + model = Seq2Seq(args.src_vocab_size, args.trg_vocab_size, args.embed_dim, + args.hidden_size, args.num_layers, args.dropout) + + model.prepare(fluid.optimizer.Adam(learning_rate=args.learning_rate, + parameter_list=model.parameters()), + CrossEntropyCriterion(), + inputs=inputs, + labels=labels) + + model.fit(train_data=train_loader, + eval_data=None, + epochs=1, + eval_freq=1, + save_freq=1, + verbose=2, + callbacks=[ + LoggerCallback(log_freq=args.print_step) + ]) + + +if __name__ == "__main__": + args = PDConfig(yaml_file="./seq2seq.yaml") + args.build() + args.Print() + + do_train(args) -- GitLab