提交 286696aa 编写于 作者: Y yangyaming

extend imikolov instead of adding ptb

上级 70d15e84
...@@ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' ...@@ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d' MD5 = '30177ea32e27c525793142b6bf2c8e2d'
class DataType(object):
NGRAM = 1
SEQ = 2
def word_count(f, word_freq=None): def word_count(f, word_freq=None):
if word_freq is None: if word_freq is None:
word_freq = collections.defaultdict(int) word_freq = collections.defaultdict(int)
...@@ -69,7 +74,7 @@ def build_dict(min_word_freq=50): ...@@ -69,7 +74,7 @@ def build_dict(min_word_freq=50):
return word_idx return word_idx
def reader_creator(filename, word_idx, n): def reader_creator(filename, word_idx, n, data_type):
def reader(): def reader():
with tarfile.open( with tarfile.open(
paddle.v2.dataset.common.download( paddle.v2.dataset.common.download(
...@@ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n): ...@@ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n):
UNK = word_idx['<unk>'] UNK = word_idx['<unk>']
for l in f: for l in f:
if DataType.NGRAM == data_type:
assert n > -1, 'Invalid gram length'
l = ['<s>'] + l.strip().split() + ['<e>'] l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n: if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l] l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1): for i in range(n, len(l) + 1):
yield tuple(l[i - n:i]) yield tuple(l[i - n:i])
elif DataType.SEQ == data_type:
l = l.strip().split()
l = [word_idx.get(w, UNK) for w in l]
src_seq = [word_idx['<s>']] + l
trg_seq = l + [word_idx['<e>']]
if n > 0 and len(src_seq) > n: continue
yield src_seq, trg_seq
else:
assert False, 'Unknow data type'
return reader return reader
def train(word_idx, n): def train(word_idx, n, data_type=DataType.NGRAM):
""" """
imikolov training set creator. imikolov training set creator.
...@@ -97,15 +113,18 @@ def train(word_idx, n): ...@@ -97,15 +113,18 @@ def train(word_idx, n):
:param word_idx: word dictionary :param word_idx: word dictionary
:type word_idx: dict :type word_idx: dict
:param n: sliding window size :param n: sliding window size if type is ngram, otherwise max length of sequence
:type n: int :type n: int
:param data_type: data type (ngram or sequence)
:type data_type: member variable of DataType (NGRAM or SEQ)
:return: Training reader creator :return: Training reader creator
:rtype: callable :rtype: callable
""" """
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n) return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n,
data_type)
def test(word_idx, n): def test(word_idx, n, data_type=DataType.NGRAM):
""" """
imikolov test set creator. imikolov test set creator.
...@@ -114,12 +133,15 @@ def test(word_idx, n): ...@@ -114,12 +133,15 @@ def test(word_idx, n):
:param word_idx: word dictionary :param word_idx: word dictionary
:type word_idx: dict :type word_idx: dict
:param n: sliding window size :param n: sliding window size if type is ngram, otherwise max length of sequence
:type n: int :type n: int
:param data_type: data type (ngram or sequence)
:type data_type: member variable of DataType (NGRAM or SEQ)
:return: Test reader creator :return: Test reader creator
:rtype: callable :rtype: callable
""" """
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n) return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n,
data_type)
def fetch(): def fetch():
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
langauge model's simple dataset.
This module will download dataset from
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
into paddle reader creators.
"""
import paddle.v2.dataset.common
import collections
import tarfile
__all__ = ['train', 'test', 'build_dict']
URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
def word_count(f, word_freq=None):
if word_freq is None:
word_freq = collections.defaultdict(int)
for l in f:
for w in l.strip().split():
word_freq[w] += 1
word_freq['<s>'] += 1
word_freq['<e>'] += 1
return word_freq
def build_dict(min_word_freq=50):
"""
Build a word dictionary from the corpus, Keys of the dictionary are words,
and values are zero-based IDs of these words.
"""
train_filename = './simple-examples/data/ptb.train.txt'
test_filename = './simple-examples/data/ptb.valid.txt'
with tarfile.open(
paddle.v2.dataset.common.download(
paddle.v2.dataset.imikolov.URL, 'imikolov',
paddle.v2.dataset.imikolov.MD5)) as tf:
trainf = tf.extractfile(train_filename)
testf = tf.extractfile(test_filename)
word_freq = word_count(testf, word_count(trainf))
if '<unk>' in word_freq:
# remove <unk> for now, since we will set it as last index
del word_freq['<unk>']
word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items())
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(zip(words, xrange(len(words))))
word_idx['<unk>'] = len(words)
return word_idx
def reader_creator(filename, reader_type, word_idx, n=-1):
def reader():
with tarfile.open(
paddle.v2.dataset.common.download(
paddle.v2.dataset.imikolov.URL, 'imikolov',
paddle.v2.dataset.imikolov.MD5)) as tf:
f = tf.extractfile(filename)
UNK = word_idx['<unk>']
for l in f:
if 'ngram' == reader_type:
assert n > -1, 'Invalid gram length'
l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) < n: continue
l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1):
yield tuple(l[i - n:i])
elif 'seq' == reader_type:
l = l.strip().split()
l = [word_idx.get(w, UNK) for w in l]
src_seq = [word_idx['<s>']] + l
trg_seq = l + [word_idx['<e>']]
yield src_seq, trg_seq
return reader
def ngram_train(word_idx, n):
"""
ptb ngram type training set creator.
It returns a reader creator, each sample in the reader is a word ID
tuple.
:param word_idx: word dictionary
:type word_idx: dict
:param n: sliding window size
:type n: int
:return: Training reader creator
:rtype: callable
"""
return reader_creator('./simple-examples/data/ptb.train.txt', 'ngram',
word_idx, n)
def ngram_test(word_idx, n):
"""
ptb ngram test set creator.
It returns a reader creator, each sample in the reader is a word ID
tuple.
:param word_idx: word dictionary
:type word_idx: dict
:param n: sliding window size
:type n: int
:return: Test reader creator
:rtype: callable
"""
return reader_creator('./simple-examples/data/ptb.valid.txt', 'ngram',
word_idx, n)
def seq_train(word_idx):
"""
ptb sequence type training set creator.
It returns a reader creator, each sample in the reader is a word ID
pair.
:param word_idx: word dictionary
:type word_idx: dict
:return: Test reader creator
:rtype: callable
"""
return reader_creator('./simple-examples/data/ptb.train.txt', 'seq',
word_idx)
def seq_test(word_idx):
"""
ptb sequence type test set creator.
It returns a reader creator, each sample in the reader is a word ID
pair.
:param word_idx: word dictionary
:type word_idx: dict
:return: Test reader creator
:rtype: callable
"""
return reader_creator('./simple-examples/data/ptb.valid.txt', 'seq',
word_idx)
def fetch():
paddle.v2.dataset.common.download(URL, "imikolov", MD5)
...@@ -13,10 +13,37 @@ class TestMikolov(unittest.TestCase): ...@@ -13,10 +13,37 @@ class TestMikolov(unittest.TestCase):
n = 5 n = 5
self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n) self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n)
first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\
'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\
'rake regatta rubens sim snack-food ssangyong swapo wachter'
first_line = [
WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ')
]
for l in paddle.v2.dataset.imikolov.train(
WORD_DICT, n=-1,
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
read_line = l[0][1:]
break
self.assertEqual(first_line, read_line)
def test_test(self): def test_test(self):
n = 5 n = 5
self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n) self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n)
first_line = 'consumers may want to move their telephones a little '\
'closer to the tv set'
first_line = [
WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ')
]
for l in paddle.v2.dataset.imikolov.test(
WORD_DICT, n=-1,
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
read_line = l[0][1:]
break
self.assertEqual(first_line, read_line)
def test_total(self): def test_total(self):
_, idx = zip(*WORD_DICT.items()) _, idx = zip(*WORD_DICT.items())
self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1) self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)
......
import paddle.v2.dataset.ptb
import unittest
WORD_DICT = paddle.v2.dataset.ptb.build_dict()
class TestMikolov(unittest.TestCase):
def check_reader(self, reader, n):
for l in reader():
self.assertEqual(len(l), n)
def test_ngram_train(self):
n = 5
self.check_reader(paddle.v2.dataset.ptb.ngram_train(WORD_DICT, n), n)
def test_ngram_test(self):
n = 5
self.check_reader(paddle.v2.dataset.ptb.ngram_test(WORD_DICT, n), n)
def test_seq_train(self):
first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\
'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\
'rake regatta rubens sim snack-food ssangyong swapo wachter'
first_line = [
WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ')
]
for l in paddle.v2.dataset.ptb.seq_train(WORD_DICT)():
read_line = l[0][1:]
break
self.assertEqual(first_line, read_line)
def test_seq_test(self):
first_line = 'consumers may want to move their telephones a little '\
'closer to the tv set'
first_line = [
WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ')
]
for l in paddle.v2.dataset.ptb.seq_test(WORD_DICT)():
read_line = l[0][1:]
break
self.assertEqual(first_line, read_line)
def test_total(self):
_, idx = zip(*WORD_DICT.items())
self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册