From 286696aa2b300b63d460f117c834500e87ca407a Mon Sep 17 00:00:00 2001 From: yangyaming Date: Sat, 6 May 2017 17:33:12 +0800 Subject: [PATCH] extend imikolov instead of adding ptb --- python/paddle/v2/dataset/imikolov.py | 44 +++-- python/paddle/v2/dataset/ptb.py | 169 ------------------ .../paddle/v2/dataset/tests/imikolov_test.py | 27 +++ python/paddle/v2/dataset/tests/ptb_test.py | 53 ------ 4 files changed, 60 insertions(+), 233 deletions(-) delete mode 100644 python/paddle/v2/dataset/ptb.py delete mode 100644 python/paddle/v2/dataset/tests/ptb_test.py diff --git a/python/paddle/v2/dataset/imikolov.py b/python/paddle/v2/dataset/imikolov.py index 97b9f7d915..dd3a4552d2 100644 --- a/python/paddle/v2/dataset/imikolov.py +++ b/python/paddle/v2/dataset/imikolov.py @@ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' MD5 = '30177ea32e27c525793142b6bf2c8e2d' +class DataType(object): + NGRAM = 1 + SEQ = 2 + + def word_count(f, word_freq=None): if word_freq is None: word_freq = collections.defaultdict(int) @@ -69,7 +74,7 @@ def build_dict(min_word_freq=50): return word_idx -def reader_creator(filename, word_idx, n): +def reader_creator(filename, word_idx, n, data_type): def reader(): with tarfile.open( paddle.v2.dataset.common.download( @@ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n): UNK = word_idx[''] for l in f: - l = [''] + l.strip().split() + [''] - if len(l) >= n: + if DataType.NGRAM == data_type: + assert n > -1, 'Invalid gram length' + l = [''] + l.strip().split() + [''] + if len(l) >= n: + l = [word_idx.get(w, UNK) for w in l] + for i in range(n, len(l) + 1): + yield tuple(l[i - n:i]) + elif DataType.SEQ == data_type: + l = l.strip().split() l = [word_idx.get(w, UNK) for w in l] - for i in range(n, len(l) + 1): - yield tuple(l[i - n:i]) + src_seq = [word_idx['']] + l + trg_seq = l + [word_idx['']] + if n > 0 and len(src_seq) > n: continue + yield src_seq, trg_seq + else: + assert False, 'Unknow data type' return reader -def train(word_idx, n): +def train(word_idx, n, data_type=DataType.NGRAM): """ imikolov training set creator. @@ -97,15 +113,18 @@ def train(word_idx, n): :param word_idx: word dictionary :type word_idx: dict - :param n: sliding window size + :param n: sliding window size if type is ngram, otherwise max length of sequence :type n: int + :param data_type: data type (ngram or sequence) + :type data_type: member variable of DataType (NGRAM or SEQ) :return: Training reader creator :rtype: callable """ - return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n) + return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n, + data_type) -def test(word_idx, n): +def test(word_idx, n, data_type=DataType.NGRAM): """ imikolov test set creator. @@ -114,12 +133,15 @@ def test(word_idx, n): :param word_idx: word dictionary :type word_idx: dict - :param n: sliding window size + :param n: sliding window size if type is ngram, otherwise max length of sequence :type n: int + :param data_type: data type (ngram or sequence) + :type data_type: member variable of DataType (NGRAM or SEQ) :return: Test reader creator :rtype: callable """ - return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n) + return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n, + data_type) def fetch(): diff --git a/python/paddle/v2/dataset/ptb.py b/python/paddle/v2/dataset/ptb.py deleted file mode 100644 index ea68daf9c6..0000000000 --- a/python/paddle/v2/dataset/ptb.py +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -langauge model's simple dataset. - -This module will download dataset from -http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set -into paddle reader creators. -""" -import paddle.v2.dataset.common -import collections -import tarfile - -__all__ = ['train', 'test', 'build_dict'] - -URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' -MD5 = '30177ea32e27c525793142b6bf2c8e2d' - - -def word_count(f, word_freq=None): - if word_freq is None: - word_freq = collections.defaultdict(int) - - for l in f: - for w in l.strip().split(): - word_freq[w] += 1 - word_freq[''] += 1 - word_freq[''] += 1 - - return word_freq - - -def build_dict(min_word_freq=50): - """ - Build a word dictionary from the corpus, Keys of the dictionary are words, - and values are zero-based IDs of these words. - """ - train_filename = './simple-examples/data/ptb.train.txt' - test_filename = './simple-examples/data/ptb.valid.txt' - with tarfile.open( - paddle.v2.dataset.common.download( - paddle.v2.dataset.imikolov.URL, 'imikolov', - paddle.v2.dataset.imikolov.MD5)) as tf: - trainf = tf.extractfile(train_filename) - testf = tf.extractfile(test_filename) - word_freq = word_count(testf, word_count(trainf)) - if '' in word_freq: - # remove for now, since we will set it as last index - del word_freq[''] - - word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items()) - - word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) - words, _ = list(zip(*word_freq_sorted)) - word_idx = dict(zip(words, xrange(len(words)))) - word_idx[''] = len(words) - - return word_idx - - -def reader_creator(filename, reader_type, word_idx, n=-1): - def reader(): - with tarfile.open( - paddle.v2.dataset.common.download( - paddle.v2.dataset.imikolov.URL, 'imikolov', - paddle.v2.dataset.imikolov.MD5)) as tf: - f = tf.extractfile(filename) - - UNK = word_idx[''] - - for l in f: - if 'ngram' == reader_type: - assert n > -1, 'Invalid gram length' - l = [''] + l.strip().split() + [''] - if len(l) < n: continue - l = [word_idx.get(w, UNK) for w in l] - for i in range(n, len(l) + 1): - yield tuple(l[i - n:i]) - elif 'seq' == reader_type: - l = l.strip().split() - l = [word_idx.get(w, UNK) for w in l] - src_seq = [word_idx['']] + l - trg_seq = l + [word_idx['']] - yield src_seq, trg_seq - - return reader - - -def ngram_train(word_idx, n): - """ - ptb ngram type training set creator. - - It returns a reader creator, each sample in the reader is a word ID - tuple. - - :param word_idx: word dictionary - :type word_idx: dict - :param n: sliding window size - :type n: int - :return: Training reader creator - :rtype: callable - """ - return reader_creator('./simple-examples/data/ptb.train.txt', 'ngram', - word_idx, n) - - -def ngram_test(word_idx, n): - """ - ptb ngram test set creator. - - It returns a reader creator, each sample in the reader is a word ID - tuple. - - :param word_idx: word dictionary - :type word_idx: dict - :param n: sliding window size - :type n: int - :return: Test reader creator - :rtype: callable - """ - return reader_creator('./simple-examples/data/ptb.valid.txt', 'ngram', - word_idx, n) - - -def seq_train(word_idx): - """ - ptb sequence type training set creator. - - It returns a reader creator, each sample in the reader is a word ID - pair. - - :param word_idx: word dictionary - :type word_idx: dict - :return: Test reader creator - :rtype: callable - """ - return reader_creator('./simple-examples/data/ptb.train.txt', 'seq', - word_idx) - - -def seq_test(word_idx): - """ - ptb sequence type test set creator. - - It returns a reader creator, each sample in the reader is a word ID - pair. - - :param word_idx: word dictionary - :type word_idx: dict - :return: Test reader creator - :rtype: callable - """ - return reader_creator('./simple-examples/data/ptb.valid.txt', 'seq', - word_idx) - - -def fetch(): - paddle.v2.dataset.common.download(URL, "imikolov", MD5) diff --git a/python/paddle/v2/dataset/tests/imikolov_test.py b/python/paddle/v2/dataset/tests/imikolov_test.py index 009e55243a..4e52810e6b 100644 --- a/python/paddle/v2/dataset/tests/imikolov_test.py +++ b/python/paddle/v2/dataset/tests/imikolov_test.py @@ -13,10 +13,37 @@ class TestMikolov(unittest.TestCase): n = 5 self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n) + first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\ + 'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\ + 'rake regatta rubens sim snack-food ssangyong swapo wachter' + first_line = [ + WORD_DICT.get(ch, WORD_DICT['']) + for ch in first_line.split(' ') + ] + for l in paddle.v2.dataset.imikolov.train( + WORD_DICT, n=-1, + data_type=paddle.v2.dataset.imikolov.DataType.SEQ)(): + read_line = l[0][1:] + break + self.assertEqual(first_line, read_line) + def test_test(self): n = 5 self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n) + first_line = 'consumers may want to move their telephones a little '\ + 'closer to the tv set' + first_line = [ + WORD_DICT.get(ch, WORD_DICT['']) + for ch in first_line.split(' ') + ] + for l in paddle.v2.dataset.imikolov.test( + WORD_DICT, n=-1, + data_type=paddle.v2.dataset.imikolov.DataType.SEQ)(): + read_line = l[0][1:] + break + self.assertEqual(first_line, read_line) + def test_total(self): _, idx = zip(*WORD_DICT.items()) self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1) diff --git a/python/paddle/v2/dataset/tests/ptb_test.py b/python/paddle/v2/dataset/tests/ptb_test.py deleted file mode 100644 index 5e584a734d..0000000000 --- a/python/paddle/v2/dataset/tests/ptb_test.py +++ /dev/null @@ -1,53 +0,0 @@ -import paddle.v2.dataset.ptb -import unittest - -WORD_DICT = paddle.v2.dataset.ptb.build_dict() - - -class TestMikolov(unittest.TestCase): - def check_reader(self, reader, n): - for l in reader(): - self.assertEqual(len(l), n) - - def test_ngram_train(self): - n = 5 - self.check_reader(paddle.v2.dataset.ptb.ngram_train(WORD_DICT, n), n) - - def test_ngram_test(self): - n = 5 - self.check_reader(paddle.v2.dataset.ptb.ngram_test(WORD_DICT, n), n) - - def test_seq_train(self): - first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\ - 'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\ - 'rake regatta rubens sim snack-food ssangyong swapo wachter' - first_line = [ - WORD_DICT.get(ch, WORD_DICT['']) - for ch in first_line.split(' ') - ] - for l in paddle.v2.dataset.ptb.seq_train(WORD_DICT)(): - read_line = l[0][1:] - break - - self.assertEqual(first_line, read_line) - - def test_seq_test(self): - first_line = 'consumers may want to move their telephones a little '\ - 'closer to the tv set' - first_line = [ - WORD_DICT.get(ch, WORD_DICT['']) - for ch in first_line.split(' ') - ] - for l in paddle.v2.dataset.ptb.seq_test(WORD_DICT)(): - read_line = l[0][1:] - break - - self.assertEqual(first_line, read_line) - - def test_total(self): - _, idx = zip(*WORD_DICT.items()) - self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1) - - -if __name__ == '__main__': - unittest.main() -- GitLab