wmt14.py 4.4 KB
Newer Older
H
Helin Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Q
qijun 已提交
15 16
WMT14 dataset.
The original WMT14 dataset is too large and a small set of data for set is provided.
Q
qijun 已提交
17
This module will download dataset from
Q
qijun 已提交
18 19 20
http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and
parse train/test set into paddle reader creators.

H
Helin Wang 已提交
21
"""
Q
qiaolongfei 已提交
22 23
import tarfile

24
from paddle.v2.dataset.common import download
H
Helin Wang 已提交
25 26 27 28 29

__all__ = ['train', 'test', 'build_dict']

URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
30
# this is a small set of data for test. The original data is too large and will be add later.
H
Helin Wang 已提交
31
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
Q
qiaolongfei 已提交
32
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
Q
qiaolongfei 已提交
33 34 35 36 37 38

START = "<s>"
END = "<e>"
UNK = "<unk>"
UNK_IDX = 2

Q
qiaolongfei 已提交
39 40 41

def __read_to_dict__(tar_file, dict_size):
    def __to_dict__(fd, size):
Q
qiaolongfei 已提交
42
        out_dict = dict()
Q
qiaolongfei 已提交
43 44
        for line_count, line in enumerate(fd):
            if line_count < size:
Q
qiaolongfei 已提交
45 46 47
                out_dict[line.strip()] = line_count
            else:
                break
Q
qiaolongfei 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        return out_dict

    with tarfile.open(tar_file, mode='r') as f:
        names = [
            each_item.name for each_item in f
            if each_item.name.endswith("src.dict")
        ]
        assert len(names) == 1
        src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
        names = [
            each_item.name for each_item in f
            if each_item.name.endswith("trg.dict")
        ]
        assert len(names) == 1
        trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
        return src_dict, trg_dict


def reader_creator(tar_file, file_name, dict_size):
    def reader():
        src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
        with tarfile.open(tar_file, mode='r') as f:
            names = [
                each_item.name for each_item in f
                if each_item.name.endswith(file_name)
H
Helin Wang 已提交
73
            ]
Q
qiaolongfei 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
            for name in names:
                for line in f.extractfile(name):
                    line_split = line.strip().split('\t')
                    if len(line_split) != 2:
                        continue
                    src_seq = line_split[0]  # one source sequence
                    src_words = src_seq.split()
                    src_ids = [
                        src_dict.get(w, UNK_IDX)
                        for w in [START] + src_words + [END]
                    ]

                    trg_seq = line_split[1]  # one target sequence
                    trg_words = trg_seq.split()
                    trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]

                    # remove sequence whose length > 80 in training mode
                    if len(src_ids) > 80 or len(trg_ids) > 80:
                        continue
                    trg_ids_next = trg_ids + [trg_dict[END]]
                    trg_ids = [trg_dict[START]] + trg_ids

                    yield src_ids, trg_ids, trg_ids_next

    return reader


def train(dict_size):
Q
qijun 已提交
102 103 104
    """
    WMT14 train set creator.

Q
qijun 已提交
105
    It returns a reader creator, each sample in the reader is source language word index
Q
qijun 已提交
106 107 108 109 110
    sequence, target language word index sequence and next word index sequence.

    :return: Train reader creator
    :rtype: callable
    """
Q
qiaolongfei 已提交
111
    return reader_creator(
112
        download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'train/train', dict_size)
Q
qiaolongfei 已提交
113 114 115


def test(dict_size):
Q
qijun 已提交
116 117 118
    """
    WMT14 test set creator.

Q
qijun 已提交
119
    It returns a reader creator, each sample in the reader is source language word index
Q
qijun 已提交
120 121 122 123 124
    sequence, target language word index sequence and next word index sequence.

    :return: Train reader creator
    :rtype: callable
    """
Q
qiaolongfei 已提交
125
    return reader_creator(
126
        download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
Y
Yancey1989 已提交
127 128


129 130
def fetch():
    download(URL_TRAIN, 'wmt14', MD5_TRAIN)