reader.py 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
baiyfbupt 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19 20 21 22 23 24 25 26 27
import image_util
from paddle.utils.image_util import *
from PIL import Image
from PIL import ImageDraw
import numpy as np
import xml.etree.ElementTree
import os
import time
import copy
Q
qingqing01 已提交
28
import random
29
import cv2
B
baiyf 已提交
30
import six
31
from data_util import GeneratorEnqueuer
32 33 34 35 36 37 38


class Settings(object):
    def __init__(self,
                 dataset=None,
                 data_dir=None,
                 label_file=None,
Q
qingqing01 已提交
39 40 41
                 resize_h=None,
                 resize_w=None,
                 mean_value=[104., 117., 123.],
42 43 44 45
                 apply_distort=True,
                 apply_expand=True,
                 ap_version='11point',
                 toy=0):
Q
qingqing01 已提交
46 47 48 49 50 51 52 53 54
        self.dataset = dataset
        self.ap_version = ap_version
        self.toy = toy
        self.data_dir = data_dir
        self.apply_distort = apply_distort
        self.apply_expand = apply_expand
        self.resize_height = resize_h
        self.resize_width = resize_w
        self.img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype(
55
            'float32')
Q
qingqing01 已提交
56 57 58 59 60 61 62 63 64
        self.expand_prob = 0.5
        self.expand_max_ratio = 4
        self.hue_prob = 0.5
        self.hue_delta = 18
        self.contrast_prob = 0.5
        self.contrast_delta = 0.5
        self.saturation_prob = 0.5
        self.saturation_delta = 0.5
        self.brightness_prob = 0.5
Q
qingqing01 已提交
65
        # _brightness_delta is the normalized value by 256
Q
qingqing01 已提交
66 67 68
        self.brightness_delta = 0.125
        self.scale = 0.007843  # 1 / 127.5
        self.data_anchor_sampling_prob = 0.5
69
        self.min_face_size = 8.0
70 71


Q
qingqing01 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84
def to_chw_bgr(image):
    """
    Transpose image from HWC to CHW and from RBG to BGR.
    Args:
        image (np.array): an image with HWC and RBG layout.
    """
    # HWC to CHW
    if len(image.shape) == 3:
        image = np.swapaxes(image, 1, 2)
        image = np.swapaxes(image, 1, 0)
    # RBG to BGR
    image = image[[2, 1, 0], :, :]
    return image
85 86 87


def preprocess(img, bbox_labels, mode, settings, image_path):
88 89 90
    img_width, img_height = img.size
    sampled_labels = bbox_labels
    if mode == 'train':
Q
qingqing01 已提交
91
        if settings.apply_distort:
92
            img = image_util.distort_image(img, settings)
Q
qingqing01 已提交
93
        if settings.apply_expand:
94 95
            img, bbox_labels, img_width, img_height = image_util.expand_image(
                img, bbox_labels, img_width, img_height, settings)
Q
qingqing01 已提交
96

97 98 99
        # sampling
        batch_sampler = []

B
Bai Yifan 已提交
100
        prob = np.random.uniform(0., 1.)
Q
qingqing01 已提交
101 102 103 104 105 106 107 108 109 110
        if prob > settings.data_anchor_sampling_prob:
            scale_array = np.array([16, 32, 64, 128, 256, 512])
            batch_sampler.append(
                image_util.sampler(1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2,
                                   0.0, True))
            sampled_bbox = image_util.generate_batch_random_samples(
                batch_sampler, bbox_labels, img_width, img_height, scale_array,
                settings.resize_width, settings.resize_height)
            img = np.array(img)
            if len(sampled_bbox) > 0:
B
Bai Yifan 已提交
111
                idx = int(np.random.uniform(0, len(sampled_bbox)))
Q
qingqing01 已提交
112 113
                img, sampled_labels = image_util.crop_image_sampling(
                    img, bbox_labels, sampled_bbox[idx], img_width, img_height,
114 115 116 117
                    settings.resize_width, settings.resize_height,
                    settings.min_face_size)

            img = img.astype('uint8')
