flowers.py 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html 
and parse train/test set intopaddle reader creators.

This set contains images of flowers belonging to 102 different categories. 
The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category.

The database was used in:

Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
 number of classes.Proceedings of the Indian Conference on Computer Vision, 
Graphics and Image Processing (2008) 
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.

"""
import cPickle
import itertools
from common import download
import tarfile
import scipy.io as scio
36
from paddle.v2.image import *
37 38 39
import os
import numpy as np
import paddle.v2 as paddle
40
from multiprocessing import cpu_count
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
__all__ = ['train', 'test', 'valid']

DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'


def default_mapper(sample):
    '''
    map image bytes data to type needed by model input layer
    '''
    img, label = sample
    img = paddle.image.load_image_bytes(img)
    img = paddle.image.simple_transform(img, 256, 224, True)
    return img.flatten().astype('float32'), label


def reader_creator(data_file,
                   label_file,
                   setid_file,
64 65 66
                   dataset_name,
                   mapper=default_mapper,
                   buffered_size=1024):
67
    '''
68 69 70
    1. read images from tar file and 
        merge images into batch files in 102flowers.tgz_batch/
    2. get a reader to read sample from batch file
71 72 73 74 75 76 77 78
    
    :param data_file: downloaded data file 
    :type data_file: string
    :param label_file: downloaded label file 
    :type label_file: string
    :param setid_file: downloaded setid file containing information
                        about how to split dataset
    :type setid_file: string
79 80
    :param dataset_name: data set name (tstid|trnid|valid)
    :type dataset_name: string
81 82 83
    :param mapper: a function to map image bytes data to type 
                    needed by model input layer
    :type mapper: callable
84 85
    :param buffered_size: the size of buffer used to process images
    :type buffered_size: int
86 87 88
    :return: data reader
    :rtype: callable
    '''
89 90 91 92 93 94 95
    labels = scio.loadmat(label_file)['labels'][0]
    indexes = scio.loadmat(setid_file)[dataset_name][0]
    img2label = {}
    for i in indexes:
        img = "jpg/image_%05d.jpg" % i
        img2label[img] = labels[i - 1]
    file_list = batch_images_from_tar(data_file, dataset_name, img2label)
96 97 98 99 100 101 102 103 104 105 106 107

    def reader():
        for file in open(file_list):
            file = file.strip()
            batch = None
            with open(file, 'r') as f:
                batch = cPickle.load(f)
            data = batch['data']
            labels = batch['label']
            for sample, label in itertools.izip(data, batch['label']):
                yield sample, int(label)

108
    return paddle.reader.xmap(mapper, reader, cpu_count(), buffered_size)
109 110


111
def train(mapper=default_mapper, buffered_size=1024):
112 113 114 115 116 117 118 119 120 121
    '''
    Create flowers training set reader. 
    It returns a reader, each sample in the reader is   
    image pixels in [0, 1] and label in [1, 102] 
    translated from original color image by steps:
    1. resize to 256*256
    2. random crop to 224*224
    3. flatten
    :param mapper:  a function to map sample.
    :type mapper: callable
122 123
    :param buffered_size: the size of buffer used to process images
    :type buffered_size: int
124 125 126 127 128 129
    :return: train data reader
    :rtype: callable
    '''
    return reader_creator(
        download(DATA_URL, 'flowers', DATA_MD5),
        download(LABEL_URL, 'flowers', LABEL_MD5),
130 131
        download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper,
        buffered_size)
132 133


134
def test(mapper=default_mapper, buffered_size=1024):
135 136 137 138 139 140 141 142 143 144
    '''
    Create flowers test set reader. 
    It returns a reader, each sample in the reader is   
    image pixels in [0, 1] and label in [1, 102] 
    translated from original color image by steps:
    1. resize to 256*256
    2. random crop to 224*224
    3. flatten
    :param mapper:  a function to map sample.
    :type mapper: callable
145 146
    :param buffered_size: the size of buffer used to process images
    :type buffered_size: int
147 148 149 150 151 152
    :return: test data reader
    :rtype: callable
    '''
    return reader_creator(
        download(DATA_URL, 'flowers', DATA_MD5),
        download(LABEL_URL, 'flowers', LABEL_MD5),
153 154
        download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper,
        buffered_size)
155 156


157
def valid(mapper=default_mapper, buffered_size=1024):
158 159 160 161 162 163 164 165
    '''
    Create flowers validation set reader. 
    It returns a reader, each sample in the reader is   
    image pixels in [0, 1] and label in [1, 102] 
    translated from original color image by steps:
    1. resize to 256*256
    2. random crop to 224*224
    3. flatten
166 167 168 169 170 171
    :param mapper:  a function to map sample.
    :type mapper: callable
    :param buffered_size: the size of buffer used to process images
    :type buffered_size: int
    :return: test data reader
    :rtype: callable
172 173 174 175
    '''
    return reader_creator(
        download(DATA_URL, 'flowers', DATA_MD5),
        download(LABEL_URL, 'flowers', LABEL_MD5),
176 177
        download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper,
        buffered_size)
178 179 180 181 182 183


def fetch():
    download(DATA_URL, 'flowers', DATA_MD5)
    download(LABEL_URL, 'flowers', LABEL_MD5)
    download(SETID_URL, 'flowers', SETID_MD5)