提交 e62a4d7a 编写于 作者: W wanghaoshuang

xmap: change multiprocess to multithread.

images reader: read the data without untarring the tarball file.
image.py: move batch function from reader to image.py
上级 2799b0ec
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
CIFAR dataset.
This module will download dataset from This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators. and parse train/test set intopaddle reader creators.
...@@ -35,13 +33,11 @@ import itertools ...@@ -35,13 +33,11 @@ import itertools
from common import download from common import download
import tarfile import tarfile
import scipy.io as scio import scipy.io as scio
from image import * from paddle.v2.image import *
import os import os
from multiprocessing import Process
from multiprocessing import Pool
from multiprocessing import cpu_count
import numpy as np import numpy as np
import paddle.v2 as paddle import paddle.v2 as paddle
from multiprocessing import cpu_count
__all__ = ['train', 'test', 'valid'] __all__ = ['train', 'test', 'valid']
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
...@@ -52,33 +48,6 @@ LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' ...@@ -52,33 +48,6 @@ LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
def extract_file(tarFile):
'''
Extract tar file to tmp dir.
Example usage:
.. code-block:: python
tmp = extract_file("/home/work/test.tar.gz")
:param tarFile: target tar file
:type tarFile: string
:return: extracted dir. For example:
'/home/work/test/' while input is '/home/work/test.tar.gz'
:rtype: string
'''
base_dir = os.path.dirname(tarFile)
base_name = os.path.basename(tarFile)
if '.' in base_name:
base_name = base_name.split('.', 1)[0]
out_path = '/'.join([base_dir, base_name])
if not os.path.exists(out_path):
df = tarfile.open(tarFile, mode='r')
df.extractall(path=out_path)
df.close()
return out_path
def default_mapper(sample): def default_mapper(sample):
''' '''
map image bytes data to type needed by model input layer map image bytes data to type needed by model input layer
...@@ -92,12 +61,13 @@ def default_mapper(sample): ...@@ -92,12 +61,13 @@ def default_mapper(sample):
def reader_creator(data_file, def reader_creator(data_file,
label_file, label_file,
setid_file, setid_file,
flag, dataset_name,
mapper=default_mapper): mapper=default_mapper,
buffered_size=1024):
''' '''
1. extract 102flowers.tgz to 102flowers/ 1. read images from tar file and
2. merge images into batch files in 102flowers_batch/ merge images into batch files in 102flowers.tgz_batch/
3. get a reader to read sample from batch file 2. get a reader to read sample from batch file
:param data_file: downloaded data file :param data_file: downloaded data file
:type data_file: string :type data_file: string
...@@ -106,17 +76,23 @@ def reader_creator(data_file, ...@@ -106,17 +76,23 @@ def reader_creator(data_file,
:param setid_file: downloaded setid file containing information :param setid_file: downloaded setid file containing information
about how to split dataset about how to split dataset
:type setid_file: string :type setid_file: string
:param flag: data set name (tstid|trnid|valid) :param dataset_name: data set name (tstid|trnid|valid)
:type flag: string :type dataset_name: string
:param mapper: a function to map image bytes data to type :param mapper: a function to map image bytes data to type
needed by model input layer needed by model input layer
:type mapper: callable :type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: data reader :return: data reader
:rtype: callable :rtype: callable
''' '''
base_dir = os.path.dirname(data_file) labels = scio.loadmat(label_file)['labels'][0]
tmp_dir = extract_file(data_file) indexes = scio.loadmat(setid_file)[dataset_name][0]
file_list = create_batch(tmp_dir, label_file, setid_file, flag) img2label = {}
for i in indexes:
img = "jpg/image_%05d.jpg" % i
img2label[img] = labels[i - 1]
file_list = batch_images_from_tar(data_file, dataset_name, img2label)
def reader(): def reader():
for file in open(file_list): for file in open(file_list):
...@@ -129,66 +105,10 @@ def reader_creator(data_file, ...@@ -129,66 +105,10 @@ def reader_creator(data_file,
for sample, label in itertools.izip(data, batch['label']): for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) yield sample, int(label)
return paddle.reader.xmap(mapper, reader, cpu_count(), 1024 * 8) return paddle.reader.xmap(mapper, reader, cpu_count(), buffered_size)
def create_batch(data_dir, def train(mapper=default_mapper, buffered_size=1024):
label_file,
setid_file,
flag,
numPerBatch=1024,
nThread=16):
batch_dir = data_dir + "_batch"
labels = scio.loadmat(label_file)['labels'][0]
indexes = scio.loadmat(setid_file)[flag][0]
count = len(indexes)
out_path = "%s/%s" % (batch_dir, flag)
meta_file = "%s/%s.txt" % (batch_dir, flag)
if os.path.exists(out_path):
return meta_file
else:
os.makedirs(out_path)
def batch(file_out, start, end):
data = []
labellist = []
for index in indexes[start:end]:
img_name = "%s/jpg/image_%05d.jpg" % (data_dir, index)
with open(img_name, 'r') as f:
data.append(f.read())
labellist.append(labels[index - 1])
output = {}
output['label'] = labellist
output['data'] = data
cPickle.dump(
output, open(file_out, 'w'), protocol=cPickle.HIGHEST_PROTOCOL)
cur_id = 0
file_id = 0
while cur_id < count:
thread = []
for i in xrange(nThread):
end_id = min(cur_id + numPerBatch, count)
batch_file_name = "%s/batch_%05d" % (out_path, file_id)
w = Process(target=batch, args=(batch_file_name, cur_id, end_id))
w.daemon = True
thread.append(w)
cur_id = end_id
file_id += 1
if cur_id == count:
break
for t in thread:
t.start()
for t in thread:
t.join()
with open(meta_file, 'a') as meta:
for file in os.listdir(out_path):
meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n")
return meta_file
def train(mapper=default_mapper):
''' '''
Create flowers training set reader. Create flowers training set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -199,16 +119,19 @@ def train(mapper=default_mapper): ...@@ -199,16 +119,19 @@ def train(mapper=default_mapper):
3. flatten 3. flatten
:param mapper: a function to map sample. :param mapper: a function to map sample.
:type mapper: callable :type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: train data reader :return: train data reader
:rtype: callable :rtype: callable
''' '''
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid') download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper,
buffered_size)
def test(mapper=default_mapper): def test(mapper=default_mapper, buffered_size=1024):
''' '''
Create flowers test set reader. Create flowers test set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -219,16 +142,19 @@ def test(mapper=default_mapper): ...@@ -219,16 +142,19 @@ def test(mapper=default_mapper):
3. flatten 3. flatten
:param mapper: a function to map sample. :param mapper: a function to map sample.
:type mapper: callable :type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader :return: test data reader
:rtype: callable :rtype: callable
''' '''
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid') download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper,
buffered_size)
def valid(): def valid(mapper=default_mapper, buffered_size=1024):
''' '''
Create flowers validation set reader. Create flowers validation set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -237,19 +163,21 @@ def valid(): ...@@ -237,19 +163,21 @@ def valid():
1. resize to 256*256 1. resize to 256*256
2. random crop to 224*224 2. random crop to 224*224
3. flatten 3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
''' '''
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid') download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper,
buffered_size)
def fetch(): def fetch():
download(DATA_URL, 'flowers', DATA_MD5) download(DATA_URL, 'flowers', DATA_MD5)
download(LABEL_URL, 'flowers', LABEL_MD5) download(LABEL_URL, 'flowers', LABEL_MD5)
download(SETID_URL, 'flowers', SETID_MD5) download(SETID_URL, 'flowers', SETID_MD5)
if __name__ == '__main__':
for i in test()():
pass
...@@ -5,10 +5,14 @@ except ImportError: ...@@ -5,10 +5,14 @@ except ImportError:
cv2 = None cv2 = None
from cv2 import resize from cv2 import resize
import os
import tarfile
import cPickle
__all__ = [ __all__ = [
"load_image_bytes", "load_image", "resize_short", "to_chw", "center_crop", "load_image_bytes", "load_image", "resize_short", "to_chw", "center_crop",
"random_crop", "left_right_flip", "simple_transform", "load_and_transform" "random_crop", "left_right_flip", "simple_transform", "load_and_transform",
"batch_images_from_tar"
] ]
""" """
This file contains some common interfaces for image preprocess. This file contains some common interfaces for image preprocess.
...@@ -28,6 +32,68 @@ the image layout as follows. ...@@ -28,6 +32,68 @@ the image layout as follows.
""" """
def batch_images_from_tar(data_file,
dataset_name,
img2label,
num_per_batch=1024):
"""
Read images from tar file and batch them into batch file.
param data_file: path of image tar file
type data_file: string
param dataset_name: 'train','test' or 'valid'
type dataset_name: string
param img2label: a dic with image file name as key
and image's label as value
type img2label: dic
param num_per_batch: image number per batch file
type num_per_batch: int
return: path of list file containing paths of batch file
rtype: string
"""
batch_dir = data_file + "_batch"
out_path = "%s/%s" % (batch_dir, dataset_name)
meta_file = "%s/%s.txt" % (batch_dir, dataset_name)
if os.path.exists(out_path):
return meta_file
else:
os.makedirs(out_path)
tf = tarfile.open(data_file)
mems = tf.getmembers()
data = []
labels = []
file_id = 0
for mem in mems:
if mem.name in img2label:
data.append(tf.extractfile(mem).read())
labels.append(img2label[mem.name])
if len(data) == num_per_batch:
output = {}
output['label'] = labels
output['data'] = data
cPickle.dump(
output,
open('%s/batch_%d' % (out_path, file_id), 'w'),
protocol=cPickle.HIGHEST_PROTOCOL)
file_id += 1
data = []
labels = []
if len(data) > 0:
output = {}
output['label'] = labels
output['data'] = data
cPickle.dump(
output,
open('%s/batch_%d' % (out_path, file_id), 'w'),
protocol=cPickle.HIGHEST_PROTOCOL)
with open(meta_file, 'a') as meta:
for file in os.listdir(out_path):
meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n")
return meta_file
def load_image_bytes(bytes, is_color=True): def load_image_bytes(bytes, is_color=True):
""" """
Load an color or gray image from bytes array. Load an color or gray image from bytes array.
...@@ -36,7 +102,7 @@ def load_image_bytes(bytes, is_color=True): ...@@ -36,7 +102,7 @@ def load_image_bytes(bytes, is_color=True):
.. code-block:: python .. code-block:: python
with open('cat.jpg') as f: with open('cat.jpg') as f:
im = load_image(f.read()) im = load_image_bytes(f.read())
:param bytes: the input image bytes array. :param bytes: the input image bytes array.
:type file: str :type file: str
......
...@@ -21,8 +21,6 @@ import itertools ...@@ -21,8 +21,6 @@ import itertools
import random import random
from Queue import Queue from Queue import Queue
from threading import Thread from threading import Thread
from multiprocessing import Queue as MQueue
from multiprocessing import Process
def map_readers(func, *readers): def map_readers(func, *readers):
...@@ -248,8 +246,8 @@ def xmap(mapper, reader, process_num, buffer_size): ...@@ -248,8 +246,8 @@ def xmap(mapper, reader, process_num, buffer_size):
:rtype: callable :rtype: callable
""" """
end = XmapEndSignal() end = XmapEndSignal()
in_queue = MQueue(buffer_size) in_queue = Queue(buffer_size)
out_queue = MQueue(buffer_size) out_queue = Queue(buffer_size)
# define a worker to read samples from reader to in_queue # define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue): def read_worker(reader, in_queue):
...@@ -276,7 +274,7 @@ def xmap(mapper, reader, process_num, buffer_size): ...@@ -276,7 +274,7 @@ def xmap(mapper, reader, process_num, buffer_size):
# start several handle_workers # start several handle_workers
workers = [] workers = []
for i in xrange(process_num): for i in xrange(process_num):
worker = Process( worker = Thread(
target=handle_worker, args=(in_queue, out_queue, mapper)) target=handle_worker, args=(in_queue, out_queue, mapper))
worker.daemon = True worker.daemon = True
workers.append(worker) workers.append(worker)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册