diff --git a/python/paddle/v2/dataset/imikolov.py b/python/paddle/v2/dataset/imikolov.py index bf88fe15570682b544bbac802cd65545f756fc3e..dd3a4552d2e1a2b00dde5ddb7ac1d78445bdca51 100644 --- a/python/paddle/v2/dataset/imikolov.py +++ b/python/paddle/v2/dataset/imikolov.py @@ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' MD5 = '30177ea32e27c525793142b6bf2c8e2d' +class DataType(object): + NGRAM = 1 + SEQ = 2 + + def word_count(f, word_freq=None): if word_freq is None: word_freq = collections.defaultdict(int) @@ -41,7 +46,7 @@ def word_count(f, word_freq=None): return word_freq -def build_dict(typo_freq=50): +def build_dict(min_word_freq=50): """ Build a word dictionary from the corpus, Keys of the dictionary are words, and values are zero-based IDs of these words. @@ -59,7 +64,7 @@ def build_dict(typo_freq=50): # remove for now, since we will set it as last index del word_freq[''] - word_freq = filter(lambda x: x[1] > typo_freq, word_freq.items()) + word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items()) word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*word_freq_sorted)) @@ -69,7 +74,7 @@ def build_dict(typo_freq=50): return word_idx -def reader_creator(filename, word_idx, n): +def reader_creator(filename, word_idx, n, data_type): def reader(): with tarfile.open( paddle.v2.dataset.common.download( @@ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n): UNK = word_idx[''] for l in f: - l = [''] + l.strip().split() + [''] - if len(l) >= n: + if DataType.NGRAM == data_type: + assert n > -1, 'Invalid gram length' + l = [''] + l.strip().split() + [''] + if len(l) >= n: + l = [word_idx.get(w, UNK) for w in l] + for i in range(n, len(l) + 1): + yield tuple(l[i - n:i]) + elif DataType.SEQ == data_type: + l = l.strip().split() l = [word_idx.get(w, UNK) for w in l] - for i in range(n, len(l) + 1): - yield tuple(l[i - n:i]) + src_seq = [word_idx['']] + l + trg_seq = l + [word_idx['']] + if n > 0 and len(src_seq) > n: continue + yield src_seq, trg_seq + else: + assert False, 'Unknow data type' return reader -def train(word_idx, n): +def train(word_idx, n, data_type=DataType.NGRAM): """ imikolov training set creator. @@ -97,15 +113,18 @@ def train(word_idx, n): :param word_idx: word dictionary :type word_idx: dict - :param n: sliding window size + :param n: sliding window size if type is ngram, otherwise max length of sequence :type n: int + :param data_type: data type (ngram or sequence) + :type data_type: member variable of DataType (NGRAM or SEQ) :return: Training reader creator :rtype: callable """ - return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n) + return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n, + data_type) -def test(word_idx, n): +def test(word_idx, n, data_type=DataType.NGRAM): """ imikolov test set creator. @@ -114,12 +133,15 @@ def test(word_idx, n): :param word_idx: word dictionary :type word_idx: dict - :param n: sliding window size + :param n: sliding window size if type is ngram, otherwise max length of sequence :type n: int + :param data_type: data type (ngram or sequence) + :type data_type: member variable of DataType (NGRAM or SEQ) :return: Test reader creator :rtype: callable """ - return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n) + return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n, + data_type) def fetch(): diff --git a/python/paddle/v2/dataset/tests/imikolov_test.py b/python/paddle/v2/dataset/tests/imikolov_test.py index 009e55243a594e5e235c36fb0223ec70754d17f3..4e52810e6b924e0796e3d836dbbcb27ede2c9e25 100644 --- a/python/paddle/v2/dataset/tests/imikolov_test.py +++ b/python/paddle/v2/dataset/tests/imikolov_test.py @@ -13,10 +13,37 @@ class TestMikolov(unittest.TestCase): n = 5 self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n) + first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\ + 'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\ + 'rake regatta rubens sim snack-food ssangyong swapo wachter' + first_line = [ + WORD_DICT.get(ch, WORD_DICT['']) + for ch in first_line.split(' ') + ] + for l in paddle.v2.dataset.imikolov.train( + WORD_DICT, n=-1, + data_type=paddle.v2.dataset.imikolov.DataType.SEQ)(): + read_line = l[0][1:] + break + self.assertEqual(first_line, read_line) + def test_test(self): n = 5 self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n) + first_line = 'consumers may want to move their telephones a little '\ + 'closer to the tv set' + first_line = [ + WORD_DICT.get(ch, WORD_DICT['']) + for ch in first_line.split(' ') + ] + for l in paddle.v2.dataset.imikolov.test( + WORD_DICT, n=-1, + data_type=paddle.v2.dataset.imikolov.DataType.SEQ)(): + read_line = l[0][1:] + break + self.assertEqual(first_line, read_line) + def test_total(self): _, idx = zip(*WORD_DICT.items()) self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)