imikolov.py 2.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
"""
imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/
"""
import paddle.v2.dataset.common
import tarfile

__all__ = ['train', 'test']

URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'


def word_count(f, word_freq=None):
Y
Yi Wang 已提交
14
    add = paddle.v2.dataset.common.dict_add
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
    if word_freq == None:
        word_freq = {}

    for l in f:
        for w in l.strip().split():
            add(word_freq, w)
        add(word_freq, '<s>')
        add(word_freq, '<e>')

    return word_freq


def build_dict(train_filename, test_filename):
    with tarfile.open(
            paddle.v2.dataset.common.download(
                paddle.v2.dataset.imikolov.URL, 'imikolov',
                paddle.v2.dataset.imikolov.MD5)) as tf:
        trainf = tf.extractfile(train_filename)
        testf = tf.extractfile(test_filename)
        word_freq = word_count(testf, word_count(trainf))

        TYPO_FREQ = 50
Y
Yi Wang 已提交
37
        word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items())
38 39 40 41

        dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
        words, _ = list(zip(*dictionary))
        word_idx = dict(zip(words, xrange(len(words))))
Y
Yi Wang 已提交
42
        word_idx['<unk>'] = len(words)
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

    return word_idx


word_idx = {}


def reader_creator(filename, n):
    global word_idx
    if len(word_idx) == 0:
        word_idx = build_dict('./simple-examples/data/ptb.train.txt',
                              './simple-examples/data/ptb.valid.txt')

    def reader():
        with tarfile.open(
                paddle.v2.dataset.common.download(
                    paddle.v2.dataset.imikolov.URL, 'imikolov',
                    paddle.v2.dataset.imikolov.MD5)) as tf:
            f = tf.extractfile(filename)

Y
Yi Wang 已提交
63
            UNK = word_idx['<unk>']
64 65 66
            for l in f:
                l = ['<s>'] + l.strip().split() + ['<e>']
                if len(l) >= n:
Y
Yi Wang 已提交
67
                    l = [word_idx.get(w, UNK) for w in l]
68
                    for i in range(n, len(l) + 1):
Y
Yi Wang 已提交
69
                        yield tuple(l[i - n:i])
70 71 72 73 74 75 76 77 78 79

    return reader


def train(n):
    return reader_creator('./simple-examples/data/ptb.train.txt', n)


def test(n):
    return reader_creator('./simple-examples/data/ptb.valid.txt', n)