diff --git a/python/paddle/dataset/wmt16.py b/python/paddle/dataset/wmt16.py index aa66696fae7d3adb44511417edf4a92b82a9151b..1052d24c57b79e1db921f59bb6ea6ecdc87a7f81 100644 --- a/python/paddle/dataset/wmt16.py +++ b/python/paddle/dataset/wmt16.py @@ -71,15 +71,16 @@ def __build_dict(tar_file, dict_size, save_path, lang): for w in sen.split(): word_dict[w] += 1 - with open(save_path, "w") as fout: - fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) + with open(save_path, "wb") as fout: + fout.write( + cpt.to_bytes("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))) for idx, word in enumerate( sorted( six.iteritems(word_dict), key=lambda x: x[1], reverse=True)): if idx + 3 == dict_size: break - fout.write(word[0].encode('utf-8')) - fout.write('\n') + fout.write(cpt.to_bytes(word[0])) + fout.write(cpt.to_bytes('\n')) def __load_dict(tar_file, dict_size, lang, reverse=False):