diff --git a/python/paddle/v2/dataset/imikolov.py b/python/paddle/v2/dataset/imikolov.py new file mode 100644 index 0000000000000000000000000000000000000000..b3791ddad66e588356338150fccadbcc8fa113ca --- /dev/null +++ b/python/paddle/v2/dataset/imikolov.py @@ -0,0 +1,79 @@ +""" +imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/ +""" +import paddle.v2.dataset.common +import tarfile + +__all__ = ['train', 'test'] + +URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' +MD5 = '30177ea32e27c525793142b6bf2c8e2d' + + +def word_count(f, word_freq=None): + add = paddle.v2.dataset.common.dict_add + if word_freq == None: + word_freq = {} + + for l in f: + for w in l.strip().split(): + add(word_freq, w) + add(word_freq, '') + add(word_freq, '') + + return word_freq + + +def build_dict(train_filename, test_filename): + with tarfile.open( + paddle.v2.dataset.common.download( + paddle.v2.dataset.imikolov.URL, 'imikolov', + paddle.v2.dataset.imikolov.MD5)) as tf: + trainf = tf.extractfile(train_filename) + testf = tf.extractfile(test_filename) + word_freq = word_count(testf, word_count(trainf)) + + TYPO_FREQ = 50 + word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items()) + + dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) + words, _ = list(zip(*dictionary)) + word_idx = dict(zip(words, xrange(len(words)))) + word_idx[''] = len(words) + + return word_idx + + +word_idx = {} + + +def reader_creator(filename, n): + global word_idx + if len(word_idx) == 0: + word_idx = build_dict('./simple-examples/data/ptb.train.txt', + './simple-examples/data/ptb.valid.txt') + + def reader(): + with tarfile.open( + paddle.v2.dataset.common.download( + paddle.v2.dataset.imikolov.URL, 'imikolov', + paddle.v2.dataset.imikolov.MD5)) as tf: + f = tf.extractfile(filename) + + UNK = word_idx[''] + for l in f: + l = [''] + l.strip().split() + [''] + if len(l) >= n: + l = [word_idx.get(w, UNK) for w in l] + for i in range(n, len(l) + 1): + yield tuple(l[i - n:i]) + + return reader + + +def train(n): + return reader_creator('./simple-examples/data/ptb.train.txt', n) + + +def test(n): + return reader_creator('./simple-examples/data/ptb.valid.txt', n) diff --git a/python/paddle/v2/dataset/tests/imikolov_test.py b/python/paddle/v2/dataset/tests/imikolov_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9b1748eaaa7f913a6b94f2087a8089fb998570aa --- /dev/null +++ b/python/paddle/v2/dataset/tests/imikolov_test.py @@ -0,0 +1,20 @@ +import paddle.v2.dataset.imikolov +import unittest + + +class TestMikolov(unittest.TestCase): + def check_reader(self, reader, n): + for l in reader(): + self.assertEqual(len(l), n) + + def test_train(self): + n = 5 + self.check_reader(paddle.v2.dataset.imikolov.train(n), n) + + def test_test(self): + n = 5 + self.check_reader(paddle.v2.dataset.imikolov.test(n), n) + + +if __name__ == '__main__': + unittest.main()