Q
qingqing01 已提交
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141
            img = Image.fromarray(img)

        else:
            # hard-code here
            batch_sampler.append(
                image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            sampled_bbox = image_util.generate_batch_samples(
                batch_sampler, bbox_labels, img_width, img_height)

            img = np.array(img)
            if len(sampled_bbox) > 0:
B
Bai Yifan 已提交
142
                idx = int(np.random.uniform(0, len(sampled_bbox)))
Q
qingqing01 已提交
143
                img, sampled_labels = image_util.crop_image(
144 145 146
                    img, bbox_labels, sampled_bbox[idx], img_width, img_height,
                    settings.resize_width, settings.resize_height,
                    settings.min_face_size)
Q
qingqing01 已提交
147 148

            img = Image.fromarray(img)
B
Bai Yifan 已提交
149 150 151 152 153 154 155 156 157
    interp_mode = [
        Image.BILINEAR, Image.HAMMING, Image.NEAREST, Image.BICUBIC,
        Image.LANCZOS
    ]
    interp_indx = np.random.randint(0, 5)

    img = img.resize(
        (settings.resize_width, settings.resize_height),
        resample=interp_mode[interp_indx])
158 159 160
    img = np.array(img)

    if mode == 'train':
B
Bai Yifan 已提交
161
        mirror = int(np.random.uniform(0, 2))
162 163
        if mirror == 1:
            img = img[:, ::-1, :]
B
baiyf 已提交
164
            for i in six.moves.xrange(len(sampled_labels)):
165 166 167
                tmp = sampled_labels[i][1]
                sampled_labels[i][1] = 1 - sampled_labels[i][3]
                sampled_labels[i][3] = 1 - tmp
Q
qingqing01 已提交
168 169

    img = to_chw_bgr(img)
170 171
    img = img.astype('float32')
    img -= settings.img_mean
Q
qingqing01 已提交
172
    img = img * settings.scale
173 174 175
    return img, sampled_labels


176
def load_file_list(input_txt):
177 178 179
    with open(input_txt, 'r') as f_dir:
        lines_input_txt = f_dir.readlines()

180
    file_dict = {}
181 182
    num_class = 0
    for i in range(len(lines_input_txt)):
Q
qingqing01 已提交
183 184
        line_txt = lines_input_txt[i].strip('\n\t\r')
        if '--' in line_txt:
185 186
            if i != 0:
                num_class += 1
187
            file_dict[num_class] = []
Q
qingqing01 已提交
188 189 190 191
            file_dict[num_class].append(line_txt)
        if '--' not in line_txt:
            if len(line_txt) > 6:
                split_str = line_txt.split(' ')
192 193 194 195
                x1_min = float(split_str[0])
                y1_min = float(split_str[1])
                x2_max = float(split_str[2])
                y2_max = float(split_str[3])
Q
qingqing01 已提交
196
                line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
197
                    x2_max) + ' ' + str(y2_max)
Q
qingqing01 已提交
198
                file_dict[num_class].append(line_txt)
199
            else:
Q
qingqing01 已提交
200
                file_dict[num_class].append(line_txt)
201

202
    return file_dict
203 204


205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
def expand_bboxes(bboxes,
                  expand_left=2.,
                  expand_up=2.,
                  expand_right=2.,
                  expand_down=2.):
    """
    Expand bboxes, expand 2 times by defalut.
    """
    expand_boxes = []
    for bbox in bboxes:
        xmin = bbox[0]
        ymin = bbox[1]
        xmax = bbox[2]
        ymax = bbox[3]
        w = xmax - xmin
        h = ymax - ymin
        ex_xmin = max(xmin - w / expand_left, 0.)
        ex_ymin = max(ymin - h / expand_up, 0.)
        ex_xmax = min(xmax + w / expand_right, 1.)
        ex_ymax = min(ymax + h / expand_down, 1.)
        expand_boxes.append([ex_xmin, ex_ymin, ex_xmax, ex_ymax])
    return expand_boxes


229 230 231 232
def train_generator(settings, file_list, batch_size, shuffle=True):
    file_dict = load_file_list(file_list)
    while True:
        if shuffle:
