提交 7275e0a8 编写于 作者: Y Yi Wang

In response to comments from Helin

上级 a2cec420
......@@ -32,3 +32,10 @@ def download(url, module_name, md5sum):
shutil.copyfileobj(r.raw, f)
return filename
def dict_add(a_dict, ele):
if ele in a_dict:
a_dict[ele] += 1
else:
a_dict[ele] = 1
......@@ -10,14 +10,8 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz'
MD5 = '30177ea32e27c525793142b6bf2c8e2d'
def add(a_dict, ele):
if ele in a_dict:
a_dict[ele] += 1
else:
a_dict[ele] = 1
def word_count(f, word_freq=None):
add = paddle.v2.dataset.common.dict_add
if word_freq == None:
word_freq = {}
......@@ -45,7 +39,7 @@ def build_dict(train_filename, test_filename):
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary))
word_idx = dict(zip(words, xrange(len(words))))
word_idx['<any>'] = len(words)
word_idx['<unk>'] = len(words)
return word_idx
......@@ -66,13 +60,13 @@ def reader_creator(filename, n):
paddle.v2.dataset.imikolov.MD5)) as tf:
f = tf.extractfile(filename)
ANY = word_idx['<any>']
UNK = word_idx['<unk>']
for l in f:
l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n:
l = [word_idx.get(w, ANY) for w in l]
l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1):
yield l[i - n:i]
yield tuple(l[i - n:i])
return reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册