提交 2799b0ec 编写于 作者: W wanghaoshuang@baidu.com 提交者: wanghaoshuang

Add flowers dataset for image classification model

上级 b15b2637
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
CIFAR dataset.
This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators.
This set contains images of flowers belonging to 102 different categories.
The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category.
The database was used in:
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008)
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import cPickle
import itertools
from common import download
import tarfile
import scipy.io as scio
from image import *
import os
from multiprocessing import Process
from multiprocessing import Pool
from multiprocessing import cpu_count
import numpy as np
import paddle.v2 as paddle
__all__ = ['train', 'test', 'valid']
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
def extract_file(tarFile):
'''
Extract tar file to tmp dir.
Example usage:
.. code-block:: python
tmp = extract_file("/home/work/test.tar.gz")
:param tarFile: target tar file
:type tarFile: string
:return: extracted dir. For example:
'/home/work/test/' while input is '/home/work/test.tar.gz'
:rtype: string
'''
base_dir = os.path.dirname(tarFile)
base_name = os.path.basename(tarFile)
if '.' in base_name:
base_name = base_name.split('.', 1)[0]
out_path = '/'.join([base_dir, base_name])
if not os.path.exists(out_path):
df = tarfile.open(tarFile, mode='r')
df.extractall(path=out_path)
df.close()
return out_path
def default_mapper(sample):
'''
map image bytes data to type needed by model input layer
'''
img, label = sample
img = paddle.image.load_image_bytes(img)
img = paddle.image.simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label
def reader_creator(data_file,
label_file,
setid_file,
flag,
mapper=default_mapper):
'''
1. extract 102flowers.tgz to 102flowers/
2. merge images into batch files in 102flowers_batch/
3. get a reader to read sample from batch file
:param data_file: downloaded data file
:type data_file: string
:param label_file: downloaded label file
:type label_file: string
:param setid_file: downloaded setid file containing information
about how to split dataset
:type setid_file: string
:param flag: data set name (tstid|trnid|valid)
:type flag: string
:param mapper: a function to map image bytes data to type
needed by model input layer
:type mapper: callable
:return: data reader
:rtype: callable
'''
base_dir = os.path.dirname(data_file)
tmp_dir = extract_file(data_file)
file_list = create_batch(tmp_dir, label_file, setid_file, flag)
def reader():
for file in open(file_list):
file = file.strip()
batch = None
with open(file, 'r') as f:
batch = cPickle.load(f)
data = batch['data']
labels = batch['label']
for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label)
return paddle.reader.xmap(mapper, reader, cpu_count(), 1024 * 8)
def create_batch(data_dir,
label_file,
setid_file,
flag,
numPerBatch=1024,
nThread=16):
batch_dir = data_dir + "_batch"
labels = scio.loadmat(label_file)['labels'][0]
indexes = scio.loadmat(setid_file)[flag][0]
count = len(indexes)
out_path = "%s/%s" % (batch_dir, flag)
meta_file = "%s/%s.txt" % (batch_dir, flag)
if os.path.exists(out_path):
return meta_file
else:
os.makedirs(out_path)
def batch(file_out, start, end):
data = []
labellist = []
for index in indexes[start:end]:
img_name = "%s/jpg/image_%05d.jpg" % (data_dir, index)
with open(img_name, 'r') as f:
data.append(f.read())
labellist.append(labels[index - 1])
output = {}
output['label'] = labellist
output['data'] = data
cPickle.dump(
output, open(file_out, 'w'), protocol=cPickle.HIGHEST_PROTOCOL)
cur_id = 0
file_id = 0
while cur_id < count:
thread = []
for i in xrange(nThread):
end_id = min(cur_id + numPerBatch, count)
batch_file_name = "%s/batch_%05d" % (out_path, file_id)
w = Process(target=batch, args=(batch_file_name, cur_id, end_id))
w.daemon = True
thread.append(w)
cur_id = end_id
file_id += 1
if cur_id == count:
break
for t in thread:
t.start()
for t in thread:
t.join()
with open(meta_file, 'a') as meta:
for file in os.listdir(out_path):
meta.write(os.path.abspath("%s/%s" % (out_path, file)) + "\n")
return meta_file
def train(mapper=default_mapper):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:return: train data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid')
def test(mapper=default_mapper):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid')
def valid():
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid')
def fetch():
download(DATA_URL, 'flowers', DATA_MD5)
download(LABEL_URL, 'flowers', LABEL_MD5)
download(SETID_URL, 'flowers', SETID_MD5)
if __name__ == '__main__':
for i in test()():
pass
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.dataset.flowers
import unittest
class TestFlowers(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
size = 224 * 224 * 3
for l in reader():
self.assertEqual(l[0].size, size)
if l[1] > label:
label = l[1]
sum += 1
return sum, label
def test_train(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train())
self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102)
def test_test(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test())
self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102)
def test_valid(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.valid())
self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102)
if __name__ == '__main__':
unittest.main()
import numpy as np
try:
import cv2
except:
print(
"import cv2 error, please install opencv-python: pip install opencv-python"
)
except ImportError:
cv2 = None
from cv2 import resize
__all__ = [
"load_image", "resize_short", "to_chw", "center_crop", "random_crop",
"left_right_flip", "simple_transform", "load_and_transform"
"load_image_bytes", "load_image", "resize_short", "to_chw", "center_crop",
"random_crop", "left_right_flip", "simple_transform", "load_and_transform"
]
"""
This file contains some common interfaces for image preprocess.
......@@ -28,6 +28,28 @@ the image layout as follows.
"""
def load_image_bytes(bytes, is_color=True):
"""
Load an color or gray image from bytes array.
Example usage:
.. code-block:: python
with open('cat.jpg') as f:
im = load_image(f.read())
:param bytes: the input image bytes array.
:type file: str
:param is_color: If set is_color True, it will load and
return a color image. Otherwise, it will
load and return a gray image.
"""
flag = 1 if is_color else 0
file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8)
img = cv2.imdecode(file_bytes, flag)
return img
def load_image(file, is_color=True):
"""
Load an color or gray image from the file path.
......@@ -76,7 +98,7 @@ def resize_short(im, size):
h_new = size * h / w
else:
w_new = size * w / h
im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
im = resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
return im
......
......@@ -14,13 +14,15 @@
__all__ = [
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
'ComposeNotAligned', 'firstn'
'ComposeNotAligned', 'firstn', 'xmap'
]
import itertools
import random
from Queue import Queue
from threading import Thread
from multiprocessing import Queue as MQueue
from multiprocessing import Process
def map_readers(func, *readers):
......@@ -224,3 +226,74 @@ def firstn(reader, n):
yield item
return firstn_reader
class XmapEndSignal():
pass
def xmap(mapper, reader, process_num, buffer_size):
"""
Use multiprocess to map samples from reader by a mapper defined by user.
And this function contains a buffered decorator.
:param mapper: a function to map sample.
:type mapper: callable
:param reader: the data reader to read from
:type reader: callable
:param process_num: process number to handle original sample
:type process_num: int
:param buffer_size: max buffer size
:type buffer_size: int
:return: the decarated reader
:rtype: callable
"""
end = XmapEndSignal()
in_queue = MQueue(buffer_size)
out_queue = MQueue(buffer_size)
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for i in reader():
in_queue.put(i)
in_queue.put(end)
# start a read worker in a thread
t = Thread(target=read_worker, args=(reader, in_queue))
t.daemon = True
t.start()
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
r = mapper(sample)
out_queue.put(r)
sample = in_queue.get()
in_queue.put(end)
out_queue.put(end)
# start several handle_workers
workers = []
for i in xrange(process_num):
worker = Process(
target=handle_worker, args=(in_queue, out_queue, mapper))
worker.daemon = True
workers.append(worker)
for w in workers:
w.start()
def xreader():
sample = out_queue.get()
while not isinstance(sample, XmapEndSignal):
yield sample
sample = out_queue.get()
finish = 1
while finish < process_num:
sample = out_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
yield sample
return xreader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册