From 9a97c7f745b0014504f9c52897b72fd58147e70a Mon Sep 17 00:00:00 2001 From: ying Date: Wed, 17 Jan 2018 20:09:44 +0800 Subject: [PATCH] add wmt16 into dataset. --- python/paddle/v2/dataset/__init__.py | 16 +- python/paddle/v2/dataset/common.py | 21 +- python/paddle/v2/dataset/tests/wmt16_test.py | 66 ++++ python/paddle/v2/dataset/wmt14.py | 18 +- python/paddle/v2/dataset/wmt16.py | 348 ++++++++++++++++++ python/paddle/v2/fluid/layers/control_flow.py | 33 +- 6 files changed, 482 insertions(+), 20 deletions(-) create mode 100644 python/paddle/v2/dataset/tests/wmt16_test.py create mode 100644 python/paddle/v2/dataset/wmt16.py diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 90830515c1e..c1acbecd9c3 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -24,11 +24,23 @@ import conll05 import uci_housing import sentiment import wmt14 +import wmt16 import mq2007 import flowers import voc2012 __all__ = [ - 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing', 'wmt14', 'mq2007', 'flowers', 'voc2012' + 'mnist', + 'imikolov', + 'imdb', + 'cifar', + 'movielens', + 'conll05', + 'sentiment' + 'uci_housing', + 'wmt14', + 'wmt16', + 'mq2007', + 'flowers', + 'voc2012', ] diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index fab8a68b0be..9aba35a6481 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -25,8 +25,12 @@ import glob import cPickle as pickle __all__ = [ - 'DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader', - 'convert' + 'DATA_HOME', + 'download', + 'md5file', + 'split', + 'cluster_files_reader', + 'convert', ] DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset') @@ -58,12 +62,15 @@ def md5file(fname): return hash_md5.hexdigest() -def download(url, module_name, md5sum): +def download(url, module_name, md5sum, save_name=None): dirname = os.path.join(DATA_HOME, module_name) if not os.path.exists(dirname): os.makedirs(dirname) - filename = os.path.join(dirname, url.split('/')[-1]) + filename = os.path.join(dirname, + url.split('/')[-1] + if save_name is None else save_name) + retry = 0 retry_limit = 3 while not (os.path.exists(filename) and md5file(filename) == md5sum): @@ -196,9 +203,11 @@ def convert(output_path, reader, line_count, name_prefix): Convert data from reader to recordio format files. :param output_path: directory in which output files will be saved. - :param reader: a data reader, from which the convert program will read data instances. + :param reader: a data reader, from which the convert program will read + data instances. :param name_prefix: the name prefix of generated files. - :param max_lines_to_shuffle: the max lines numbers to shuffle before writing. + :param max_lines_to_shuffle: the max lines numbers to shuffle before + writing. """ assert line_count >= 1 diff --git a/python/paddle/v2/dataset/tests/wmt16_test.py b/python/paddle/v2/dataset/tests/wmt16_test.py new file mode 100644 index 00000000000..cef6c3216e7 --- /dev/null +++ b/python/paddle/v2/dataset/tests/wmt16_test.py @@ -0,0 +1,66 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2.dataset.wmt16 +import unittest + + +class TestWMT16(unittest.TestCase): + def checkout_one_sample(self, sample): + # train data has 3 field: source language word indices, + # target language word indices, and target next word indices. + self.assertEqual(len(sample), 3) + + # test start mark and end mark in source word indices. + self.assertEqual(sample[0][0], 0) + self.assertEqual(sample[0][-1], 1) + + # test start mask in target word indices + self.assertEqual(sample[1][0], 0) + + # test en mask in target next word indices + self.assertEqual(sample[2][-1], 1) + + def test_train(self): + for idx, sample in enumerate( + paddle.v2.dataset.wmt16.train( + src_dict_size=100000, trg_dict_size=100000)()): + if idx >= 10: break + self.checkout_one_sample(sample) + + def test_test(self): + for idx, sample in enumerate( + paddle.v2.dataset.wmt16.test( + src_dict_size=1000, trg_dict_size=1000)()): + if idx >= 10: break + self.checkout_one_sample(sample) + + def test_val(self): + for idx, sample in enumerate( + paddle.v2.dataset.wmt16.validation( + src_dict_size=1000, trg_dict_size=1000)()): + if idx >= 10: break + self.checkout_one_sample(sample) + + def test_get_dict(self): + dict_size = 1000 + word_dict = paddle.v2.dataset.wmt16.get_dict("en", dict_size, True) + self.assertEqual(len(word_dict), dict_size) + self.assertEqual(word_dict[0], "") + self.assertEqual(word_dict[1], "") + self.assertEqual(word_dict[2], "") + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py index 95a35d97ce9..1e54a4999b4 100644 --- a/python/paddle/v2/dataset/wmt14.py +++ b/python/paddle/v2/dataset/wmt14.py @@ -25,12 +25,20 @@ import gzip import paddle.v2.dataset.common from paddle.v2.parameters import Parameters -__all__ = ['train', 'test', 'build_dict', 'convert'] - -URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' +__all__ = [ + 'train', + 'test', + 'get_dict', + 'convert', +] + +URL_DEV_TEST = ('http://www-lium.univ-lemans.fr/~schwenk/' + 'cslm_joint_paper/data/dev+test.tgz') MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' -# this is a small set of data for test. The original data is too large and will be add later. -URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz' +# this is a small set of data for test. The original data is too large and +# will be add later. +URL_TRAIN = ('http://paddlepaddle.cdn.bcebos.com/demo/' + 'wmt_shrinked_data/wmt14.tgz') MD5_TRAIN = '0791583d57d5beb693b9414c5b36798c' # BLEU of this trained model is 26.92 URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz' diff --git a/python/paddle/v2/dataset/wmt16.py b/python/paddle/v2/dataset/wmt16.py new file mode 100644 index 00000000000..a1899f20b55 --- /dev/null +++ b/python/paddle/v2/dataset/wmt16.py @@ -0,0 +1,348 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +ACL2016 Multimodal Machine Translation. Please see this websit for more details: +http://www.statmt.org/wmt16/multimodal-task.html#task1 + +If you use the dataset created for your task, please cite the following paper: +Multi30K: Multilingual English-German Image Descriptions. + +@article{elliott-EtAl:2016:VL16, + author = {{Elliott}, D. and {Frank}, S. and {Sima"an}, K. and {Specia}, L.}, + title = {Multi30K: Multilingual English-German Image Descriptions}, + booktitle = {Proceedings of the 6th Workshop on Vision and Language}, + year = {2016}, + pages = {70--74}, + year = 2016 +} +""" + +import os +import tarfile +import gzip +from collections import defaultdict + +import paddle.v2.dataset.common + +__all__ = [ + "train", + "test", + "validation", + "convert", + "fetch", + "get_dict", +] + +DATA_URL = ("http://cloud.dlnel.org/filepub/" + "?uuid=46a0808e-ddd8-427c-bacd-0dbc6d045fed") +DATA_MD5 = "0c38be43600334966403524a40dcd81e" + +TOTAL_EN_WORDS = 11250 +TOTAL_DE_WORDS = 19220 + +START_MARK = "" +END_MARK = "" +UNK_MARK = "" + + +def __build_dict__(tar_file, dict_size, save_path, lang): + word_dict = defaultdict(int) + with tarfile.open(tar_file, mode="r") as f: + for line in f.extractfile("wmt16/train"): + line_split = line.strip().split("\t") + if len(line_split) != 2: continue + sen = line_split[0] if lang == "en" else line_split[1] + for w in sen.split(): + word_dict[w] += 1 + + with open(save_path, "w") as fout: + fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) + for idx, word in enumerate( + sorted( + word_dict.iteritems(), key=lambda x: x[1], reverse=True)): + if idx + 3 == dict_size: break + fout.write("%s\n" % (word[0])) + + +def __load_dict__(tar_file, dict_size, lang, reverse=False): + dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, + "wmt16/%s_%d.dict" % (lang, dict_size)) + if not os.path.exists(dict_path) or ( + len(open(dict_path, "r").readlines()) != dict_size): + __build_dict__(tar_file, dict_size, dict_path, lang) + + word_dict = {} + with open(dict_path, "r") as fdict: + for idx, line in enumerate(fdict): + if reverse: + word_dict[idx] = line.strip() + else: + word_dict[line.strip()] = idx + return word_dict + + +def __get_dict_size__(src_dict_size, trg_dict_size, src_lang): + src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else + TOTAL_DE_WORDS)) + trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else + TOTAL_ENG_WORDS)) + return src_dict_size, trg_dict_size + + +def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang): + def reader(): + src_dict = __load_dict__(tar_file, src_dict_size, src_lang) + trg_dict = __load_dict__(tar_file, trg_dict_size, + ("de" if src_lang == "en" else "en")) + + # the indice for start mark, end mark, and unk are the same in source + # language and target language. Here uses the source language + # dictionary to determine their indices. + start_id = src_dict[START_MARK] + end_id = src_dict[END_MARK] + unk_id = src_dict[UNK_MARK] + + src_col = 0 if src_lang == "en" else 1 + trg_col = 1 - src_col + + with tarfile.open(tar_file, mode="r") as f: + for line in f.extractfile(file_name): + line_split = line.strip().split("\t") + if len(line_split) != 2: + continue + src_words = line_split[src_col].split() + src_ids = [start_id] + [ + src_dict.get(w, unk_id) for w in src_words + ] + [end_id] + + trg_words = line_split[trg_col].split() + trg_ids = [trg_dict.get(w, unk_id) for w in trg_words] + + trg_ids_next = trg_ids + [end_id] + trg_ids = [start_id] + trg_ids + + yield src_ids, trg_ids, trg_ids_next + + return reader + + +def train(src_dict_size, trg_dict_size, src_lang="en"): + """ + WMT16 train set reader. + + This function returns the reader for train data. Each sample the reader + returns is made up of three fields: the source language word index sequence, + target language word index sequence and next word index sequence. + + + NOTE: + The original like for training data is: + http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz + + paddle.dataset.wmt16 provides a tokenized version of the original dataset by + using moses's tokenization script: + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl + + Args: + src_dict_size(int): Size of the source language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + trg_dict_size(int): Size of the target language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + src_lang(string): A string indicating which language is the source + language. Available options are: "en" for English + and "de" for Germany. + + Returns: + callable: The train reader. + """ + + assert (src_lang in ["en", "de"], ("An error language type. Only support: " + "en (for English); de(for Germany)")) + src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, + trg_dict_size, src_lang) + + return reader_creator( + tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, + "wmt16.tar.gz"), + file_name="wmt16/train", + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang) + + +def test(src_dict_size, trg_dict_size, src_lang="en"): + """ + WMT16 test set reader. + + This function returns the reader for test data. Each sample the reader + returns is made up of three fields: the source language word index sequence, + target language word index sequence and next word index sequence. + + NOTE: + The original like for test data is: + http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/mmt16_task1_test.tar.gz + + paddle.dataset.wmt16 provides a tokenized version of the original dataset by + using moses's tokenization script: + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl + + Args: + src_dict_size(int): Size of the source language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + trg_dict_size(int): Size of the target language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + src_lang(string): A string indicating which language is the source + language. Available options are: "en" for English + and "de" for Germany. + + Returns: + callable: The test reader. + """ + + assert (src_lang in ["en", "de"], + ("An error language type. " + "Only support: en (for English); de(for Germany)")) + + src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, + trg_dict_size, src_lang) + + return reader_creator( + tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, + "wmt16.tar.gz"), + file_name="wmt16/test", + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang) + + +def validation(src_dict_size, trg_dict_size, src_lang="en"): + """ + WMT16 validation set reader. + + This function returns the reader for validation data. Each sample the reader + returns is made up of three fields: the source language word index sequence, + target language word index sequence and next word index sequence. + + NOTE: + The original like for validation data is: + http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz + + paddle.dataset.wmt16 provides a tokenized version of the original dataset by + using moses's tokenization script: + https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl + + Args: + src_dict_size(int): Size of the source language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + trg_dict_size(int): Size of the target language dictionary. Three + special tokens will be added into the dictionary: + for start mark, for end mark, and for + unknown word. + src_lang(string): A string indicating which language is the source + language. Available options are: "en" for English + and "de" for Germany. + + Returns: + callable: The validation reader. + """ + assert (src_lang in ["en", "de"], + ("An error language type. " + "Only support: en (for English); de(for Germany)")) + src_dict_size, trg_dict_size = __get_dict_size__(src_dict_size, + trg_dict_size, src_lang) + + return reader_creator( + tar_file=paddle.v2.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, + "wmt16.tar.gz"), + file_name="wmt16/val", + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang) + + +def get_dict(lang, dict_size, reverse=False): + """ + return the word dictionary for the specified language. + + Args: + lang(string): A string indicating which language is the source + language. Available options are: "en" for English + and "de" for Germany. + dict_size(int): Size of the specified language dictionary. + reverse(bool): If reverse is set to False, the returned python + dictionary will use word as key and use index as value. + If reverse is set to True, the returned python + dictionary will use index as key and word as value. + + Returns: + dict: The word dictionary for the specific language. + """ + + if lang == "en": dict_size = min(dict_size, TOTAL_EN_WORDS) + else: dict_size = min(dict_size, TOTAL_DE_WORDS) + + dict_path = os.path.join(paddle.v2.dataset.common.DATA_HOME, + "wmt16/%s_%d.dict" % (lang, dict_size)) + assert (os.path.exists(dict_path), "Word dictionary does not exist. " + "Please invoke paddle.dataset.wmt16.train/test/validation " + "first to build the dictionary.") + tar_file = os.path.join(paddle.v2.dataset.common.DATA_HOME, "wmt16.tar.gz") + return __load_dict__(tar_file, dict_size, lang, reverse) + + +def fetch(): + """download the entire dataset. + """ + paddle.v4.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, + "wmt16.tar.gz") + + +def convert(path, src_dict_size, trg_dict_size, src_lang): + """Converts dataset to recordio format. + """ + + paddle.v2.dataset.common.convert( + path, + train( + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang), + 1000, + "wmt16_train") + paddle.v2.dataset.common.convert( + path, + test( + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang), + 1000, + "wmt16_test") + paddle.v2.dataset.common.convert( + path, + validation( + src_dict_size=src_dict_size, + trg_dict_size=trg_dict_size, + src_lang=src_lang), + 1000, + "wmt16_validation") diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index e72b22c83f6..b2183ebda1e 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -19,13 +19,32 @@ import contextlib from ..registry import autodoc __all__ = [ - 'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', - 'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While', - 'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array', - 'array_to_lod_tensor', 'increment', 'array_write', 'create_array', - 'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse', - 'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'reorder_lod_tensor_by_rank', - 'ParallelDo', 'Print' + 'split_lod_tensor', + 'merge_lod_tensor', + 'BlockGuard', + 'BlockGuardWithCompletion', + 'StaticRNNMemoryLink', + 'WhileGuard', + 'While', + 'lod_rank_table', + 'max_sequence_len', + 'topk', + 'lod_tensor_to_array', + 'array_to_lod_tensor', + 'increment', + 'array_write', + 'create_array', + 'less_than', + 'array_read', + 'shrink_memory', + 'array_length', + 'IfElse', + 'DynamicRNN', + 'ConditionalBlock', + 'StaticRNN', + 'reorder_lod_tensor_by_rank', + 'ParallelDo', + 'Print', ] -- GitLab