提交 2f16f47e 编写于 作者: Y Yu Yang

Fix dataset wmt16

上级 6ca37448
...@@ -78,7 +78,8 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -78,7 +78,8 @@ def __build_dict(tar_file, dict_size, save_path, lang):
six.iteritems(word_dict), key=lambda x: x[1], six.iteritems(word_dict), key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0])) fout.write(word[0].encode('utf-8'))
fout.write('\n')
def __load_dict(tar_file, dict_size, lang, reverse=False): def __load_dict(tar_file, dict_size, lang, reverse=False):
......
...@@ -72,7 +72,8 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -72,7 +72,8 @@ def __build_dict(tar_file, dict_size, save_path, lang):
sorted( sorted(
word_dict.iteritems(), key=lambda x: x[1], reverse=True)): word_dict.iteritems(), key=lambda x: x[1], reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0])) fout.write(word[0].encode('utf-8'))
fout.write('\n')
def __load_dict(tar_file, dict_size, lang, reverse=False): def __load_dict(tar_file, dict_size, lang, reverse=False):
...@@ -300,8 +301,10 @@ def get_dict(lang, dict_size, reverse=False): ...@@ -300,8 +301,10 @@ def get_dict(lang, dict_size, reverse=False):
dict: The word dictionary for the specific language. dict: The word dictionary for the specific language.
""" """
if lang == "en": dict_size = min(dict_size, TOTAL_EN_WORDS) if lang == "en":
else: dict_size = min(dict_size, TOTAL_DE_WORDS) dict_size = min(dict_size, TOTAL_EN_WORDS)
else:
dict_size = min(dict_size, TOTAL_DE_WORDS)
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size)) "wmt16/%s_%d.dict" % (lang, dict_size))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册