提交 8aa6740f 编写于 作者: G guosheng

Append args related to dict in Transformer

上级 bca3c03d
...@@ -51,7 +51,17 @@ def parse_args(): ...@@ -51,7 +51,17 @@ def parse_args():
default=None, default=None,
nargs=argparse.REMAINDER) nargs=argparse.REMAINDER)
args = parser.parse_args() args = parser.parse_args()
merge_cfg_from_list(args.opts, [InferTaskConfig, ModelHyperParams]) # Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
dict_args = [
"src_vocab_size", str(len(src_dict)), "trg_vocab_size",
str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
"eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]])
]
merge_cfg_from_list(args.opts + dict_args,
[InferTaskConfig, ModelHyperParams])
return args return args
...@@ -351,7 +361,7 @@ def infer(args): ...@@ -351,7 +361,7 @@ def infer(args):
unk_mark=args.special_token[2], unk_mark=args.special_token[2],
clip_last_batch=False) clip_last_batch=False)
trg_idx2word = test_data._load_dict( trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True) dict_path=args.trg_vocab_fpath, reverse=True)
def post_process_seq(seq, def post_process_seq(seq,
......
...@@ -136,10 +136,10 @@ class DataReader(object): ...@@ -136,10 +136,10 @@ class DataReader(object):
:param seed: The seed for random. :param seed: The seed for random.
:type seed: int :type seed: int
""" """
self._src_vocab = self._load_dict(src_vocab_fpath) self._src_vocab = self.load_dict(src_vocab_fpath)
self._only_src = True self._only_src = True
if trg_vocab_fpath is not None: if trg_vocab_fpath is not None:
self._trg_vocab = self._load_dict(trg_vocab_fpath) self._trg_vocab = self.load_dict(trg_vocab_fpath)
self._only_src = False self._only_src = False
self._pool_size = pool_size self._pool_size = pool_size
self._batch_size = batch_size self._batch_size = batch_size
...@@ -237,7 +237,8 @@ class DataReader(object): ...@@ -237,7 +237,8 @@ class DataReader(object):
return src_seq_words, trg_seq_words return src_seq_words, trg_seq_words
def _load_dict(self, dict_path, reverse=False): @staticmethod
def load_dict(dict_path, reverse=False):
word_dict = {} word_dict = {}
with open(dict_path, "r") as fdict: with open(dict_path, "r") as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
......
...@@ -79,7 +79,17 @@ def parse_args(): ...@@ -79,7 +79,17 @@ def parse_args():
default=None, default=None,
nargs=argparse.REMAINDER) nargs=argparse.REMAINDER)
args = parser.parse_args() args = parser.parse_args()
merge_cfg_from_list(args.opts, [TrainTaskConfig, ModelHyperParams]) # Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
dict_args = [
"src_vocab_size", str(len(src_dict)), "trg_vocab_size",
str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
"eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]])
]
merge_cfg_from_list(args.opts + dict_args,
[TrainTaskConfig, ModelHyperParams])
return args return args
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册