提交 2f344e7f 编写于 作者: Y ying

fix name convention.

上级 9a97c7f7
...@@ -50,8 +50,8 @@ UNK = "<unk>" ...@@ -50,8 +50,8 @@ UNK = "<unk>"
UNK_IDX = 2 UNK_IDX = 2
def __read_to_dict__(tar_file, dict_size): def __read_to_dict(tar_file, dict_size):
def __to_dict__(fd, size): def __to_dict(fd, size):
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fd): for line_count, line in enumerate(fd):
if line_count < size: if line_count < size:
...@@ -66,19 +66,19 @@ def __read_to_dict__(tar_file, dict_size): ...@@ -66,19 +66,19 @@ def __read_to_dict__(tar_file, dict_size):
if each_item.name.endswith("src.dict") if each_item.name.endswith("src.dict")
] ]
assert len(names) == 1 assert len(names) == 1
src_dict = __to_dict__(f.extractfile(names[0]), dict_size) src_dict = __to_dict(f.extractfile(names[0]), dict_size)
names = [ names = [
each_item.name for each_item in f each_item.name for each_item in f
if each_item.name.endswith("trg.dict") if each_item.name.endswith("trg.dict")
] ]
assert len(names) == 1 assert len(names) == 1
trg_dict = __to_dict__(f.extractfile(names[0]), dict_size) trg_dict = __to_dict(f.extractfile(names[0]), dict_size)
return src_dict, trg_dict return src_dict, trg_dict
def reader_creator(tar_file, file_name, dict_size): def reader_creator(tar_file, file_name, dict_size):
def reader(): def reader():
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f: with tarfile.open(tar_file, mode='r') as f:
names = [ names = [
each_item.name for each_item in f each_item.name for each_item in f
...@@ -160,7 +160,7 @@ def get_dict(dict_size, reverse=True): ...@@ -160,7 +160,7 @@ def get_dict(dict_size, reverse=True):
# if reverse = False, return dict = {'a':'001', 'b':'002', ...} # if reverse = False, return dict = {'a':'001', 'b':'002', ...}
# else reverse = true, return dict = {'001':'a', '002':'b', ...} # else reverse = true, return dict = {'001':'a', '002':'b', ...}
tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) tar_file = paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict__(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse: if reverse:
src_dict = {v: k for k, v in src_dict.items()} src_dict = {v: k for k, v in src_dict.items()}
trg_dict = {v: k for k, v in trg_dict.items()} trg_dict = {v: k for k, v in trg_dict.items()}
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
ACL2016 Multimodal Machine Translation. Please see this websit for more details: ACL2016 Multimodal Machine Translation. Please see this website for more
http://www.statmt.org/wmt16/multimodal-task.html#task1 details: http://www.statmt.org/wmt16/multimodal-task.html#task1
If you use the dataset created for your task, please cite the following paper: If you use the dataset created for your task, please cite the following paper:
Multi30K: Multilingual English-German Image Descriptions. Multi30K: Multilingual English-German Image Descriptions.
...@@ -56,7 +56,7 @@ END_MARK = "<e>" ...@@ -56,7 +56,7 @@ END_MARK = "<e>"
UNK_MARK = "<unk>" UNK_MARK = "<unk>"
def __build_dict__(tar_file, dict_size, save_path, lang): def __build_dict(tar_file, dict_size, save_path, lang):
word_dict = defaultdict(int) word_dict = defaultdict(int)
with tarfile.open(tar_file, mode="r") as f: with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile("wmt16/train"): for line in f.extractfile("wmt16/train"):
...@@ -75,12 +75,12 @@ def __build_dict__(tar_file, dict_size, save_path, lang): ...@@ -75,12 +75,12 @@ def __build_dict__(tar_file, dict_size, save_path, lang):
fout.write("%s\n" % (word[0])) fout.write("%s\n" % (word[0]))
def __load_dict__(tar_file, dict_size, lang, reverse=False): def __load_dict(tar_file, dict_size, lang, reverse=False):
dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size)) "wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or ( if not os.path.exists(dict_path) or (
len(open(dict_path, "r").readlines()) != dict_size): len(open(dict_path, "r").readlines()) != dict_size):
__build_dict__(tar_file, dict_size, dict_path, lang) __build_dict(tar_file, dict_size, dict_path, lang)
word_dict = {} word_dict = {}
with open(dict_path, "r") as fdict: with open(dict_path, "r") as fdict:
...@@ -92,7 +92,7 @@ def __load_dict__(tar_file, dict_size, lang, reverse=False): ...@@ -92,7 +92,7 @@ def __load_dict__(tar_file, dict_size, lang, reverse=False):
return word_dict return word_dict
def __get_dict_size__(src_dict_size, trg_dict_size, src_lang): def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else
TOTAL_DE_WORDS)) TOTAL_DE_WORDS))
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else
...@@ -102,9 +102,9 @@ def __get_dict_size__(src_dict_size, trg_dict_size, src_lang): ...@@ -102,9 +102,9 @@ def __get_dict_size__(src_dict_size, trg_dict_size, src_lang):
def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang): def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
def reader(): def reader():
src_dict = __load_dict__(tar_file, src_dict_size, src_lang) src_dict = __load_dict(tar_file, src_dict_size, src_lang)
trg_dict = __load_dict__(tar_file, trg_dict_size, trg_dict = __load_dict(tar_file, trg_dict_size,
("de" if src_lang == "en" else "en")) ("de" if src_lang == "en" else "en"))
# the indice for start mark, end mark, and unk are the same in source # the indice for start mark, end mark, and unk are the same in source
# language and target language. Here uses the source language # language and target language. Here uses the source language
...@@ -173,8 +173,8 @@ def train(src_dict_size, trg_dict_size, src_lang="en"): ...@@ -173,8 +173,8 @@ def train(src_dict_size, trg_dict_size, src_lang="en"):
assert (src_lang in ["en", "de"], ("An error language type. Only support: " assert (src_lang in ["en", "de"], ("An error language type. Only support: "
"en (for English); de(for Germany)")) "en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
trg_dict_size, src_lang) src_lang)
return reader_creator( return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
...@@ -222,8 +222,8 @@ def test(src_dict_size, trg_dict_size, src_lang="en"): ...@@ -222,8 +222,8 @@ def test(src_dict_size, trg_dict_size, src_lang="en"):
("An error language type. " ("An error language type. "
"Only support: en (for English); de(for Germany)")) "Only support: en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
trg_dict_size, src_lang) src_lang)
return reader_creator( return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
...@@ -269,8 +269,8 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"): ...@@ -269,8 +269,8 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"):
assert (src_lang in ["en", "de"], assert (src_lang in ["en", "de"],
("An error language type. " ("An error language type. "
"Only support: en (for English); de(for Germany)")) "Only support: en (for English); de(for Germany)"))
src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
trg_dict_size, src_lang) src_lang)
return reader_creator( return reader_creator(
tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5,
...@@ -308,7 +308,7 @@ def get_dict(lang, dict_size, reverse=False): ...@@ -308,7 +308,7 @@ def get_dict(lang, dict_size, reverse=False):
"Please invoke paddle.dataset.wmt16.train/test/validation " "Please invoke paddle.dataset.wmt16.train/test/validation "
"first to build the dictionary.") "first to build the dictionary.")
tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz") tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz")
return __load_dict__(tar_file, dict_size, lang, reverse) return __load_dict(tar_file, dict_size, lang, reverse)
def fetch(): def fetch():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册