imikolov.py 5.4 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
"""
Q
qijun 已提交
15
imikolov's simple dataset.
Y
Yu Yang 已提交
16

M
minqiyang 已提交
17
This module will download dataset from
Q
qijun 已提交
18 19
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
into paddle reader creators.
20
"""
21

22
import collections
23 24
import tarfile

25 26 27
import paddle.dataset.common
import paddle.utils.deprecated as deprecated

28 29
__all__ = []

30
# URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
31
URL = 'https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz'
32 33 34
MD5 = '30177ea32e27c525793142b6bf2c8e2d'


35
class DataType:
36 37 38 39
    NGRAM = 1
    SEQ = 2


40
def word_count(f, word_freq=None):
41 42
    if word_freq is None:
        word_freq = collections.defaultdict(int)
43 44 45

    for l in f:
        for w in l.strip().split():
46 47 48
            word_freq[w] += 1
        word_freq['<s>'] += 1
        word_freq['<e>'] += 1
49 50 51 52

    return word_freq


Y
yangyaming 已提交
53
def build_dict(min_word_freq=50):
Q
qijun 已提交
54
    """
Q
qijun 已提交
55 56
    Build a word dictionary from the corpus,  Keys of the dictionary are words,
    and values are zero-based IDs of these words.
Q
qijun 已提交
57
    """
58 59
    train_filename = './simple-examples/data/ptb.train.txt'
    test_filename = './simple-examples/data/ptb.valid.txt'
60
    with tarfile.open(
61 62 63 64
        paddle.dataset.common.download(
            paddle.dataset.imikolov.URL, 'imikolov', paddle.dataset.imikolov.MD5
        )
    ) as tf:
65 66 67
        trainf = tf.extractfile(train_filename)
        testf = tf.extractfile(test_filename)
        word_freq = word_count(testf, word_count(trainf))
68 69 70
        if '<unk>' in word_freq:
            # remove <unk> for now, since we will set it as last index
            del word_freq['<unk>']
71

72
        word_freq = [x for x in word_freq.items() if x[1] > min_word_freq]
73

74 75
        word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
        words, _ = list(zip(*word_freq_sorted))
76
        word_idx = dict(list(zip(words, range(len(words)))))
Y
Yi Wang 已提交
77
        word_idx['<unk>'] = len(words)
78 79 80 81

    return word_idx


82
def reader_creator(filename, word_idx, n, data_type):
83 84
    def reader():
        with tarfile.open(
85 86 87 88 89 90
            paddle.dataset.common.download(
                paddle.dataset.imikolov.URL,
                'imikolov',
                paddle.dataset.imikolov.MD5,
            )
        ) as tf:
91 92
            f = tf.extractfile(filename)

Y
Yi Wang 已提交
93
            UNK = word_idx['<unk>']
94
            for l in f:
95 96 97 98 99
                if DataType.NGRAM == data_type:
                    assert n > -1, 'Invalid gram length'
                    l = ['<s>'] + l.strip().split() + ['<e>']
                    if len(l) >= n:
                        l = [word_idx.get(w, UNK) for w in l]
100
                        for i in range(n, len(l) + 1):
101
                            yield tuple(l[i - n : i])
102 103
                elif DataType.SEQ == data_type:
                    l = l.strip().split()
Y
Yi Wang 已提交
104
                    l = [word_idx.get(w, UNK) for w in l]
105 106
                    src_seq = [word_idx['<s>']] + l
                    trg_seq = l + [word_idx['<e>']]
107 108
                    if n > 0 and len(src_seq) > n:
                        continue
109 110
                    yield src_seq, trg_seq
                else:
C
chenxujun 已提交
111
                    assert False, 'Unknown data type'
112 113 114 115

    return reader


116 117 118
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
119
    level=1,
120 121
    reason="Please use new dataset API which supports paddle.io.DataLoader",
)
122
def train(word_idx, n, data_type=DataType.NGRAM):
Q
qijun 已提交
123
    """
Q
qijun 已提交
124
    imikolov training set creator.
Q
qijun 已提交
125

Q
qijun 已提交
126
    It returns a reader creator, each sample in the reader is a word ID
Q
qijun 已提交
127 128 129 130
    tuple.

    :param word_idx: word dictionary
    :type word_idx: dict
131
    :param n: sliding window size if type is ngram, otherwise max length of sequence
Q
qijun 已提交
132
    :type n: int
133 134
    :param data_type: data type (ngram or sequence)
    :type data_type: member variable of DataType (NGRAM or SEQ)
Q
qijun 已提交
135
    :return: Training reader creator
Q
qijun 已提交
136 137
    :rtype: callable
    """
138 139 140
    return reader_creator(
        './simple-examples/data/ptb.train.txt', word_idx, n, data_type
    )
141 142


143 144 145
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
146
    level=1,
147 148
    reason="Please use new dataset API which supports paddle.io.DataLoader",
)
149
def test(word_idx, n, data_type=DataType.NGRAM):
Q
qijun 已提交
150 151 152
    """
    imikolov test set creator.

Q
qijun 已提交
153
    It returns a reader creator, each sample in the reader is a word ID
Q
qijun 已提交
154 155 156 157
    tuple.

    :param word_idx: word dictionary
    :type word_idx: dict
158
    :param n: sliding window size if type is ngram, otherwise max length of sequence
Q
qijun 已提交
159
    :type n: int
160 161
    :param data_type: data type (ngram or sequence)
    :type data_type: member variable of DataType (NGRAM or SEQ)
Q
qijun 已提交
162
    :return: Test reader creator
Q
qijun 已提交
163 164
    :rtype: callable
    """
165 166 167
    return reader_creator(
        './simple-examples/data/ptb.valid.txt', word_idx, n, data_type
    )
Y
Yancey1989 已提交
168 169


170 171 172
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
173
    level=1,
174 175
    reason="Please use new dataset API which supports paddle.io.DataLoader",
)
176
def fetch():
177
    paddle.dataset.common.download(URL, "imikolov", MD5)