wmt14.py 4.0 KB
Newer Older
H
Helin Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
wmt14 dataset
"""
Q
qiaolongfei 已提交
17
import tarfile
L
Luo Tao 已提交
18
import gzip
Q
qiaolongfei 已提交
19

20
from paddle.v2.dataset.common import download
L
Luo Tao 已提交
21
from paddle.v2.parameters import Parameters
H
Helin Wang 已提交
22 23 24 25 26

__all__ = ['train', 'test', 'build_dict']

URL_DEV_TEST = 'http://www-lium.univ-lemans.fr/~schwenk/cslm_joint_paper/data/dev+test.tgz'
MD5_DEV_TEST = '7d7897317ddd8ba0ae5c5fa7248d3ff5'
27
# this is a small set of data for test. The original data is too large and will be add later.
H
Helin Wang 已提交
28
URL_TRAIN = 'http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz'
Q
qiaolongfei 已提交
29
MD5_TRAIN = 'a755315dd01c2c35bde29a744ede23a6'
L
Luo Tao 已提交
30 31 32
# this is the pretrained model, whose bleu = 26.92
URL_MODEL = 'http://paddlepaddle.bj.bcebos.com/demo/wmt_14/wmt14_model.tar.gz'
MD5_MODEL = '6b097d23e15654608c6f74923e975535'
Q
qiaolongfei 已提交
33 34 35 36 37 38

START = "<s>"
END = "<e>"
UNK = "<unk>"
UNK_IDX = 2

Q
qiaolongfei 已提交
39 40 41

def __read_to_dict__(tar_file, dict_size):
    def __to_dict__(fd, size):
Q
qiaolongfei 已提交
42
        out_dict = dict()
Q
qiaolongfei 已提交
43 44
        for line_count, line in enumerate(fd):
            if line_count < size:
Q
qiaolongfei 已提交
45 46 47
                out_dict[line.strip()] = line_count
            else:
                break
Q
qiaolongfei 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
        return out_dict

    with tarfile.open(tar_file, mode='r') as f:
        names = [
            each_item.name for each_item in f
            if each_item.name.endswith("src.dict")
        ]
        assert len(names) == 1
        src_dict = __to_dict__(f.extractfile(names[0]), dict_size)
        names = [
            each_item.name for each_item in f
            if each_item.name.endswith("trg.dict")
        ]
        assert len(names) == 1
        trg_dict = __to_dict__(f.extractfile(names[0]), dict_size)
        return src_dict, trg_dict


def reader_creator(tar_file, file_name, dict_size):
    def reader():
        src_dict, trg_dict = __read_to_dict__(tar_file, dict_size)
        with tarfile.open(tar_file, mode='r') as f:
            names = [
                each_item.name for each_item in f
                if each_item.name.endswith(file_name)
H
Helin Wang 已提交
73
            ]
Q
qiaolongfei 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
            for name in names:
                for line in f.extractfile(name):
                    line_split = line.strip().split('\t')
                    if len(line_split) != 2:
                        continue
                    src_seq = line_split[0]  # one source sequence
                    src_words = src_seq.split()
                    src_ids = [
                        src_dict.get(w, UNK_IDX)
                        for w in [START] + src_words + [END]
                    ]

                    trg_seq = line_split[1]  # one target sequence
                    trg_words = trg_seq.split()
                    trg_ids = [trg_dict.get(w, UNK_IDX) for w in trg_words]

                    # remove sequence whose length > 80 in training mode
                    if len(src_ids) > 80 or len(trg_ids) > 80:
                        continue
                    trg_ids_next = trg_ids + [trg_dict[END]]
                    trg_ids = [trg_dict[START]] + trg_ids

                    yield src_ids, trg_ids, trg_ids_next

    return reader


def train(dict_size):
    return reader_creator(
103
        download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'train/train', dict_size)
Q
qiaolongfei 已提交
104 105 106 107


def test(dict_size):
    return reader_creator(
108
        download(URL_TRAIN, 'wmt14', MD5_TRAIN), 'test/test', dict_size)
Y
Yancey1989 已提交
109 110


L
Luo Tao 已提交
111 112 113 114 115 116 117
def model():
    tar_file = download(URL_MODEL, 'wmt14', MD5_MODEL)
    with gzip.open(tar_file, 'r') as f:
        parameters = Parameters.from_tar(f)
    return parameters


118 119
def fetch():
    download(URL_TRAIN, 'wmt14', MD5_TRAIN)
L
Luo Tao 已提交
120
    download(URL_MODEL, 'wmt14', MD5_MODEL)