diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index 1e54a4999b43aa609ddea20b188fead109a649b4..5104e29051e4480f3a7eb18421f1b519841b009b 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -50,8 +50,8 @@ UNK = "" UNK_IDX = 2 -def __read_to_dict__(tar_file, dict_size): - def __to_dict__(fd, size): +def __read_to_dict(tar_file, dict_size): + def __to_dict(fd, size): out_dict = dict() for line_count, line in enumerate(fd): if line_count < size: @@ -66,19 +66,19 @@ def __read_to_dict__(tar_file, dict_size): if each_item.name.endswith("src.dict") ] assert len(names) == 1 - src_dict = __to_dict__(f.extractfile(names[0]), dict_size) + src_dict = __to_dict(f.extractfile(names[0]), dict_size) names = [ each_item.name for each_item in f if each_item.name.endswith("trg.dict") ] assert len(names) == 1 - trg_dict = __to_dict__(f.extractfile(names[0]), dict_size) + trg_dict = __to_dict(f.extractfile(names[0]), dict_size) return src_dict, trg_dict def reader_creator(tar_file, file_name, dict_size): def reader(): - src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) + src_dict, trg_dict = __read_to_dict(tar_file, dict_size) with tarfile.open(tar_file, mode='r') as f: names = [ each_item.name for each_item in f @@ -160,7 +160,7 @@ def get_dict(dict_size, reverse=True): # if reverse = False, return dict = {'a':'001', 'b':'002', ...} # else reverse = true, return dict = {'001':'a', '002':'b', ...} tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) - src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) + src_dict, trg_dict = __read_to_dict(tar_file, dict_size) if reverse: src_dict = {v: k for k, v in src_dict.items()} trg_dict = {v: k for k, v in trg_dict.items()} diff --git a/python/paddle/v2/dataset/wmt16.py b/python/paddle/v2/dataset/wmt16.py index a1899f20b55e27c70e4c9031f3149a81249506de..bbc28a2da99052308471931122946d0d96b54da5 100644 --- a/python/paddle/v2/dataset/wmt16.py +++ b/python/paddle/v2/dataset/wmt16.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -ACL2016 Multimodal Machine Translation. Please see this websit for more details: -http://www.statmt.org/wmt16/multimodal-task.html#task1 +ACL2016 Multimodal Machine Translation. Please see this website for more +details: http://www.statmt.org/wmt16/multimodal-task.html#task1 If you use the dataset created for your task, please cite the following paper: Multi30K: Multilingual English-German Image Descriptions. @@ -56,7 +56,7 @@ END_MARK = "" UNK_MARK = "" -def __build_dict__(tar_file, dict_size, save_path, lang): +def __build_dict(tar_file, dict_size, save_path, lang): word_dict = defaultdict(int) with tarfile.open(tar_file, mode="r") as f: for line in f.extractfile("wmt16/train"): @@ -75,12 +75,12 @@ def __build_dict__(tar_file, dict_size, save_path, lang): fout.write("%s\n" % (word[0])) -def __load_dict__(tar_file, dict_size, lang, reverse=False): +def __load_dict(tar_file, dict_size, lang, reverse=False): dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16/%s_%d.dict" % (lang, dict_size)) if not os.path.exists(dict_path) or ( len(open(dict_path, "r").readlines()) != dict_size): - __build_dict__(tar_file, dict_size, dict_path, lang) + __build_dict(tar_file, dict_size, dict_path, lang) word_dict = {} with open(dict_path, "r") as fdict: @@ -92,7 +92,7 @@ def __load_dict__(tar_file, dict_size, lang, reverse=False): return word_dict -def __get_dict_size__(src_dict_size, trg_dict_size, src_lang): +def __get_dict_size(src_dict_size, trg_dict_size, src_lang): src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else TOTAL_DE_WORDS)) trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else @@ -102,9 +102,9 @@ def __get_dict_size__(src_dict_size, trg_dict_size, src_lang): def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang): def reader(): - src_dict = __load_dict__(tar_file, src_dict_size, src_lang) - trg_dict = __load_dict__(tar_file, trg_dict_size, - ("de" if src_lang == "en" else "en")) + src_dict = __load_dict(tar_file, src_dict_size, src_lang) + trg_dict = __load_dict(tar_file, trg_dict_size, + ("de" if src_lang == "en" else "en")) # the indice for start mark, end mark, and unk are the same in source # language and target language. Here uses the source language @@ -173,8 +173,8 @@ def train(src_dict_size, trg_dict_size, src_lang="en"): assert (src_lang in ["en", "de"], ("An error language type. Only support: " "en (for English); de(for Germany)")) - src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, - trg_dict_size, src_lang) + src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, + src_lang) return reader_creator( tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, @@ -222,8 +222,8 @@ def test(src_dict_size, trg_dict_size, src_lang="en"): ("An error language type. " "Only support: en (for English); de(for Germany)")) - src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, - trg_dict_size, src_lang) + src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, + src_lang) return reader_creator( tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, @@ -269,8 +269,8 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"): assert (src_lang in ["en", "de"], ("An error language type. " "Only support: en (for English); de(for Germany)")) - src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, - trg_dict_size, src_lang) + src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, + src_lang) return reader_creator( tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, @@ -308,7 +308,7 @@ def get_dict(lang, dict_size, reverse=False): "Please invoke paddle.dataset.wmt16.train/test/validation " "first to build the dictionary.") tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz") - return __load_dict__(tar_file, dict_size, lang, reverse) + return __load_dict(tar_file, dict_size, lang, reverse) def fetch():