提交 2f344e7f 编写于 作者: Y ying

fix name convention.

上级 9a97c7f7
......@@ -50,8 +50,8 @@ UNK = "<unk>"
UNK_IDX = 2
def __read_to_dict__(tar_file, dict_size):
def __to_dict__(fd, size):
def __read_to_dict(tar_file, dict_size):
def __to_dict(fd, size):
out_dict = dict()
for line_count, line in enumerate(fd):
if line_count < size:
......@@ -66,19 +66,19 @@ def __read_to_dict__(tar_file, dict_size):
if each_item.name.endswith("src.dict")
]
assert len(names) == 1
src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
src_dict = __to_dict(f.extractfile(names[0]), dict_size)
names = [
each_item.name for each_item in f
if each_item.name.endswith("trg.dict")
]
assert len(names) == 1
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
return src_dict, trg_dict
def reader_creator(tar_file, file_name, dict_size):
def reader():
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f:
names = [
each_item.name for each_item in f
......@@ -160,7 +160,7 @@ def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse:
src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()}
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
ACL2016 Multimodal Machine Translation. Please see this websit for more details:
http://www.statmt.org/wmt16/multimodal-task.html#task1
ACL2016 Multimodal Machine Translation. Please see this website for more
details: http://www.statmt.org/wmt16/multimodal-task.html#task1
If you use the dataset created for your task, please cite the following paper:
Multi30K: Multilingual English-German Image Descriptions.
......@@ -56,7 +56,7 @@ END_MARK = "<e>"
UNK_MARK = "<unk>"
def __build_dict__(tar_file, dict_size, save_path, lang):
def __build_dict(tar_file, dict_size, save_path, lang):
word_dict = defaultdict(int)
with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile("wmt16/train"):
......@@ -75,12 +75,12 @@ def __build_dict__(tar_file, dict_size, save_path, lang):
fout.write("%s\n" % (word[0]))
def __load_dict__(tar_file, dict_size, lang, reverse=False):
def __load_dict(tar_file, dict_size, lang, reverse=False):
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or (
len(open(dict_path, "r").readlines()) != dict_size):
__build_dict__(tar_file, dict_size, dict_path, lang)
__build_dict(tar_file, dict_size, dict_path, lang)
word_dict = {}
with open(dict_path, "r") as fdict:
......@@ -92,7 +92,7 @@ def __load_dict__(tar_file, dict_size, lang, reverse=False):
return word_dict
def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):
def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else
TOTAL_DE_WORDS))
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else
......@@ -102,9 +102,9 @@ def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):
def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
def reader():
src_dict = __load_dict__(tar_file, src_dict_size, src_lang)
trg_dict = __load_dict__(tar_file, trg_dict_size,
("de" if src_lang == "en" else "en"))
src_dict = __load_dict(tar_file, src_dict_size, src_lang)
trg_dict = __load_dict(tar_file, trg_dict_size,
("de" if src_lang == "en" else "en"))
# the indice for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language
......@@ -173,8 +173,8 @@ def train(src_dict_size, trg_dict_size, src_lang="en"):
assert (src_lang in ["en", "de"], ("An error language type. Only support: "
"en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
trg_dict_size, src_lang)
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)
return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
......@@ -222,8 +222,8 @@ def test(src_dict_size, trg_dict_size, src_lang="en"):
("An error language type. "
"Only support: en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
trg_dict_size, src_lang)
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)
return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
......@@ -269,8 +269,8 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"):
assert (src_lang in ["en", "de"],
("An error language type. "
"Only support: en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size,
trg_dict_size, src_lang)
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang)
return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
......@@ -308,7 +308,7 @@ def get_dict(lang, dict_size, reverse=False):
"Please invoke paddle.dataset.wmt16.train/test/validation "
"first to build the dictionary.")
tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
return __load_dict__(tar_file, dict_size, lang, reverse)
return __load_dict(tar_file, dict_size, lang, reverse)
def fetch():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册