提交 81fb41f0 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #2023 from pkuyym/develop

Add dataset PTB into paddle.dataset for language model task.
...@@ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' ...@@ -28,6 +28,11 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d' MD5 = '30177ea32e27c525793142b6bf2c8e2d'
class DataType(object):
NGRAM = 1
SEQ = 2
def word_count(f, word_freq=None): def word_count(f, word_freq=None):
if word_freq is None: if word_freq is None:
word_freq = collections.defaultdict(int) word_freq = collections.defaultdict(int)
...@@ -41,7 +46,7 @@ def word_count(f, word_freq=None): ...@@ -41,7 +46,7 @@ def word_count(f, word_freq=None):
return word_freq return word_freq
def build_dict(typo_freq=50): def build_dict(min_word_freq=50):
""" """
Build a word dictionary from the corpus, Keys of the dictionary are words, Build a word dictionary from the corpus, Keys of the dictionary are words,
and values are zero-based IDs of these words. and values are zero-based IDs of these words.
...@@ -59,7 +64,7 @@ def build_dict(typo_freq=50): ...@@ -59,7 +64,7 @@ def build_dict(typo_freq=50):
# remove <unk> for now, since we will set it as last index # remove <unk> for now, since we will set it as last index
del word_freq['<unk>'] del word_freq['<unk>']
word_freq = filter(lambda x: x[1] > typo_freq, word_freq.items()) word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items())
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted)) words, _ = list(zip(*word_freq_sorted))
...@@ -69,7 +74,7 @@ def build_dict(typo_freq=50): ...@@ -69,7 +74,7 @@ def build_dict(typo_freq=50):
return word_idx return word_idx
def reader_creator(filename, word_idx, n): def reader_creator(filename, word_idx, n, data_type):
def reader(): def reader():
with tarfile.open( with tarfile.open(
paddle.v2.dataset.common.download( paddle.v2.dataset.common.download(
...@@ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n): ...@@ -79,16 +84,27 @@ def reader_creator(filename, word_idx, n):
UNK = word_idx['<unk>'] UNK = word_idx['<unk>']
for l in f: for l in f:
l = ['<s>'] + l.strip().split() + ['<e>'] if DataType.NGRAM == data_type:
if len(l) >= n: assert n > -1, 'Invalid gram length'
l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1):
yield tuple(l[i - n:i])
elif DataType.SEQ == data_type:
l = l.strip().split()
l = [word_idx.get(w, UNK) for w in l] l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1): src_seq = [word_idx['<s>']] + l
yield tuple(l[i - n:i]) trg_seq = l + [word_idx['<e>']]
if n > 0 and len(src_seq) > n: continue
yield src_seq, trg_seq
else:
assert False, 'Unknow data type'
return reader return reader
def train(word_idx, n): def train(word_idx, n, data_type=DataType.NGRAM):
""" """
imikolov training set creator. imikolov training set creator.
...@@ -97,15 +113,18 @@ def train(word_idx, n): ...@@ -97,15 +113,18 @@ def train(word_idx, n):
:param word_idx: word dictionary :param word_idx: word dictionary
:type word_idx: dict :type word_idx: dict
:param n: sliding window size :param n: sliding window size if type is ngram, otherwise max length of sequence
:type n: int :type n: int
:param data_type: data type (ngram or sequence)
:type data_type: member variable of DataType (NGRAM or SEQ)
:return: Training reader creator :return: Training reader creator
:rtype: callable :rtype: callable
""" """
return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n) return reader_creator('./simple-examples/data/ptb.train.txt', word_idx, n,
data_type)
def test(word_idx, n): def test(word_idx, n, data_type=DataType.NGRAM):
""" """
imikolov test set creator. imikolov test set creator.
...@@ -114,12 +133,15 @@ def test(word_idx, n): ...@@ -114,12 +133,15 @@ def test(word_idx, n):
:param word_idx: word dictionary :param word_idx: word dictionary
:type word_idx: dict :type word_idx: dict
:param n: sliding window size :param n: sliding window size if type is ngram, otherwise max length of sequence
:type n: int :type n: int
:param data_type: data type (ngram or sequence)
:type data_type: member variable of DataType (NGRAM or SEQ)
:return: Test reader creator :return: Test reader creator
:rtype: callable :rtype: callable
""" """
return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n) return reader_creator('./simple-examples/data/ptb.valid.txt', word_idx, n,
data_type)
def fetch(): def fetch():
......
...@@ -13,10 +13,37 @@ class TestMikolov(unittest.TestCase): ...@@ -13,10 +13,37 @@ class TestMikolov(unittest.TestCase):
n = 5 n = 5
self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n) self.check_reader(paddle.v2.dataset.imikolov.train(WORD_DICT, n), n)
first_line = 'aer banknote berlitz calloway centrust cluett fromstein '\
'gitano guterman hydro-quebec ipo kia memotec mlx nahb punts '\
'rake regatta rubens sim snack-food ssangyong swapo wachter'
first_line = [
WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ')
]
for l in paddle.v2.dataset.imikolov.train(
WORD_DICT, n=-1,
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
read_line = l[0][1:]
break
self.assertEqual(first_line, read_line)
def test_test(self): def test_test(self):
n = 5 n = 5
self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n) self.check_reader(paddle.v2.dataset.imikolov.test(WORD_DICT, n), n)
first_line = 'consumers may want to move their telephones a little '\
'closer to the tv set'
first_line = [
WORD_DICT.get(ch, WORD_DICT['<unk>'])
for ch in first_line.split(' ')
]
for l in paddle.v2.dataset.imikolov.test(
WORD_DICT, n=-1,
data_type=paddle.v2.dataset.imikolov.DataType.SEQ)():
read_line = l[0][1:]
break
self.assertEqual(first_line, read_line)
def test_total(self): def test_total(self):
_, idx = zip(*WORD_DICT.items()) _, idx = zip(*WORD_DICT.items())
self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1) self.assertEqual(sorted(idx)[-1], len(WORD_DICT) - 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册