B
Bai Yifan 已提交
233
            np.random.shuffle(file_dict)
234
        batch_out = []
235
        for index_image in file_dict.keys():
Q
qingqing01 已提交
236
            image_name = file_dict[index_image][0]
237 238 239 240 241 242
            image_path = os.path.join(settings.data_dir, image_name)
            im = Image.open(image_path)
            if im.mode == 'L':
                im = im.convert('RGB')
            im_width, im_height = im.size

243
            # layout: label | xmin | ymin | xmax | ymax
244 245 246 247 248 249 250 251 252
            bbox_labels = []
            for index_box in range(len(file_dict[index_image])):
                if index_box >= 2:
                    bbox_sample = []
                    temp_info_box = file_dict[index_image][index_box].split(' ')
                    xmin = float(temp_info_box[0])
                    ymin = float(temp_info_box[1])
                    w = float(temp_info_box[2])
                    h = float(temp_info_box[3])
253 254 255 256

                    # Filter out wrong labels
                    if w < 0 or h < 0:
                        continue
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
                    xmax = xmin + w
                    ymax = ymin + h

                    bbox_sample.append(1)
                    bbox_sample.append(float(xmin) / im_width)
                    bbox_sample.append(float(ymin) / im_height)
                    bbox_sample.append(float(xmax) / im_width)
                    bbox_sample.append(float(ymax) / im_height)
                    bbox_labels.append(bbox_sample)
            im, sample_labels = preprocess(im, bbox_labels, "train", settings,
                                           image_path)
            sample_labels = np.array(sample_labels)
            if len(sample_labels) == 0: continue

            im = im.astype('float32')
            face_box = sample_labels[:, 1:5]
            head_box = expand_bboxes(face_box)
            label = [1] * len(face_box)
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291
            batch_out.append((im, face_box, head_box, label))
            if len(batch_out) == batch_size:
                yield batch_out
                batch_out = []


def train(settings,
          file_list,
          batch_size,
          shuffle=True,
          use_multiprocessing=True,
          num_workers=8,
          max_queue=24):
    def reader():
        try:
            enqueuer = GeneratorEnqueuer(
                train_generator(settings, file_list, batch_size, shuffle),
292
                use_multiprocessing=use_multiprocessing)
293
            enqueuer.start(max_queue_size=max_queue, workers=num_workers)
294
            generator_output = None
295 296 297 298 299 300
            while True:
                while enqueuer.is_running():
                    if not enqueuer.queue.empty():
                        generator_output = enqueuer.queue.get()
                        break
                    else:
301
                        time.sleep(0.01)
302 303 304 305 306 307 308
                yield generator_output
                generator_output = None
        finally:
            if enqueuer is not None:
                enqueuer.stop()

    return reader
309 310


311 312
def test(settings, file_list):
    file_dict = load_file_list(file_list)
Q
qingqing01 已提交
313

314 315
    def reader():
        for index_image in file_dict.keys():
Q
qingqing01 已提交
316
            image_name = file_dict[index_image][0]
317 318 319 320 321
            image_path = os.path.join(settings.data_dir, image_name)
            im = Image.open(image_path)
            if im.mode == 'L':
                im = im.convert('RGB')
            yield im, image_path
Q
qingqing01 已提交
322

323
    return reader
B
baiyfbupt 已提交
324 325


Q
qingqing01 已提交
326 327 328 329 330 331
def infer(settings, image_path):
    def batch_reader():
        img = Image.open(image_path)
        if img.mode == 'L':
            img = im.convert('RGB')
        im_width, im_height = img.size
Q
qingqing01 已提交
332 333
        if settings.resize_width and settings.resize_height:
            img = img.resize((settings.resize_width, settings.resize_height),
Q
qingqing01 已提交
334 335
                             Image.ANTIALIAS)
        img = np.array(img)
Q
qingqing01 已提交
336
        img = to_chw_bgr(img)
Q
qingqing01 已提交
337 338
        img = img.astype('float32')
        img -= settings.img_mean
Q
qingqing01 已提交
339 340
        img = img * settings.scale
        return np.array([img])
Q
qingqing01 已提交
341 342

    return batch_reader