imikolov.py 5.6 KB
Newer Older
D
dangqingqing 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
"""
Q
qijun 已提交
15
imikolov's simple dataset.
Y
Yu Yang 已提交
16

M
minqiyang 已提交
17
This module will download dataset from
Q
qijun 已提交
18 19
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
into paddle reader creators.
20
"""
21 22 23

from __future__ import print_function

24
import paddle.dataset.common
25
import paddle.utils.deprecated as deprecated
26
import collections
27
import tarfile
28
import six
29

30 31
__all__ = []

32 33
#URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
URL = 'https://dataset.bj.bcebos.com/imikolov%2Fsimple-examples.tgz'
34 35 36
MD5 = '30177ea32e27c525793142b6bf2c8e2d'


37 38 39 40 41
class DataType(object):
    NGRAM = 1
    SEQ = 2


42
def word_count(f, word_freq=None):
43 44
    if word_freq is None:
        word_freq = collections.defaultdict(int)
45 46 47

    for l in f:
        for w in l.strip().split():
48 49 50
            word_freq[w] += 1
        word_freq['<s>'] += 1
        word_freq['<e>'] += 1
51 52 53 54

    return word_freq


Y
yangyaming 已提交
55
def build_dict(min_word_freq=50):
Q
qijun 已提交
56
    """
Q
qijun 已提交
57 58
    Build a word dictionary from the corpus,  Keys of the dictionary are words,
    and values are zero-based IDs of these words.
Q
qijun 已提交
59
    """
60 61
    train_filename = './simple-examples/data/ptb.train.txt'
    test_filename = './simple-examples/data/ptb.valid.txt'
62
    with tarfile.open(
63 64 65
            paddle.dataset.common.download(paddle.dataset.imikolov.URL,
                                           'imikolov',
                                           paddle.dataset.imikolov.MD5)) as tf:
66 67 68
        trainf = tf.extractfile(train_filename)
        testf = tf.extractfile(test_filename)
        word_freq = word_count(testf, word_count(trainf))
69 70 71
        if '<unk>' in word_freq:
            # remove <unk> for now, since we will set it as last index
            del word_freq['<unk>']
72

73
        word_freq = [
M
minqiyang 已提交
74
            x for x in six.iteritems(word_freq) if x[1] > min_word_freq
75
        ]
76

77 78
        word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
        words, _ = list(zip(*word_freq_sorted))
79
        word_idx = dict(list(zip(words, six.moves.range(len(words)))))
Y
Yi Wang 已提交
80
        word_idx['<unk>'] = len(words)
81 82 83 84

    return word_idx


85
def reader_creator(filename, word_idx, n, data_type):
86

87 88
    def reader():
        with tarfile.open(
89 90 91
                paddle.dataset.common.download(
                    paddle.dataset.imikolov.URL, 'imikolov',
                    paddle.dataset.imikolov.MD5)) as tf:
92 93
            f = tf.extractfile(filename)

Y
Yi Wang 已提交
94
            UNK = word_idx['<unk>']
95
            for l in f:
96 97 98 99 100
                if DataType.NGRAM == data_type:
                    assert n > -1, 'Invalid gram length'
                    l = ['<s>'] + l.strip().split() + ['<e>']
                    if len(l) >= n:
                        l = [word_idx.get(w, UNK) for w in l]
101
                        for i in six.moves.range(n, len(l) + 1):
102 103 104
                            yield tuple(l[i - n:i])
                elif DataType.SEQ == data_type:
                    l = l.strip().split()
Y
Yi Wang 已提交
105
                    l = [word_idx.get(w, UNK) for w in l]
106 107 108 109 110 111
                    src_seq = [word_idx['<s>']] + l
                    trg_seq = l + [word_idx['<e>']]
                    if n > 0 and len(src_seq) > n: continue
                    yield src_seq, trg_seq
                else:
                    assert False, 'Unknow data type'
112 113 114 115

    return reader


116 117 118
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
119
    level=1,
120
    reason="Please use new dataset API which supports paddle.io.DataLoader")
121
def train(word_idx, n, data_type=DataType.NGRAM):
Q
qijun 已提交
122
    """
Q
qijun 已提交
123
    imikolov training set creator.
Q
qijun 已提交
124

Q
qijun 已提交
125
    It returns a reader creator, each sample in the reader is a word ID
Q
qijun 已提交
126 127 128 129
    tuple.

    :param word_idx: word dictionary
    :type word_idx: dict
130
    :param n: sliding window size if type is ngram, otherwise max length of sequence
Q
qijun 已提交
131
    :type n: int
132 133
    :param data_type: data type (ngram or sequence)
    :type data_type: member variable of DataType (NGRAM or SEQ)
Q
qijun 已提交
134
    :return: Training reader creator
Q
qijun 已提交
135 136
    :rtype: callable
    """
137 138
    return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n,
                          data_type)
139 140


141 142 143
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
144
    level=1,
145
    reason="Please use new dataset API which supports paddle.io.DataLoader")
146
def test(word_idx, n, data_type=DataType.NGRAM):
Q
qijun 已提交
147 148 149
    """
    imikolov test set creator.

Q
qijun 已提交
150
    It returns a reader creator, each sample in the reader is a word ID
Q
qijun 已提交
151 152 153 154
    tuple.

    :param word_idx: word dictionary
    :type word_idx: dict
155
    :param n: sliding window size if type is ngram, otherwise max length of sequence
Q
qijun 已提交
156
    :type n: int
157 158
    :param data_type: data type (ngram or sequence)
    :type data_type: member variable of DataType (NGRAM or SEQ)
Q
qijun 已提交
159
    :return: Test reader creator
Q
qijun 已提交
160 161
    :rtype: callable
    """
162 163
    return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n,
                          data_type)
Y
Yancey1989 已提交
164 165


166 167 168
@deprecated(
    since="2.0.0",
    update_to="paddle.text.datasets.Imikolov",
169
    level=1,
170
    reason="Please use new dataset API which supports paddle.io.DataLoader")
171
def fetch():
172
    paddle.dataset.common.download(URL, "imikolov", MD5)