flowers.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module will download dataset from
16
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
17 18
and parse train/test set intopaddle reader creators.

19
This set contains images of flowers belonging to 102 different categories.
20 21 22 23 24 25
The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category.

The database was used in:

Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
26 27
 number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008)
28 29 30 31
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.

"""
import itertools
32
import functools
33 34
from common import download
import tarfile
J
JiabinYang 已提交
35 36
import six
from six.moves import cPickle as pickle
37
import scipy.io as scio
38 39
from paddle.dataset.image import *
from paddle.reader import *
40 41
import os
import numpy as np
42
from multiprocessing import cpu_count
43 44
__all__ = ['train', 'test', 'valid']

J
JiabinYang 已提交
45

J
Jiabin Yang 已提交
46 47 48 49
DATA_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/102flowers.tgz'
LABEL_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/imagelabels.mat'
SETID_URL = 'http://paddlemodels.cdn.bcebos.com/flowers/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
50 51
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
W
wanghaoshuang 已提交
52 53 54 55 56 57
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
TRAIN_FLAG = 'tstid'
TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
58 59


60
def default_mapper(is_train, sample):
61 62 63 64
    '''
    map image bytes data to type needed by model input layer
    '''
    img, label = sample
65
    img = load_image_bytes(img)
D
dangqingqing 已提交
66
    img = simple_transform(
D
dangqingqing 已提交
67
        img, 256, 224, is_train, mean=[103.94, 116.78, 123.68])
68 69 70
    return img.flatten().astype('float32'), label


71 72 73 74
train_mapper = functools.partial(default_mapper, True)
test_mapper = functools.partial(default_mapper, False)


75 76 77
def reader_creator(data_file,
                   label_file,
                   setid_file,
78
                   dataset_name,
79
                   mapper,
80
                   buffered_size=1024,
W
wanghaoshuang 已提交
81
                   use_xmap=True):
82
    '''
83
    1. read images from tar file and
84 85
        merge images into batch files in 102flowers.tgz_batch/
    2. get a reader to read sample from batch file
86 87

    :param data_file: downloaded data file
88
    :type data_file: string
89
    :param label_file: downloaded label file
90 91 92 93
    :type label_file: string
    :param setid_file: downloaded setid file containing information
                        about how to split dataset
    :type setid_file: string
94 95
    :param dataset_name: data set name (tstid|trnid|valid)
    :type dataset_name: string
96
    :param mapper: a function to map image bytes data to type
97 98
                    needed by model input layer
    :type mapper: callable
99 100
    :param buffered_size: the size of buffer used to process images
    :type buffered_size: int
101 102 103
    :return: data reader
    :rtype: callable
    '''
104 105 106 107 108 109 110
    labels = scio.loadmat(label_file)['labels'][0]
    indexes = scio.loadmat(setid_file)[dataset_name][0]
    img2label = {}
    for i in indexes:
        img = "jpg/image_%05d.jpg" % i
        img2label[img] = labels[i - 1]
    file_list = batch_images_from_tar(data_file, dataset_name, img2label)
111 112 113 114 115

    def reader():
        for file in open(file_list):
            file = file.strip()
            batch = None
J
JiabinYang 已提交
116 117 118 119 120
            with open(file, 'rb') as f:
                if six.PY2:
                    batch = pickle.load(f)
                else:
                    batch = pickle.load(f, encoding='bytes')
121 122 123
            data = batch['data']
            labels = batch['label']
            for sample, label in itertools.izip(data, batch['label']):
L
livc 已提交
124
                yield sample, int(label) - 1
125

W
wanghaoshuang 已提交
126
    if use_xmap:
C
chengduoZH 已提交
127 128
        cpu_num = int(os.environ.get('CPU_NUM', cpu_count()))
        return xmap_readers(mapper, reader, cpu_num, buffered_size)
129 130
    else:
        return map_readers(mapper, reader)
131 132


133
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
134
    '''
135 136 137
    Create flowers training set reader.
    It returns a reader, each sample in the reader is
    image pixels in [0, 1] and label in [1, 102]
138 139 140 141 142 143
    translated from original color image by steps:
    1. resize to 256*256
    2. random crop to 224*224
    3. flatten
    :param mapper:  a function to map sample.
    :type mapper: callable
144 145
    :param buffered_size: the size of buffer used to process images
    :type buffered_size: int
146 147 148 149 150 151
    :return: train data reader
    :rtype: callable
    '''
    return reader_creator(
        download(DATA_URL, 'flowers', DATA_MD5),
        download(LABEL_URL, 'flowers', LABEL_MD5),
W
wanghaoshuang 已提交
152 153
        download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
        buffered_size, use_xmap)
154 155


156
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
157
    '''
158 159 160
    Create flowers test set reader.
    It returns a reader, each sample in the reader is
    image pixels in [0, 1] and label in [1, 102]
161 162 163 164 165 166
    translated from original color image by steps:
    1. resize to 256*256
    2. random crop to 224*224
    3. flatten
    :param mapper:  a function to map sample.
    :type mapper: callable
167 168
    :param buffered_size: the size of buffer used to process images
    :type buffered_size: int
169 170 171 172 173 174
    :return: test data reader
    :rtype: callable
    '''
    return reader_creator(
        download(DATA_URL, 'flowers', DATA_MD5),
        download(LABEL_URL, 'flowers', LABEL_MD5),
W
wanghaoshuang 已提交
175 176
        download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
        buffered_size, use_xmap)
177 178


179
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
180
    '''
181 182 183
    Create flowers validation set reader.
    It returns a reader, each sample in the reader is
    image pixels in [0, 1] and label in [1, 102]
184 185 186 187
    translated from original color image by steps:
    1. resize to 256*256
    2. random crop to 224*224
    3. flatten
188 189 190 191 192 193
    :param mapper:  a function to map sample.
    :type mapper: callable
    :param buffered_size: the size of buffer used to process images
    :type buffered_size: int
    :return: test data reader
    :rtype: callable
194 195 196 197
    '''
    return reader_creator(
        download(DATA_URL, 'flowers', DATA_MD5),
        download(LABEL_URL, 'flowers', LABEL_MD5),
W
wanghaoshuang 已提交
198 199
        download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
        buffered_size, use_xmap)
200 201 202 203 204 205


def fetch():
    download(DATA_URL, 'flowers', DATA_MD5)
    download(LABEL_URL, 'flowers', LABEL_MD5)
    download(SETID_URL, 'flowers', SETID_MD5)