wmt14.py 4.8 KB
Newer Older
H
Helin Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Q
qijun 已提交
15 16
WMT14 dataset.
The original WMT14 dataset is too large and a small set of data for set is provided.
Q
qijun 已提交
17
This module will download dataset from
Q
qijun 已提交
18 19 20
http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and
parse train/test set into paddle reader creators.

H
Helin Wang 已提交
21
"""
Q
qiaolongfei 已提交
22
import tarfile
L
Luo Tao 已提交
23
import gzip
Q
qiaolongfei 已提交
24

25
from paddle.v2.dataset.common import download
L
Luo Tao 已提交
26
from paddle.v2.parameters import Parameters
H
Helin Wang 已提交
27 28 29 30 31

__all__ = ['train', 'test', 'build_dict']

URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
32
# this is a small set of data for test. The original data is too large and will be add later.
H
Helin Wang 已提交
33
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
Q
qiaolongfei 已提交
34
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
L
Luo Tao 已提交
35 36 37
# this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '6b097d23e15654608c6f74923e975535'
Q
qiaolongfei 已提交
38 39 40 41 42 43

START = "<s>"
END = "<e>"
UNK = "<unk>"
UNK_IDX = 2

Q
qiaolongfei 已提交
44 45 46

def __read_to_dict__(tar_file, dict_size):
    def __to_dict__(fd, size):
Q
qiaolongfei 已提交
47
        out_dict = dict()
Q
qiaolongfei 已提交
48 49
        for line_count, line in enumerate(fd):
            if line_count < size:
Q
qiaolongfei 已提交
50 51 52
                out_dict[line.strip()] = line_count
            else:
                break
Q
qiaolongfei 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
        return out_dict

    with tarfile.open(tar_file, mode='r') as f:
        names = [
            each_item.name for each_item in f
            if each_item.name.endswith("src.dict")
        ]
        assert len(names) == 1
        src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
        names = [
            each_item.name for each_item in f
            if each_item.name.endswith("trg.dict")
        ]
        assert len(names) == 1
        trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
        return src_dict, trg_dict


def reader_creator(tar_file, file_name, dict_size):
    def reader():
        src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
        with tarfile.open(tar_file, mode='r') as f:
            names = [
                each_item.name for each_item in f
                if each_item.name.endswith(file_name)
H
Helin Wang 已提交
78
            ]
Q
qiaolongfei 已提交
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
            for name in names:
                for line in f.extractfile(name):
                    line_split = line.strip().split('\t')
                    if len(line_split) != 2:
                        continue
                    src_seq = line_split[0]  # one source sequence
                    src_words = src_seq.split()
                    src_ids = [
                        src_dict.get(w, UNK_IDX)
                        for w in [START] + src_words + [END]
                    ]

                    trg_seq = line_split[1]  # one target sequence
                    trg_words = trg_seq.split()
                    trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]

                    # remove sequence whose length > 80 in training mode
                    if len(src_ids) > 80 or len(trg_ids) > 80:
                        continue
                    trg_ids_next = trg_ids + [trg_dict[END]]
                    trg_ids = [trg_dict[START]] + trg_ids

                    yield src_ids, trg_ids, trg_ids_next

    return reader


def train(dict_size):
Q
qijun 已提交
107 108 109
    """
    WMT14 train set creator.

Q
qijun 已提交
110
    It returns a reader creator, each sample in the reader is source language word index
Q
qijun 已提交
111 112 113 114 115
    sequence, target language word index sequence and next word index sequence.

    :return: Train reader creator
    :rtype: callable
    """
Q
qiaolongfei 已提交
116
    return reader_creator(
117
        download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'train/train', dict_size)
Q
qiaolongfei 已提交
118 119 120


def test(dict_size):
Q
qijun 已提交
121 122 123
    """
    WMT14 test set creator.

Q
qijun 已提交
124
    It returns a reader creator, each sample in the reader is source language word index
Q
qijun 已提交
125 126 127 128 129
    sequence, target language word index sequence and next word index sequence.

    :return: Train reader creator
    :rtype: callable
    """
Q
qiaolongfei 已提交
130
    return reader_creator(
131
        download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
Y
Yancey1989 已提交
132 133


L
Luo Tao 已提交
134 135 136 137 138 139 140
def model():
    tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
    with gzip.open(tar_file, 'r') as f:
        parameters = Parameters.from_tar(f)
    return parameters


141 142
def fetch():
    download(URL_TRAIN, 'wmt14', MD5_TRAIN)
L
Luo Tao 已提交
143
    download(URL_MODEL, 'wmt14', MD5_MODEL)