diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 7021a6da05dec6be216534112c2df2586e73390f..2eb018b8d60e9a8bd0091836ab56c35b05786fca 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -66,13 +66,6 @@ def download(url, module_name, md5sum): return filename -def dict_add(a_dict, ele): - if ele in a_dict: - a_dict[ele] += 1 - else: - a_dict[ele] = 1 - - def fetch_all(): for module_name in filter(lambda x: not x.startswith("__"), dir(paddle.v2.dataset)): diff --git a/python/paddle/v2/dataset/imdb.py b/python/paddle/v2/dataset/imdb.py index 5284017ce08de8beb559f58fb6006639f40f5580..9a7ccff4d5cd2563053adb0aae95fc6d10ad2a50 100644 --- a/python/paddle/v2/dataset/imdb.py +++ b/python/paddle/v2/dataset/imdb.py @@ -18,6 +18,7 @@ TODO(yuyang18): Complete comments. """ import paddle.v2.dataset.common +import collections import tarfile import Queue import re @@ -48,10 +49,10 @@ def tokenize(pattern): def build_dict(pattern, cutoff): - word_freq = {} + word_freq = collections.defaultdict(int) for doc in tokenize(pattern): for word in doc: - paddle.v2.dataset.common.dict_add(word_freq, word) + word_freq[word] += 1 # Not sure if we should prune less-frequent words here. word_freq = filter(lambda x: x[1] > cutoff, word_freq.items()) diff --git a/python/paddle/v2/dataset/imikolov.py b/python/paddle/v2/dataset/imikolov.py index 2931d06e7eb65bde887c56a8bc20e7a9c5e4d4e4..5d7e0282b4db639e6590ade66241328d6ab8b5e3 100644 --- a/python/paddle/v2/dataset/imikolov.py +++ b/python/paddle/v2/dataset/imikolov.py @@ -17,6 +17,7 @@ imikolov's simple dataset: http://www.fit.vutbr.cz/~imikolov/rnnlm/ Complete comments. """ import paddle.v2.dataset.common +import collections import tarfile __all__ = ['train', 'test', 'build_dict'] @@ -26,15 +27,14 @@ MD5 = '30177ea32e27c525793142b6bf2c8e2d' def word_count(f, word_freq=None): - add = paddle.v2.dataset.common.dict_add - if word_freq == None: - word_freq = {} + if word_freq is None: + word_freq = collections.defaultdict(int) for l in f: for w in l.strip().split(): - add(word_freq, w) - add(word_freq, '') - add(word_freq, '') + word_freq[w] += 1 + word_freq[''] += 1 + word_freq[''] += 1 return word_freq