imikolov.py 5.5 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
"""
Q
qijun 已提交
15
imikolov's simple dataset.
Y
Yu Yang 已提交
16

M
minqiyang 已提交
17
This module will download dataset from
Q
qijun 已提交
18 19
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
into paddle reader creators.
20
"""
21

22
import paddle.dataset.common
23
import paddle.utils.deprecated as deprecated
24
import collections
25 26
import tarfile

27 28
__all__ = []

29
# URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
30
URL = 'https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz'
31 32 33
MD5 = '30177ea32e27c525793142b6bf2c8e2d'


34 35 36 37 38
class DataType(object):
    NGRAM = 1
    SEQ = 2


39
def word_count(f, word_freq=None):
40 41
    if word_freq is None:
        word_freq = collections.defaultdict(int)
42 43 44

    for l in f:
        for w in l.strip().split():
45 46 47
            word_freq[w] += 1
        word_freq['<s>'] += 1
        word_freq['<e>'] += 1
48 49 50 51

    return word_freq


Y
yangyaming 已提交
52
def build_dict(min_word_freq=50):
Q
qijun 已提交
53
    """
Q
qijun 已提交
54 55
    Build a word dictionary from the corpus,  Keys of the dictionary are words,
    and values are zero-based IDs of these words.
Q
qijun 已提交
56
    """
57 58
    train_filename = './simple-examples/data/ptb.train.txt'
    test_filename = './simple-examples/data/ptb.valid.txt'
59
    with tarfile.open(
60 61 62 63
        paddle.dataset.common.download(
            paddle.dataset.imikolov.URL, 'imikolov', paddle.dataset.imikolov.MD5
        )
    ) as tf:
64 65 66
        trainf = tf.extractfile(train_filename)
        testf = tf.extractfile(test_filename)
        word_freq = word_count(testf, word_count(trainf))
67 68 69
        if '<unk>' in word_freq:
            # remove <unk> for now, since we will set it as last index
            del word_freq['<unk>']
70

71
        word_freq = [x for x in word_freq.items() if x[1] > min_word_freq]
72

73 74
        word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
        words, _ = list(zip(*word_freq_sorted))
75
        word_idx = dict(list(zip(words, range(len(words)))))
Y
Yi Wang 已提交
76
        word_idx['<unk>'] = len(words)
77 78 79 80

    return word_idx


81
def reader_creator(filename, word_idx, n, data_type):
82 83
    def reader():
        with tarfile.open(
84 85 86 87 88 89
            paddle.dataset.common.download(
                paddle.dataset.imikolov.URL,
                'imikolov',
                paddle.dataset.imikolov.MD5,
            )
        ) as tf:
90 91
            f = tf.extractfile(filename)

Y
Yi Wang 已提交
92
            UNK = word_idx['<unk>']
93
            for l in f:
94 95 96 97 98
                if DataType.NGRAM == data_type:
                    assert n > -1, 'Invalid gram length'
                    l = ['<s>'] + l.strip().split() + ['<e>']
                    if len(l) >= n:
                        l = [word_idx.get(w, UNK) for w in l]
99
                        for i in range(n, len(l) + 1):
100
                            yield tuple(l[i - n : i])
101 102
                elif DataType.SEQ == data_type:
                    l = l.strip().split()
Y
Yi Wang 已提交
103
                    l = [word_idx.get(w, UNK) for w in l]
104 105
                    src_seq = [word_idx['<s>']] + l
                    trg_seq = l + [word_idx['<e>']]
106 107
                    if n > 0 and len(src_seq) > n:
                        continue
108 109 110
                    yield src_seq, trg_seq
                else:
                    assert False, 'Unknow data type'
111 112 113 114

    return reader


115 116 117
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
118
    level=1,
119 120
    reason="Please use new dataset API which supports paddle.io.DataLoader",
)
121
def train(word_idx, n, data_type=DataType.NGRAM):
Q
qijun 已提交
122
    """
Q
qijun 已提交
123
    imikolov training set creator.
Q
qijun 已提交
124

Q
qijun 已提交
125
    It returns a reader creator, each sample in the reader is a word ID
Q
qijun 已提交
126 127 128 129
    tuple.

    :param word_idx: word dictionary
    :type word_idx: dict
130
    :param n: sliding window size if type is ngram, otherwise max length of sequence
Q
qijun 已提交
131
    :type n: int
132 133
    :param data_type: data type (ngram or sequence)
    :type data_type: member variable of DataType (NGRAM or SEQ)
Q
qijun 已提交
134
    :return: Training reader creator
Q
qijun 已提交
135 136
    :rtype: callable
    """
137 138 139
    return reader_creator(
        './simple-examples/data/ptb.train.txt', word_idx, n, data_type
    )
140 141


142 143 144
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
145
    level=1,
146 147
    reason="Please use new dataset API which supports paddle.io.DataLoader",
)
148
def test(word_idx, n, data_type=DataType.NGRAM):
Q
qijun 已提交
149 150 151
    """
    imikolov test set creator.

Q
qijun 已提交
152
    It returns a reader creator, each sample in the reader is a word ID
Q
qijun 已提交
153 154 155 156
    tuple.

    :param word_idx: word dictionary
    :type word_idx: dict
157
    :param n: sliding window size if type is ngram, otherwise max length of sequence
Q
qijun 已提交
158
    :type n: int
159 160
    :param data_type: data type (ngram or sequence)
    :type data_type: member variable of DataType (NGRAM or SEQ)
Q
qijun 已提交
161
    :return: Test reader creator
Q
qijun 已提交
162 163
    :rtype: callable
    """
164 165 166
    return reader_creator(
        './simple-examples/data/ptb.valid.txt', word_idx, n, data_type
    )
Y
Yancey1989 已提交
167 168


169 170 171
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
172
    level=1,
173 174
    reason="Please use new dataset API which supports paddle.io.DataLoader",
)
175
def fetch():
176
    paddle.dataset.common.download(URL, "imikolov", MD5)