未验证 提交 c3fd2c28 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #7002 from qingqing01/imdb_data

 Speed data reader for IMDB dataset.
...@@ -23,10 +23,9 @@ Besides, this module also provides API for building dictionary. ...@@ -23,10 +23,9 @@ Besides, this module also provides API for building dictionary.
import paddle.v2.dataset.common import paddle.v2.dataset.common
import collections import collections
import tarfile import tarfile
import Queue
import re import re
import string import string
import threading import random
__all__ = ['build_dict', 'train', 'test', 'convert'] __all__ = ['build_dict', 'train', 'test', 'convert']
...@@ -74,47 +73,21 @@ def build_dict(pattern, cutoff): ...@@ -74,47 +73,21 @@ def build_dict(pattern, cutoff):
return word_idx return word_idx
def reader_creator(pos_pattern, neg_pattern, word_idx, buffer_size): def reader_creator(pos_pattern, neg_pattern, word_idx):
UNK = word_idx['<unk>'] UNK = word_idx['<unk>']
INS = []
qs = [Queue.Queue(maxsize=buffer_size), Queue.Queue(maxsize=buffer_size)] def load(pattern, out, label):
def load(pattern, queue):
for doc in tokenize(pattern): for doc in tokenize(pattern):
queue.put(doc) out.append(([word_idx.get(w, UNK) for w in doc], label))
queue.put(None)
load(pos_pattern, INS, 0)
load(neg_pattern, INS, 1)
random.shuffle(INS)
def reader(): def reader():
# Creates two threads that loads positive and negative samples for doc, label in INS:
# into qs. yield doc, label
t0 = threading.Thread(
target=load, args=(
pos_pattern,
qs[0], ))
t0.daemon = True
t0.start()
t1 = threading.Thread(
target=load, args=(
neg_pattern,
qs[1], ))
t1.daemon = True
t1.start()
# Read alternatively from qs[0] and qs[1].
i = 0
doc = qs[i].get()
while doc != None:
yield [word_idx.get(w, UNK) for w in doc], i % 2
i += 1
doc = qs[i % 2].get()
# If any queue is empty, reads from the other queue.
i += 1
doc = qs[i % 2].get()
while doc != None:
yield [word_idx.get(w, UNK) for w in doc], i % 2
doc = qs[i % 2].get()
return reader return reader
...@@ -133,7 +106,7 @@ def train(word_idx): ...@@ -133,7 +106,7 @@ def train(word_idx):
""" """
return reader_creator( return reader_creator(
re.compile("aclImdb/train/pos/.*\.txt$"), re.compile("aclImdb/train/pos/.*\.txt$"),
re.compile("aclImdb/train/neg/.*\.txt$"), word_idx, 1000) re.compile("aclImdb/train/neg/.*\.txt$"), word_idx)
def test(word_idx): def test(word_idx):
...@@ -150,7 +123,7 @@ def test(word_idx): ...@@ -150,7 +123,7 @@ def test(word_idx):
""" """
return reader_creator( return reader_creator(
re.compile("aclImdb/test/pos/.*\.txt$"), re.compile("aclImdb/test/pos/.*\.txt$"),
re.compile("aclImdb/test/neg/.*\.txt$"), word_idx, 1000) re.compile("aclImdb/test/neg/.*\.txt$"), word_idx)
def word_dict(): def word_dict():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册