提交 e62a4d7a 编写于 作者: W wanghaoshuang

xmap: change multiprocess to multithread.

images reader: read the data without untarring the tarball file.
image.py: move batch function from reader to image.py
上级 2799b0ec
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CIFAR dataset.
This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators.
......@@ -35,13 +33,11 @@ import itertools
from common import download
import tarfile
import scipy.io as scio
from image import *
from paddle.v2.image import *
import os
from multiprocessing import Process
from multiprocessing import Pool
from multiprocessing import cpu_count
import numpy as np
import paddle.v2 as paddle
from multiprocessing import cpu_count
__all__ = ['train', 'test', 'valid']
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
......@@ -52,33 +48,6 @@ LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
def extract_file(tarFile):
'''
Extract tar file to tmp dir.
Example usage:
.. code-block:: python
tmp = extract_file("/home/work/test.tar.gz")
:param tarFile: target tar file
:type tarFile: string
:return: extracted dir. For example:
'/home/work/test/' while input is '/home/work/test.tar.gz'
:rtype: string
'''
base_dir = os.path.dirname(tarFile)
base_name = os.path.basename(tarFile)
if '.' in base_name:
base_name = base_name.split('.', 1)[0]
out_path = '/'.join([base_dir, base_name])
if not os.path.exists(out_path):
df = tarfile.open(tarFile, mode='r')
df.extractall(path=out_path)
df.close()
return out_path
def default_mapper(sample):
'''
map image bytes data to type needed by model input layer
......@@ -92,12 +61,13 @@ def default_mapper(sample):
def reader_creator(data_file,
label_file,
setid_file,
flag,
mapper=default_mapper):
dataset_name,
mapper=default_mapper,
buffered_size=1024):
'''
1. extract 102flowers.tgz to 102flowers/
2. merge images into batch files in 102flowers_batch/
3. get a reader to read sample from batch file
1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/
2. get a reader to read sample from batch file
:param data_file: downloaded data file
:type data_file: string
......@@ -106,17 +76,23 @@ def reader_creator(data_file,
:param setid_file: downloaded setid file containing information
about how to split dataset
:type setid_file: string
:param flag: data set name (tstid|trnid|valid)
:type flag: string
:param dataset_name: data set name (tstid|trnid|valid)
:type dataset_name: string
:param mapper: a function to map image bytes data to type
needed by model input layer
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: data reader
:rtype: callable
'''
base_dir = os.path.dirname(data_file)
tmp_dir = extract_file(data_file)
file_list = create_batch(tmp_dir, label_file, setid_file, flag)
labels = scio.loadmat(label_file)['labels'][0]
indexes = scio.loadmat(setid_file)[dataset_name][0]
img2label = {}
for i in indexes:
img = "jpg/image_%05d.jpg" % i
img2label[img] = labels[i - 1]
file_list = batch_images_from_tar(data_file, dataset_name, img2label)
def reader():
for file in open(file_list):
......@@ -129,66 +105,10 @@ def reader_creator(data_file,
for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label)
return paddle.reader.xmap(mapper, reader, cpu_count(), 1024 * 8)
return paddle.reader.xmap(mapper, reader, cpu_count(), buffered_size)
def create_batch(data_dir,
label_file,
setid_file,
flag,
numPerBatch=1024,
nThread=16):
batch_dir = data_dir + "_batch"
labels = scio.loadmat(label_file)['labels'][0]
indexes = scio.loadmat(setid_file)[flag][0]
count = len(indexes)
out_path = "%s/%s" % (batch_dir, flag)
meta_file = "%s/%s.txt" % (batch_dir, flag)
if os.path.exists(out_path):
return meta_file
else:
os.makedirs(out_path)
def batch(file_out, start, end):
data = []
labellist = []
for index in indexes[start:end]:
img_name = "%s/jpg/image_%05d.jpg" % (data_dir, index)
with open(img_name, 'r') as f:
data.append(f.read())
labellist.append(labels[index - 1])
output = {}
output['label'] = labellist
output['data'] = data
cPickle.dump(
output, open(file_out, 'w'), protocol=cPickle.HIGHEST_PROTOCOL)
cur_id = 0
file_id = 0
while cur_id < count:
thread = []
for i in xrange(nThread):
end_id = min(cur_id + numPerBatch, count)
batch_file_name = "%s/batch_%05d" % (out_path, file_id)
w = Process(target=batch, args=(batch_file_name, cur_id, end_id))
w.daemon = True
thread.append(w)
cur_id = end_id
file_id += 1
if cur_id == count:
break
for t in thread:
t.start()
for t in thread:
t.join()
with open(meta_file, 'a') as meta:
for file in os.listdir(out_path):
meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n")
return meta_file
def train(mapper=default_mapper):
def train(mapper=default_mapper, buffered_size=1024):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
......@@ -199,16 +119,19 @@ def train(mapper=default_mapper):
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: train data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid')
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper,
buffered_size)
def test(mapper=default_mapper):
def test(mapper=default_mapper, buffered_size=1024):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
......@@ -219,16 +142,19 @@ def test(mapper=default_mapper):
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid')
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper,
buffered_size)
def valid():
def valid(mapper=default_mapper, buffered_size=1024):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
......@@ -237,19 +163,21 @@ def valid():
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid')
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper,
buffered_size)
def fetch():
download(DATA_URL, 'flowers', DATA_MD5)
download(LABEL_URL, 'flowers', LABEL_MD5)
download(SETID_URL, 'flowers', SETID_MD5)
if __name__ == '__main__':
for i in test()():
pass
......@@ -5,10 +5,14 @@ except ImportError:
cv2 = None
from cv2 import resize
import os
import tarfile
import cPickle
__all__ = [
"load_image_bytes", "load_image", "resize_short", "to_chw", "center_crop",
"random_crop", "left_right_flip", "simple_transform", "load_and_transform"
"random_crop", "left_right_flip", "simple_transform", "load_and_transform",
"batch_images_from_tar"
]
"""
This file contains some common interfaces for image preprocess.
......@@ -28,6 +32,68 @@ the image layout as follows.
"""
def batch_images_from_tar(data_file,
dataset_name,
img2label,
num_per_batch=1024):
"""
Read images from tar file and batch them into batch file.
param data_file: path of image tar file
type data_file: string
param dataset_name: 'train','test' or 'valid'
type dataset_name: string
param img2label: a dic with image file name as key
and image's label as value
type img2label: dic
param num_per_batch: image number per batch file
type num_per_batch: int
return: path of list file containing paths of batch file
rtype: string
"""
batch_dir = data_file + "_batch"
out_path = "%s/%s" % (batch_dir, dataset_name)
meta_file = "%s/%s.txt" % (batch_dir, dataset_name)
if os.path.exists(out_path):
return meta_file
else:
os.makedirs(out_path)
tf = tarfile.open(data_file)
mems = tf.getmembers()
data = []
labels = []
file_id = 0
for mem in mems:
if mem.name in img2label:
data.append(tf.extractfile(mem).read())
labels.append(img2label[mem.name])
if len(data) == num_per_batch:
output = {}
output['label'] = labels
output['data'] = data
cPickle.dump(
output,
open('%s/batch_%d' % (out_path, file_id), 'w'),
protocol=cPickle.HIGHEST_PROTOCOL)
file_id += 1
data = []
labels = []
if len(data) > 0:
output = {}
output['label'] = labels
output['data'] = data
cPickle.dump(
output,
open('%s/batch_%d' % (out_path, file_id), 'w'),
protocol=cPickle.HIGHEST_PROTOCOL)
with open(meta_file, 'a') as meta:
for file in os.listdir(out_path):
meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n")
return meta_file
def load_image_bytes(bytes, is_color=True):
"""
Load an color or gray image from bytes array.
......@@ -36,7 +102,7 @@ def load_image_bytes(bytes, is_color=True):
.. code-block:: python
with open('cat.jpg') as f:
im = load_image(f.read())
im = load_image_bytes(f.read())
:param bytes: the input image bytes array.
:type file: str
......
......@@ -21,8 +21,6 @@ import itertools
import random
from Queue import Queue
from threading import Thread
from multiprocessing import Queue as MQueue
from multiprocessing import Process
def map_readers(func, *readers):
......@@ -248,8 +246,8 @@ def xmap(mapper, reader, process_num, buffer_size):
:rtype: callable
"""
end = XmapEndSignal()
in_queue = MQueue(buffer_size)
out_queue = MQueue(buffer_size)
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
......@@ -276,7 +274,7 @@ def xmap(mapper, reader, process_num, buffer_size):
# start several handle_workers
workers = []
for i in xrange(process_num):
worker = Process(
worker = Thread(
target=handle_worker, args=(in_queue, out_queue, mapper))
worker.daemon = True
workers.append(worker)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册