diff --git a/fluid/neural_machine_translation/transformer/infer.py b/fluid/neural_machine_translation/transformer/infer.py index c11330a5467e5a56a21d9e7f70ffe802253d0cf3..e8f7f47dd5c0dc4937b73bd1693b2fd14fb8d55c 100644 --- a/fluid/neural_machine_translation/transformer/infer.py +++ b/fluid/neural_machine_translation/transformer/infer.py @@ -51,7 +51,17 @@ def parse_args(): default=None, nargs=argparse.REMAINDER) args = parser.parse_args() - merge_cfg_from_list(args.opts, [InferTaskConfig, ModelHyperParams]) + # Append args related to dict + src_dict = reader.DataReader.load_dict(args.src_vocab_fpath) + trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath) + dict_args = [ + "src_vocab_size", str(len(src_dict)), "trg_vocab_size", + str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]), + "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx", + str(src_dict[args.special_token[2]]) + ] + merge_cfg_from_list(args.opts + dict_args, + [InferTaskConfig, ModelHyperParams]) return args @@ -351,7 +361,7 @@ def infer(args): unk_mark=args.special_token[2], clip_last_batch=False) - trg_idx2word = test_data._load_dict( + trg_idx2word = test_data.load_dict( dict_path=args.trg_vocab_fpath, reverse=True) def post_process_seq(seq, diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index 7029d35803138ea497aeef476a47009e6a78f37b..5daa70a2336cd5b3d8e1c9568174832219d3a9c6 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -136,10 +136,10 @@ class DataReader(object): :param seed: The seed for random. :type seed: int """ - self._src_vocab = self._load_dict(src_vocab_fpath) + self._src_vocab = self.load_dict(src_vocab_fpath) self._only_src = True if trg_vocab_fpath is not None: - self._trg_vocab = self._load_dict(trg_vocab_fpath) + self._trg_vocab = self.load_dict(trg_vocab_fpath) self._only_src = False self._pool_size = pool_size self._batch_size = batch_size @@ -237,7 +237,8 @@ class DataReader(object): return src_seq_words, trg_seq_words - def _load_dict(self, dict_path, reverse=False): + @staticmethod + def load_dict(dict_path, reverse=False): word_dict = {} with open(dict_path, "r") as fdict: for idx, line in enumerate(fdict): diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index c65ec8bbc36cb17c54e62196580d0597f38c6d19..90fa75728679f1153daf5e9533e63b9c1ae64b9d 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -79,7 +79,17 @@ def parse_args(): default=None, nargs=argparse.REMAINDER) args = parser.parse_args() - merge_cfg_from_list(args.opts, [TrainTaskConfig, ModelHyperParams]) + # Append args related to dict + src_dict = reader.DataReader.load_dict(args.src_vocab_fpath) + trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath) + dict_args = [ + "src_vocab_size", str(len(src_dict)), "trg_vocab_size", + str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]), + "eos_idx", str(src_dict[args.special_token[1]]), "unk_idx", + str(src_dict[args.special_token[2]]) + ] + merge_cfg_from_list(args.opts + dict_args, + [TrainTaskConfig, ModelHyperParams]) return args