提交 4b5a4322 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #1763 from reyoung/feature/remove_unnecessary_code_in_dataset

Remove unecessary code to generate freq_dict.
...@@ -66,13 +66,6 @@ def download(url, module_name, md5sum): ...@@ -66,13 +66,6 @@ def download(url, module_name, md5sum):
return filename return filename
def dict_add(a_dict, ele):
if ele in a_dict:
a_dict[ele] += 1
else:
a_dict[ele] = 1
def fetch_all(): def fetch_all():
for module_name in filter(lambda x: not x.startswith("__"), for module_name in filter(lambda x: not x.startswith("__"),
dir(paddle.v2.dataset)): dir(paddle.v2.dataset)):
......
...@@ -18,6 +18,7 @@ TODO(yuyang18): Complete comments. ...@@ -18,6 +18,7 @@ TODO(yuyang18): Complete comments.
""" """
import paddle.v2.dataset.common import paddle.v2.dataset.common
import collections
import tarfile import tarfile
import Queue import Queue
import re import re
...@@ -48,10 +49,10 @@ def tokenize(pattern): ...@@ -48,10 +49,10 @@ def tokenize(pattern):
def build_dict(pattern, cutoff): def build_dict(pattern, cutoff):
word_freq = {} word_freq = collections.defaultdict(int)
for doc in tokenize(pattern): for doc in tokenize(pattern):
for word in doc: for word in doc:
paddle.v2.dataset.common.dict_add(word_freq, word) word_freq[word] += 1
# Not sure if we should prune less-frequent words here. # Not sure if we should prune less-frequent words here.
word_freq = filter(lambda x: x[1] > cutoff, word_freq.items()) word_freq = filter(lambda x: x[1] > cutoff, word_freq.items())
......
...@@ -17,6 +17,7 @@ imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/ ...@@ -17,6 +17,7 @@ imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/
Complete comments. Complete comments.
""" """
import paddle.v2.dataset.common import paddle.v2.dataset.common
import collections
import tarfile import tarfile
__all__ = ['train', 'test', 'build_dict'] __all__ = ['train', 'test', 'build_dict']
...@@ -26,15 +27,14 @@ MD5 = '30177ea32e27c525793142b6bf2c8e2d' ...@@ -26,15 +27,14 @@ MD5 = '30177ea32e27c525793142b6bf2c8e2d'
def word_count(f, word_freq=None): def word_count(f, word_freq=None):
add = paddle.v2.dataset.common.dict_add if word_freq is None:
if word_freq == None: word_freq = collections.defaultdict(int)
word_freq = {}
for l in f: for l in f:
for w in l.strip().split(): for w in l.strip().split():
add(word_freq, w) word_freq[w] += 1
add(word_freq, '<s>') word_freq['<s>'] += 1
add(word_freq, '<e>') word_freq['<e>'] += 1
return word_freq return word_freq
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册