diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index a5ffe25a116e9be039bdebaaaad435685e23d372..fcf4437ffaf329f52cc5bc997eff45dee200873c 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -32,3 +32,10 @@ def download(url, module_name, md5sum): shutil.copyfileobj(r.raw, f) return filename + + +def dict_add(a_dict, ele): + if ele in a_dict: + a_dict[ele] += 1 + else: + a_dict[ele] = 1 diff --git a/python/paddle/v2/dataset/imikolov.py b/python/paddle/v2/dataset/imikolov.py index d9518dd27e9bbe02e9f47bae8af544fcece0a2c3..b3791ddad66e588356338150fccadbcc8fa113ca 100644 --- a/python/paddle/v2/dataset/imikolov.py +++ b/python/paddle/v2/dataset/imikolov.py @@ -10,14 +10,8 @@ URL = 'http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz' MD5 = '30177ea32e27c525793142b6bf2c8e2d' -def add(a_dict, ele): - if ele in a_dict: - a_dict[ele] += 1 - else: - a_dict[ele] = 1 - - def word_count(f, word_freq=None): + add = paddle.v2.dataset.common.dict_add if word_freq == None: word_freq = {} @@ -45,7 +39,7 @@ def build_dict(train_filename, test_filename): dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) words, _ = list(zip(*dictionary)) word_idx = dict(zip(words, xrange(len(words)))) - word_idx[''] = len(words) + word_idx[''] = len(words) return word_idx @@ -66,13 +60,13 @@ def reader_creator(filename, n): paddle.v2.dataset.imikolov.MD5)) as tf: f = tf.extractfile(filename) - ANY = word_idx[''] + UNK = word_idx[''] for l in f: l = [''] + l.strip().split() + [''] if len(l) >= n: - l = [word_idx.get(w, ANY) for w in l] + l = [word_idx.get(w, UNK) for w in l] for i in range(n, len(l) + 1): - yield l[i - n:i] + yield tuple(l[i - n:i]) return reader