提交 81fb41f0 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #2023 from pkuyym/develop

Add dataset PTB into paddle.dataset for language model task.
......@@ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
class DataType(object):
NGRAM = 1
SEQ = 2
def word_count(f, word_freq=None):
if word_freq is None:
word_freq = collections.defaultdict(int)
......@@ -41,7 +46,7 @@ def word_count(f, word_freq=None):
return word_freq
def build_dict(typo_freq=50):
def build_dict(min_word_freq=50):
"""
Build a word dictionary from the corpus, Keys of the dictionary are words,
and values are zero-based IDs of these words.
......@@ -59,7 +64,7 @@ def build_dict(typo_freq=50):
# remove <unk> for now, since we will set it as last index
del word_freq['<unk>']
word_freq = filter(lambda x: x[1] > typo_freq, word_freq.items())
word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items())
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
......@@ -69,7 +74,7 @@ def build_dict(typo_freq=50):
return word_idx
def reader_creator(filename, word_idx, n):
def reader_creator(filename, word_idx, n, data_type):
def reader():
with tarfile.open(
paddle.v2.dataset.common.download(
......@@ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n):
UNK = word_idx['<unk>']
for l in f:
if DataType.NGRAM == data_type:
assert n > -1, 'Invalid gram length'
l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1):
yield tuple(l[i - n:i])
elif DataType.SEQ == data_type:
l = l.strip().split()
l = [word_idx.get(w, UNK) for w in l]
src_seq = [word_idx['<s>']] + l
trg_seq = l + [word_idx['<e>']]
if n > 0 and len(src_seq) > n: continue
yield src_seq, trg_seq
else:
assert False, 'Unknow data type'
return reader
def train(word_idx, n):
def train(word_idx, n, data_type=DataType.NGRAM):
"""
imikolov training set creator.
......@@ -97,15 +113,18 @@ def train(word_idx, n):
:param word_idx: word dictionary
:type word_idx: dict
:param n: sliding window size
:param n: sliding window size if type is ngram, otherwise max length of sequence
:type n: int
:param data_type: data type (ngram or sequence)
:type data_type: member variable of DataType (NGRAM or SEQ)
:return: Training reader creator
:rtype: callable
"""
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n)
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n,
data_type)
def test(word_idx, n):
def test(word_idx, n, data_type=DataType.NGRAM):
"""
imikolov test set creator.
......@@ -114,12 +133,15 @@ def test(word_idx, n):
:param word_idx: word dictionary
:type word_idx: dict
:param n: sliding window size
:param n: sliding window size if type is ngram, otherwise max length of sequence
:type n: int
:param data_type: data type (ngram or sequence)
:type data_type: member variable of DataType (NGRAM or SEQ)
:return: Test reader creator
:rtype: callable
"""
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n)
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n,
data_type)
def fetch():
......
......@@ -13,10 +13,37 @@ class TestMikolov(unittest.TestCase):
n = 5
self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n)
first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\
'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\
'rake regatta rubens sim snack-food ssangyong swapo wachter'
first_line = [
WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ')
]
for l in paddle.v2.dataset.imikolov.train(
WORD_DICT, n=-1,
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
read_line = l[0][1:]
break
self.assertEqual(first_line, read_line)
def test_test(self):
n = 5
self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n)
first_line = 'consumers may want to move their telephones a little '\
'closer to the tv set'
first_line = [
WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ')
]
for l in paddle.v2.dataset.imikolov.test(
WORD_DICT, n=-1,
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
read_line = l[0][1:]
break
self.assertEqual(first_line, read_line)
def test_total(self):
_, idx = zip(*WORD_DICT.items())
self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册