diff --git a/python/paddle/v2/dataset/__init__.py b/python/paddle/v2/dataset/__init__.py index 82f11a7c41149c2231130dc7c2205debb643aa89..6c371d3c9bdee94a91b9a48ff7c4a006c8d7eb21 100644 --- a/python/paddle/v2/dataset/__init__.py +++ b/python/paddle/v2/dataset/__init__.py @@ -20,8 +20,9 @@ import movielens import conll05 import uci_housing import sentiment +import wmt14 __all__ = [ 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' - 'uci_housing' + 'uci_housing', 'wmt14' ] diff --git a/python/paddle/v2/dataset/wmt14.py b/python/paddle/v2/dataset/wmt14.py new file mode 100644 index 0000000000000000000000000000000000000000..9904848b5d3ef95dc331fc0ba1a98f29f8b1dfeb --- /dev/null +++ b/python/paddle/v2/dataset/wmt14.py @@ -0,0 +1,142 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +wmt14 dataset +""" +import paddle.v2.dataset.common +import tarfile +import os.path +import itertools + +__all__ = ['train', 'test', 'build_dict'] + +URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz' +MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5' +URL_TRAIN = 'http://localhost:8000/train.tgz' +MD5_TRAIN = '72de99da2830ea5a3a2c4eb36092bbc7' + + +def word_count(f, word_freq=None): + add = paddle.v2.dataset.common.dict_add + if word_freq == None: + word_freq = {} + + for l in f: + for w in l.strip().split(): + add(word_freq, w) + add(word_freq, '') + add(word_freq, '') + + return word_freq + + +def get_word_dix(word_freq): + TYPO_FREQ = 50 + word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items()) + word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) + words, _ = list(zip(*word_freq_sorted)) + word_idx = dict(zip(words, xrange(len(words)))) + word_idx[''] = len(words) + return word_idx + + +def get_word_freq(train, dev): + word_freq = word_count(train, word_count(dev)) + if '' in word_freq: + # remove for now, since we will set it as last index + del word_freq[''] + return word_freq + + +def build_dict(): + base_dir = './wmt14-data' + train_en_filename = base_dir + '/train/train.en' + train_fr_filename = base_dir + '/train/train.fr' + dev_en_filename = base_dir + '/dev/ntst1213.en' + dev_fr_filename = base_dir + '/dev/ntst1213.fr' + + if not os.path.exists(train_en_filename) or not os.path.exists( + train_fr_filename): + with tarfile.open( + paddle.v2.dataset.common.download(URL_TRAIN, 'wmt14', + MD5_TRAIN)) as tf: + tf.extractall(base_dir) + + if not os.path.exists(dev_en_filename) or not os.path.exists( + dev_fr_filename): + with tarfile.open( + paddle.v2.dataset.common.download(URL_DEV_TEST, 'wmt14', + MD5_DEV_TEST)) as tf: + tf.extractall(base_dir) + + f_en = open(train_en_filename) + f_fr = open(train_fr_filename) + f_en_dev = open(dev_en_filename) + f_fr_dev = open(dev_fr_filename) + + word_freq_en = get_word_freq(f_en, f_en_dev) + word_freq_fr = get_word_freq(f_fr, f_fr_dev) + + f_en.close() + f_fr.close() + f_en_dev.close() + f_fr_dev.close() + + return get_word_dix(word_freq_en), get_word_dix(word_freq_fr) + + +def reader_creator(directory, path_en, path_fr, URL, MD5, dict_en, dict_fr): + def reader(): + if not os.path.exists(path_en) or not os.path.exists(path_fr): + with tarfile.open( + paddle.v2.dataset.common.download(URL, 'wmt14', MD5)) as tf: + tf.extractall(directory) + + f_en = open(path_en) + f_fr = open(path_fr) + UNK_en = dict_en[''] + UNK_fr = dict_fr[''] + + for en, fr in itertools.izip(f_en, f_fr): + src_ids = [dict_en.get(w, UNK_en) for w in en.strip().split()] + tar_ids = [ + dict_fr.get(w, UNK_fr) + for w in [''] + fr.strip().split() + [''] + ] + + # remove sequence whose length > 80 in training mode + if len(src_ids) == 0 or len(tar_ids) <= 1 or len( + src_ids) > 80 or len(tar_ids) > 80: + continue + + yield src_ids, tar_ids[:-1], tar_ids[1:] + + f_en.close() + f_fr.close() + + return reader + + +def train(dict_en, dict_fr): + directory = './wmt14-data' + return reader_creator(directory, directory + '/train/train.en', + directory + '/train/train.fr', URL_TRAIN, MD5_TRAIN, + dict_en, dict_fr) + + +def test(dict_en, dict_fr): + directory = './wmt14-data' + return reader_creator(directory, directory + '/dev/ntst1213.en', + directory + '/dev/ntst1213.fr', URL_DEV_TEST, + MD5_DEV_TEST, dict_en, dict_fr)