diff --git a/python/paddle/v2/dataset/imdb.py b/python/paddle/v2/dataset/imdb.py index cfc1c886e1389c15e3f803c341b6f62dd7b4bf41..21ed7f7a5ce279f5bc65e5b008f14a1b0ff97343 100644 --- a/python/paddle/v2/dataset/imdb.py +++ b/python/paddle/v2/dataset/imdb.py @@ -23,10 +23,9 @@ Besides, this module also provides API for building dictionary. import paddle.v2.dataset.common import collections import tarfile -import Queue import re import string -import threading +import random __all__ = ['build_dict', 'train', 'test', 'convert'] @@ -74,47 +73,21 @@ def build_dict(pattern, cutoff): return word_idx -def reader_creator(pos_pattern, neg_pattern, word_idx, buffer_size): +def reader_creator(pos_pattern, neg_pattern, word_idx): UNK = word_idx[''] + INS = [] - qs = [Queue.Queue(maxsize=buffer_size), Queue.Queue(maxsize=buffer_size)] - - def load(pattern, queue): + def load(pattern, out, label): for doc in tokenize(pattern): - queue.put(doc) - queue.put(None) + out.append(([word_idx.get(w, UNK) for w in doc], label)) + + load(pos_pattern, INS, 0) + load(neg_pattern, INS, 1) + random.shuffle(INS) def reader(): - # Creates two threads that loads positive and negative samples - # into qs. - t0 = threading.Thread( - target=load, args=( - pos_pattern, - qs[0], )) - t0.daemon = True - t0.start() - - t1 = threading.Thread( - target=load, args=( - neg_pattern, - qs[1], )) - t1.daemon = True - t1.start() - - # Read alternatively from qs[0] and qs[1]. - i = 0 - doc = qs[i].get() - while doc != None: - yield [word_idx.get(w, UNK) for w in doc], i % 2 - i += 1 - doc = qs[i % 2].get() - - # If any queue is empty, reads from the other queue. - i += 1 - doc = qs[i % 2].get() - while doc != None: - yield [word_idx.get(w, UNK) for w in doc], i % 2 - doc = qs[i % 2].get() + for doc, label in INS: + yield doc, label return reader @@ -133,7 +106,7 @@ def train(word_idx): """ return reader_creator( re.compile("aclImdb/train/pos/.*\.txt$"), - re.compile("aclImdb/train/neg/.*\.txt$"), word_idx, 1000) + re.compile("aclImdb/train/neg/.*\.txt$"), word_idx) def test(word_idx): @@ -150,7 +123,7 @@ def test(word_idx): """ return reader_creator( re.compile("aclImdb/test/pos/.*\.txt$"), - re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000) + re.compile("aclImdb/test/neg/.*\.txt$"), word_idx) def word_dict():