提交 4cbbb23f 编写于 作者: H Helin Wang

expose build_dict in imikolov dataset, fix bug that len(word_dict) is not...

expose build_dict in imikolov dataset, fix bug that len(word_dict) is not bigger than all index in word_dict.
上级 1013983e
......@@ -17,7 +17,7 @@ imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/
import paddle.v2.dataset.common
import tarfile
__all__ = ['train', 'test']
__all__ = ['train', 'test', 'build_dict']
URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
......@@ -37,7 +37,9 @@ def word_count(f, word_freq=None):
return word_freq
def build_dict(train_filename, test_filename):
def build_dict():
train_filename = './simple-examples/data/ptb.train.txt'
test_filename = './simple-examples/data/ptb.valid.txt'
with tarfile.open(
paddle.v2.dataset.common.download(
paddle.v2.dataset.imikolov.URL, 'imikolov',
......@@ -45,27 +47,22 @@ def build_dict(train_filename, test_filename):
trainf = tf.extractfile(train_filename)
testf = tf.extractfile(test_filename)
word_freq = word_count(testf, word_count(trainf))
if '<unk>' in word_freq:
# remove <unk> for now, since we will set it as last index
del word_freq['<unk>']
TYPO_FREQ = 50
word_freq = filter(lambda x: x[1] > TYPO_FREQ, word_freq.items())
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(zip(words, xrange(len(words))))
word_idx['<unk>'] = len(words)
return word_idx
word_idx = {}
def reader_creator(filename, n):
global word_idx
if len(word_idx) == 0:
word_idx = build_dict('./simple-examples/data/ptb.train.txt',
'./simple-examples/data/ptb.valid.txt')
def reader_creator(filename, word_idx, n):
def reader():
with tarfile.open(
paddle.v2.dataset.common.download(
......@@ -84,9 +81,9 @@ def reader_creator(filename, n):
return reader
def train(n):
return reader_creator('./simple-examples/data/ptb.train.txt', n)
def train(word_idx, n):
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n)
def test(n):
return reader_creator('./simple-examples/data/ptb.valid.txt', n)
def test(word_idx, n):
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n)
import paddle.v2.dataset.imikolov
import unittest
WORD_DICT = paddle.v2.dataset.imikolov.build_dict()
class TestMikolov(unittest.TestCase):
def check_reader(self, reader, n):
......@@ -9,11 +11,15 @@ class TestMikolov(unittest.TestCase):
def test_train(self):
n = 5
self.check_reader(paddle.v2.dataset.imikolov.train(n), n)
self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n)
def test_test(self):
n = 5
self.check_reader(paddle.v2.dataset.imikolov.test(n), n)
self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n)
def test_total(self):
_, idx = zip(*WORD_DICT.items())
self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册