From e62a4d7abe5287fd5fdc3464ef81a5c682a49589 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Fri, 2 Jun 2017 10:56:15 +0800 Subject: [PATCH] xmap: change multiprocess to multithread. images reader: read the data without untarring the tarball file. image.py: move batch function from reader to image.py --- python/paddle/v2/dataset/flowers.py | 150 +++++++-------------------- python/paddle/v2/image.py | 70 ++++++++++++- python/paddle/v2/reader/decorator.py | 8 +- 3 files changed, 110 insertions(+), 118 deletions(-) diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index 3d38b5dab9..d9a39b11df 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """ -CIFAR dataset. - This module will download dataset from http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html and parse train/test set intopaddle reader creators. @@ -35,13 +33,11 @@ import itertools from common import download import tarfile import scipy.io as scio -from image import * +from paddle.v2.image import * import os -from multiprocessing import Process -from multiprocessing import Pool -from multiprocessing import cpu_count import numpy as np import paddle.v2 as paddle +from multiprocessing import cpu_count __all__ = ['train', 'test', 'valid'] DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' @@ -52,33 +48,6 @@ LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' -def extract_file(tarFile): - ''' - Extract tar file to tmp dir. - - Example usage: - - .. code-block:: python - tmp = extract_file("/home/work/test.tar.gz") - - :param tarFile: target tar file - :type tarFile: string - :return: extracted dir. For example: - '/home/work/test/' while input is '/home/work/test.tar.gz' - :rtype: string - ''' - base_dir = os.path.dirname(tarFile) - base_name = os.path.basename(tarFile) - if '.' in base_name: - base_name = base_name.split('.', 1)[0] - out_path = '/'.join([base_dir, base_name]) - if not os.path.exists(out_path): - df = tarfile.open(tarFile, mode='r') - df.extractall(path=out_path) - df.close() - return out_path - - def default_mapper(sample): ''' map image bytes data to type needed by model input layer @@ -92,12 +61,13 @@ def default_mapper(sample): def reader_creator(data_file, label_file, setid_file, - flag, - mapper=default_mapper): + dataset_name, + mapper=default_mapper, + buffered_size=1024): ''' - 1. extract 102flowers.tgz to 102flowers/ - 2. merge images into batch files in 102flowers_batch/ - 3. get a reader to read sample from batch file + 1. read images from tar file and + merge images into batch files in 102flowers.tgz_batch/ + 2. get a reader to read sample from batch file :param data_file: downloaded data file :type data_file: string @@ -106,17 +76,23 @@ def reader_creator(data_file, :param setid_file: downloaded setid file containing information about how to split dataset :type setid_file: string - :param flag: data set name (tstid|trnid|valid) - :type flag: string + :param dataset_name: data set name (tstid|trnid|valid) + :type dataset_name: string :param mapper: a function to map image bytes data to type needed by model input layer :type mapper: callable + :param buffered_size: the size of buffer used to process images + :type buffered_size: int :return: data reader :rtype: callable ''' - base_dir = os.path.dirname(data_file) - tmp_dir = extract_file(data_file) - file_list = create_batch(tmp_dir, label_file, setid_file, flag) + labels = scio.loadmat(label_file)['labels'][0] + indexes = scio.loadmat(setid_file)[dataset_name][0] + img2label = {} + for i in indexes: + img = "jpg/image_%05d.jpg" % i + img2label[img] = labels[i - 1] + file_list = batch_images_from_tar(data_file, dataset_name, img2label) def reader(): for file in open(file_list): @@ -129,66 +105,10 @@ def reader_creator(data_file, for sample, label in itertools.izip(data, batch['label']): yield sample, int(label) - return paddle.reader.xmap(mapper, reader, cpu_count(), 1024 * 8) + return paddle.reader.xmap(mapper, reader, cpu_count(), buffered_size) -def create_batch(data_dir, - label_file, - setid_file, - flag, - numPerBatch=1024, - nThread=16): - batch_dir = data_dir + "_batch" - labels = scio.loadmat(label_file)['labels'][0] - indexes = scio.loadmat(setid_file)[flag][0] - count = len(indexes) - out_path = "%s/%s" % (batch_dir, flag) - meta_file = "%s/%s.txt" % (batch_dir, flag) - - if os.path.exists(out_path): - return meta_file - else: - os.makedirs(out_path) - - def batch(file_out, start, end): - data = [] - labellist = [] - for index in indexes[start:end]: - img_name = "%s/jpg/image_%05d.jpg" % (data_dir, index) - with open(img_name, 'r') as f: - data.append(f.read()) - labellist.append(labels[index - 1]) - output = {} - output['label'] = labellist - output['data'] = data - cPickle.dump( - output, open(file_out, 'w'), protocol=cPickle.HIGHEST_PROTOCOL) - - cur_id = 0 - file_id = 0 - while cur_id < count: - thread = [] - for i in xrange(nThread): - end_id = min(cur_id + numPerBatch, count) - batch_file_name = "%s/batch_%05d" % (out_path, file_id) - w = Process(target=batch, args=(batch_file_name, cur_id, end_id)) - w.daemon = True - thread.append(w) - cur_id = end_id - file_id += 1 - if cur_id == count: - break - for t in thread: - t.start() - for t in thread: - t.join() - with open(meta_file, 'a') as meta: - for file in os.listdir(out_path): - meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n") - return meta_file - - -def train(mapper=default_mapper): +def train(mapper=default_mapper, buffered_size=1024): ''' Create flowers training set reader. It returns a reader, each sample in the reader is @@ -199,16 +119,19 @@ def train(mapper=default_mapper): 3. flatten :param mapper: a function to map sample. :type mapper: callable + :param buffered_size: the size of buffer used to process images + :type buffered_size: int :return: train data reader :rtype: callable ''' return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'trnid') + download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, + buffered_size) -def test(mapper=default_mapper): +def test(mapper=default_mapper, buffered_size=1024): ''' Create flowers test set reader. It returns a reader, each sample in the reader is @@ -219,16 +142,19 @@ def test(mapper=default_mapper): 3. flatten :param mapper: a function to map sample. :type mapper: callable + :param buffered_size: the size of buffer used to process images + :type buffered_size: int :return: test data reader :rtype: callable ''' return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'tstid') + download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, + buffered_size) -def valid(): +def valid(mapper=default_mapper, buffered_size=1024): ''' Create flowers validation set reader. It returns a reader, each sample in the reader is @@ -237,19 +163,21 @@ def valid(): 1. resize to 256*256 2. random crop to 224*224 3. flatten + :param mapper: a function to map sample. + :type mapper: callable + :param buffered_size: the size of buffer used to process images + :type buffered_size: int + :return: test data reader + :rtype: callable ''' return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'valid') + download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, + buffered_size) def fetch(): download(DATA_URL, 'flowers', DATA_MD5) download(LABEL_URL, 'flowers', LABEL_MD5) download(SETID_URL, 'flowers', SETID_MD5) - - -if __name__ == '__main__': - for i in test()(): - pass diff --git a/python/paddle/v2/image.py b/python/paddle/v2/image.py index cb5725de68..56031e8734 100644 --- a/python/paddle/v2/image.py +++ b/python/paddle/v2/image.py @@ -5,10 +5,14 @@ except ImportError: cv2 = None from cv2 import resize +import os +import tarfile +import cPickle __all__ = [ "load_image_bytes", "load_image", "resize_short", "to_chw", "center_crop", - "random_crop", "left_right_flip", "simple_transform", "load_and_transform" + "random_crop", "left_right_flip", "simple_transform", "load_and_transform", + "batch_images_from_tar" ] """ This file contains some common interfaces for image preprocess. @@ -28,6 +32,68 @@ the image layout as follows. """ +def batch_images_from_tar(data_file, + dataset_name, + img2label, + num_per_batch=1024): + """ + Read images from tar file and batch them into batch file. + param data_file: path of image tar file + type data_file: string + param dataset_name: 'train','test' or 'valid' + type dataset_name: string + param img2label: a dic with image file name as key + and image's label as value + type img2label: dic + param num_per_batch: image number per batch file + type num_per_batch: int + return: path of list file containing paths of batch file + rtype: string + """ + batch_dir = data_file + "_batch" + out_path = "%s/%s" % (batch_dir, dataset_name) + meta_file = "%s/%s.txt" % (batch_dir, dataset_name) + + if os.path.exists(out_path): + return meta_file + else: + os.makedirs(out_path) + + tf = tarfile.open(data_file) + mems = tf.getmembers() + data = [] + labels = [] + file_id = 0 + for mem in mems: + if mem.name in img2label: + data.append(tf.extractfile(mem).read()) + labels.append(img2label[mem.name]) + if len(data) == num_per_batch: + output = {} + output['label'] = labels + output['data'] = data + cPickle.dump( + output, + open('%s/batch_%d' % (out_path, file_id), 'w'), + protocol=cPickle.HIGHEST_PROTOCOL) + file_id += 1 + data = [] + labels = [] + if len(data) > 0: + output = {} + output['label'] = labels + output['data'] = data + cPickle.dump( + output, + open('%s/batch_%d' % (out_path, file_id), 'w'), + protocol=cPickle.HIGHEST_PROTOCOL) + + with open(meta_file, 'a') as meta: + for file in os.listdir(out_path): + meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n") + return meta_file + + def load_image_bytes(bytes, is_color=True): """ Load an color or gray image from bytes array. @@ -36,7 +102,7 @@ def load_image_bytes(bytes, is_color=True): .. code-block:: python with open('cat.jpg') as f: - im = load_image(f.read()) + im = load_image_bytes(f.read()) :param bytes: the input image bytes array. :type file: str diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index f06792314f..1b5df21b3d 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -21,8 +21,6 @@ import itertools import random from Queue import Queue from threading import Thread -from multiprocessing import Queue as MQueue -from multiprocessing import Process def map_readers(func, *readers): @@ -248,8 +246,8 @@ def xmap(mapper, reader, process_num, buffer_size): :rtype: callable """ end = XmapEndSignal() - in_queue = MQueue(buffer_size) - out_queue = MQueue(buffer_size) + in_queue = Queue(buffer_size) + out_queue = Queue(buffer_size) # define a worker to read samples from reader to in_queue def read_worker(reader, in_queue): @@ -276,7 +274,7 @@ def xmap(mapper, reader, process_num, buffer_size): # start several handle_workers workers = [] for i in xrange(process_num): - worker = Process( + worker = Thread( target=handle_worker, args=(in_queue, out_queue, mapper)) worker.daemon = True workers.append(worker) -- GitLab