From 2799b0ec50de669709d7e95ae82b7512426d5387 Mon Sep 17 00:00:00 2001 From: "wanghaoshuang@baidu.com" Date: Wed, 24 May 2017 00:14:07 +0800 Subject: [PATCH] Add flowers dataset for image classification model --- python/paddle/v2/dataset/flowers.py | 255 ++++++++++++++++++ .../paddle/v2/dataset/tests/flowers_test.py | 51 ++++ python/paddle/v2/image.py | 36 ++- python/paddle/v2/reader/decorator.py | 75 +++++- 4 files changed, 409 insertions(+), 8 deletions(-) create mode 100644 python/paddle/v2/dataset/flowers.py create mode 100644 python/paddle/v2/dataset/tests/flowers_test.py diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py new file mode 100644 index 00000000000..3d38b5dab97 --- /dev/null +++ b/python/paddle/v2/dataset/flowers.py @@ -0,0 +1,255 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +CIFAR dataset. + +This module will download dataset from +http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html +and parse train/test set intopaddle reader creators. + +This set contains images of flowers belonging to 102 different categories. +The images were acquired by searching the web and taking pictures. There are a +minimum of 40 images for each category. + +The database was used in: + +Nilsback, M-E. and Zisserman, A. Automated flower classification over a large + number of classes.Proceedings of the Indian Conference on Computer Vision, +Graphics and Image Processing (2008) +http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. + +""" +import cPickle +import itertools +from common import download +import tarfile +import scipy.io as scio +from image import * +import os +from multiprocessing import Process +from multiprocessing import Pool +from multiprocessing import cpu_count +import numpy as np +import paddle.v2 as paddle +__all__ = ['train', 'test', 'valid'] + +DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' +LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat' +SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' +DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' +LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' +SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' + + +def extract_file(tarFile): + ''' + Extract tar file to tmp dir. + + Example usage: + + .. code-block:: python + tmp = extract_file("/home/work/test.tar.gz") + + :param tarFile: target tar file + :type tarFile: string + :return: extracted dir. For example: + '/home/work/test/' while input is '/home/work/test.tar.gz' + :rtype: string + ''' + base_dir = os.path.dirname(tarFile) + base_name = os.path.basename(tarFile) + if '.' in base_name: + base_name = base_name.split('.', 1)[0] + out_path = '/'.join([base_dir, base_name]) + if not os.path.exists(out_path): + df = tarfile.open(tarFile, mode='r') + df.extractall(path=out_path) + df.close() + return out_path + + +def default_mapper(sample): + ''' + map image bytes data to type needed by model input layer + ''' + img, label = sample + img = paddle.image.load_image_bytes(img) + img = paddle.image.simple_transform(img, 256, 224, True) + return img.flatten().astype('float32'), label + + +def reader_creator(data_file, + label_file, + setid_file, + flag, + mapper=default_mapper): + ''' + 1. extract 102flowers.tgz to 102flowers/ + 2. merge images into batch files in 102flowers_batch/ + 3. get a reader to read sample from batch file + + :param data_file: downloaded data file + :type data_file: string + :param label_file: downloaded label file + :type label_file: string + :param setid_file: downloaded setid file containing information + about how to split dataset + :type setid_file: string + :param flag: data set name (tstid|trnid|valid) + :type flag: string + :param mapper: a function to map image bytes data to type + needed by model input layer + :type mapper: callable + :return: data reader + :rtype: callable + ''' + base_dir = os.path.dirname(data_file) + tmp_dir = extract_file(data_file) + file_list = create_batch(tmp_dir, label_file, setid_file, flag) + + def reader(): + for file in open(file_list): + file = file.strip() + batch = None + with open(file, 'r') as f: + batch = cPickle.load(f) + data = batch['data'] + labels = batch['label'] + for sample, label in itertools.izip(data, batch['label']): + yield sample, int(label) + + return paddle.reader.xmap(mapper, reader, cpu_count(), 1024 * 8) + + +def create_batch(data_dir, + label_file, + setid_file, + flag, + numPerBatch=1024, + nThread=16): + batch_dir = data_dir + "_batch" + labels = scio.loadmat(label_file)['labels'][0] + indexes = scio.loadmat(setid_file)[flag][0] + count = len(indexes) + out_path = "%s/%s" % (batch_dir, flag) + meta_file = "%s/%s.txt" % (batch_dir, flag) + + if os.path.exists(out_path): + return meta_file + else: + os.makedirs(out_path) + + def batch(file_out, start, end): + data = [] + labellist = [] + for index in indexes[start:end]: + img_name = "%s/jpg/image_%05d.jpg" % (data_dir, index) + with open(img_name, 'r') as f: + data.append(f.read()) + labellist.append(labels[index - 1]) + output = {} + output['label'] = labellist + output['data'] = data + cPickle.dump( + output, open(file_out, 'w'), protocol=cPickle.HIGHEST_PROTOCOL) + + cur_id = 0 + file_id = 0 + while cur_id < count: + thread = [] + for i in xrange(nThread): + end_id = min(cur_id + numPerBatch, count) + batch_file_name = "%s/batch_%05d" % (out_path, file_id) + w = Process(target=batch, args=(batch_file_name, cur_id, end_id)) + w.daemon = True + thread.append(w) + cur_id = end_id + file_id += 1 + if cur_id == count: + break + for t in thread: + t.start() + for t in thread: + t.join() + with open(meta_file, 'a') as meta: + for file in os.listdir(out_path): + meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n") + return meta_file + + +def train(mapper=default_mapper): + ''' + Create flowers training set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] + translated from original color image by steps: + 1. resize to 256*256 + 2. random crop to 224*224 + 3. flatten + :param mapper: a function to map sample. + :type mapper: callable + :return: train data reader + :rtype: callable + ''' + return reader_creator( + download(DATA_URL, 'flowers', DATA_MD5), + download(LABEL_URL, 'flowers', LABEL_MD5), + download(SETID_URL, 'flowers', SETID_MD5), 'trnid') + + +def test(mapper=default_mapper): + ''' + Create flowers test set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] + translated from original color image by steps: + 1. resize to 256*256 + 2. random crop to 224*224 + 3. flatten + :param mapper: a function to map sample. + :type mapper: callable + :return: test data reader + :rtype: callable + ''' + return reader_creator( + download(DATA_URL, 'flowers', DATA_MD5), + download(LABEL_URL, 'flowers', LABEL_MD5), + download(SETID_URL, 'flowers', SETID_MD5), 'tstid') + + +def valid(): + ''' + Create flowers validation set reader. + It returns a reader, each sample in the reader is + image pixels in [0, 1] and label in [1, 102] + translated from original color image by steps: + 1. resize to 256*256 + 2. random crop to 224*224 + 3. flatten + ''' + return reader_creator( + download(DATA_URL, 'flowers', DATA_MD5), + download(LABEL_URL, 'flowers', LABEL_MD5), + download(SETID_URL, 'flowers', SETID_MD5), 'valid') + + +def fetch(): + download(DATA_URL, 'flowers', DATA_MD5) + download(LABEL_URL, 'flowers', LABEL_MD5) + download(SETID_URL, 'flowers', SETID_MD5) + + +if __name__ == '__main__': + for i in test()(): + pass diff --git a/python/paddle/v2/dataset/tests/flowers_test.py b/python/paddle/v2/dataset/tests/flowers_test.py new file mode 100644 index 00000000000..cc0626f4fea --- /dev/null +++ b/python/paddle/v2/dataset/tests/flowers_test.py @@ -0,0 +1,51 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.v2.dataset.flowers +import unittest + + +class TestFlowers(unittest.TestCase): + def check_reader(self, reader): + sum = 0 + label = 0 + size = 224 * 224 * 3 + for l in reader(): + self.assertEqual(l[0].size, size) + if l[1] > label: + label = l[1] + sum += 1 + return sum, label + + def test_train(self): + instances, max_label_value = self.check_reader( + paddle.v2.dataset.flowers.train()) + self.assertEqual(instances, 1020) + self.assertEqual(max_label_value, 102) + + def test_test(self): + instances, max_label_value = self.check_reader( + paddle.v2.dataset.flowers.test()) + self.assertEqual(instances, 6149) + self.assertEqual(max_label_value, 102) + + def test_valid(self): + instances, max_label_value = self.check_reader( + paddle.v2.dataset.flowers.valid()) + self.assertEqual(instances, 1020) + self.assertEqual(max_label_value, 102) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/image.py b/python/paddle/v2/image.py index 85ad6984ba0..cb5725de686 100644 --- a/python/paddle/v2/image.py +++ b/python/paddle/v2/image.py @@ -1,14 +1,14 @@ import numpy as np try: import cv2 -except: - print( - "import cv2 error, please install opencv-python: pip install opencv-python" - ) +except ImportError: + cv2 = None + +from cv2 import resize __all__ = [ - "load_image", "resize_short", "to_chw", "center_crop", "random_crop", - "left_right_flip", "simple_transform", "load_and_transform" + "load_image_bytes", "load_image", "resize_short", "to_chw", "center_crop", + "random_crop", "left_right_flip", "simple_transform", "load_and_transform" ] """ This file contains some common interfaces for image preprocess. @@ -28,6 +28,28 @@ the image layout as follows. """ +def load_image_bytes(bytes, is_color=True): + """ + Load an color or gray image from bytes array. + + Example usage: + + .. code-block:: python + with open('cat.jpg') as f: + im = load_image(f.read()) + + :param bytes: the input image bytes array. + :type file: str + :param is_color: If set is_color True, it will load and + return a color image. Otherwise, it will + load and return a gray image. + """ + flag = 1 if is_color else 0 + file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8) + img = cv2.imdecode(file_bytes, flag) + return img + + def load_image(file, is_color=True): """ Load an color or gray image from the file path. @@ -76,7 +98,7 @@ def resize_short(im, size): h_new = size * h / w else: w_new = size * w / h - im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) + im = resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) return im diff --git a/python/paddle/v2/reader/decorator.py b/python/paddle/v2/reader/decorator.py index 104ce9a0411..f06792314fd 100644 --- a/python/paddle/v2/reader/decorator.py +++ b/python/paddle/v2/reader/decorator.py @@ -14,13 +14,15 @@ __all__ = [ 'map_readers', 'buffered', 'compose', 'chain', 'shuffle', - 'ComposeNotAligned', 'firstn' + 'ComposeNotAligned', 'firstn', 'xmap' ] import itertools import random from Queue import Queue from threading import Thread +from multiprocessing import Queue as MQueue +from multiprocessing import Process def map_readers(func, *readers): @@ -224,3 +226,74 @@ def firstn(reader, n): yield item return firstn_reader + + +class XmapEndSignal(): + pass + + +def xmap(mapper, reader, process_num, buffer_size): + """ + Use multiprocess to map samples from reader by a mapper defined by user. + And this function contains a buffered decorator. + :param mapper: a function to map sample. + :type mapper: callable + :param reader: the data reader to read from + :type reader: callable + :param process_num: process number to handle original sample + :type process_num: int + :param buffer_size: max buffer size + :type buffer_size: int + :return: the decarated reader + :rtype: callable + """ + end = XmapEndSignal() + in_queue = MQueue(buffer_size) + out_queue = MQueue(buffer_size) + + # define a worker to read samples from reader to in_queue + def read_worker(reader, in_queue): + for i in reader(): + in_queue.put(i) + in_queue.put(end) + + # start a read worker in a thread + t = Thread(target=read_worker, args=(reader, in_queue)) + t.daemon = True + t.start() + + # define a worker to handle samples from in_queue by mapper + # and put mapped samples into out_queue + def handle_worker(in_queue, out_queue, mapper): + sample = in_queue.get() + while not isinstance(sample, XmapEndSignal): + r = mapper(sample) + out_queue.put(r) + sample = in_queue.get() + in_queue.put(end) + out_queue.put(end) + + # start several handle_workers + workers = [] + for i in xrange(process_num): + worker = Process( + target=handle_worker, args=(in_queue, out_queue, mapper)) + worker.daemon = True + workers.append(worker) + for w in workers: + w.start() + + def xreader(): + sample = out_queue.get() + while not isinstance(sample, XmapEndSignal): + yield sample + sample = out_queue.get() + finish = 1 + while finish < process_num: + sample = out_queue.get() + if isinstance(sample, XmapEndSignal): + finish += 1 + else: + yield sample + + return xreader -- GitLab