imikolov.py 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
"""
imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/
"""
import paddle.v2.dataset.common
import tarfile
import collections

__all__ = ['train', 'test']

URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'


def add(a_dict, ele):
    if ele in a_dict:
        a_dict[ele] += 1
    else:
        a_dict[ele] = 1


def word_count(f, word_freq=None):
    if word_freq == None:
        word_freq = {}

    for l in f:
        for w in l.strip().split():
            add(word_freq, w)
        add(word_freq, '<s>')
        add(word_freq, '<e>')

    return word_freq


def build_dict(train_filename, test_filename):
    with tarfile.open(
            paddle.v2.dataset.common.download(
                paddle.v2.dataset.imikolov.URL, 'imikolov',
                paddle.v2.dataset.imikolov.MD5)) as tf:
        trainf = tf.extractfile(train_filename)
        testf = tf.extractfile(test_filename)
        word_freq = word_count(testf, word_count(trainf))

        STOPWORD_FREQ = 3000
        TYPO_FREQ = 50
        word_freq = filter(lambda x: x[1] > TYPO_FREQ and x[1] < STOPWORD_FREQ,
                           word_freq.items())

        dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
        words, _ = list(zip(*dictionary))
        word_idx = dict(zip(words, xrange(len(words))))
        word_idx['<any>'] = len(words)

    return word_idx


word_idx = {}


def reader_creator(filename, n):
    global word_idx
    if len(word_idx) == 0:
        word_idx = build_dict('./simple-examples/data/ptb.train.txt',
                              './simple-examples/data/ptb.valid.txt')

    def reader():
        with tarfile.open(
                paddle.v2.dataset.common.download(
                    paddle.v2.dataset.imikolov.URL, 'imikolov',
                    paddle.v2.dataset.imikolov.MD5)) as tf:
            f = tf.extractfile(filename)

            ANY = word_idx['<any>']
            for l in f:
                l = ['<s>'] + l.strip().split() + ['<e>']
                if len(l) >= n:
                    l = [word_idx.get(w, ANY) for w in l]
                    for i in range(n, len(l) + 1):
                        yield l[i - n:i]

    return reader


def train(n):
    return reader_creator('./simple-examples/data/ptb.train.txt', n)


def test(n):
    return reader_creator('./simple-examples/data/ptb.valid.txt', n)