From 8aa6740fe8dc6e3d1d781b73701f07afd00c0990 Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 9 May 2018 09:58:19 +0800 Subject: [PATCH] Append args related to dict in Transformer --- .../transformer/infer.py | 14 ++++++++++++-- .../transformer/reader.py | 7 ++++--- .../transformer/train.py | 12 +++++++++++- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index c11330a5..e8f7f47d 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -51,7 +51,17 @@ def parse_args(): default=None, nargs=argparse.REMAINDER) args = parser.parse_args() - merge_cfg_from_list(args.opts, [InferTaskConfig, ModelHyperParams]) + # Append args related to dict + src_dict = reader.DataReader.load_dict(args.src_vocab_fpath) + trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath) + dict_args = [ + "src_vocab_size", str(len(src_dict)), "trg_vocab_size", + str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]), + "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx", + str(src_dict[args.special_token[2]]) + ] + merge_cfg_from_list(args.opts + dict_args, + [InferTaskConfig, ModelHyperParams]) return args @@ -351,7 +361,7 @@ def infer(args): unk_mark=args.special_token[2], clip_last_batch=False) - trg_idx2word = test_data._load_dict( + trg_idx2word = test_data.load_dict( dict_path=args.trg_vocab_fpath, reverse=True) def post_process_seq(seq, diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 7029d358..5daa70a2 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -136,10 +136,10 @@ class DataReader(object): :param seed: The seed for random. :type seed: int """ - self._src_vocab = self._load_dict(src_vocab_fpath) + self._src_vocab = self.load_dict(src_vocab_fpath) self._only_src = True if trg_vocab_fpath is not None: - self._trg_vocab = self._load_dict(trg_vocab_fpath) + self._trg_vocab = self.load_dict(trg_vocab_fpath) self._only_src = False self._pool_size = pool_size self._batch_size = batch_size @@ -237,7 +237,8 @@ class DataReader(object): return src_seq_words, trg_seq_words - def _load_dict(self, dict_path, reverse=False): + @staticmethod + def load_dict(dict_path, reverse=False): word_dict = {} with open(dict_path, "r") as fdict: for idx, line in enumerate(fdict): diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index c65ec8bb..90fa7572 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -79,7 +79,17 @@ def parse_args(): default=None, nargs=argparse.REMAINDER) args = parser.parse_args() - merge_cfg_from_list(args.opts, [TrainTaskConfig, ModelHyperParams]) + # Append args related to dict + src_dict = reader.DataReader.load_dict(args.src_vocab_fpath) + trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath) + dict_args = [ + "src_vocab_size", str(len(src_dict)), "trg_vocab_size", + str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]), + "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx", + str(src_dict[args.special_token[2]]) + ] + merge_cfg_from_list(args.opts + dict_args, + [TrainTaskConfig, ModelHyperParams]) return args -- GitLab