reader.py 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
baiyfbupt 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19 20 21 22 23 24 25
from PIL import Image
from PIL import ImageDraw
import numpy as np
import xml.etree.ElementTree
import os
import time
import copy
Q
qingqing01 已提交
26
import random
27
import cv2
B
baiyf 已提交
28
import six
29 30 31 32
import math
from itertools import islice
import paddle
import image_util
33 34 35 36 37 38 39


class Settings(object):
    def __init__(self,
                 dataset=None,
                 data_dir=None,
                 label_file=None,
Q
qingqing01 已提交
40 41 42
                 resize_h=None,
                 resize_w=None,
                 mean_value=[104., 117., 123.],
43 44 45 46
                 apply_distort=True,
                 apply_expand=True,
                 ap_version='11point',
                 toy=0):
Q
qingqing01 已提交
47 48 49 50 51 52 53 54 55
        self.dataset = dataset
        self.ap_version = ap_version
        self.toy = toy
        self.data_dir = data_dir
        self.apply_distort = apply_distort
        self.apply_expand = apply_expand
        self.resize_height = resize_h
        self.resize_width = resize_w
        self.img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype(
56
            'float32')
Q
qingqing01 已提交
57 58 59 60 61 62 63 64 65
        self.expand_prob = 0.5
        self.expand_max_ratio = 4
        self.hue_prob = 0.5
        self.hue_delta = 18
        self.contrast_prob = 0.5
        self.contrast_delta = 0.5
        self.saturation_prob = 0.5
        self.saturation_delta = 0.5
        self.brightness_prob = 0.5
Q
qingqing01 已提交
66
        # _brightness_delta is the normalized value by 256
Q
qingqing01 已提交
67 68 69
        self.brightness_delta = 0.125
        self.scale = 0.007843  # 1 / 127.5
        self.data_anchor_sampling_prob = 0.5
70
        self.min_face_size = 8.0
71 72


Q
qingqing01 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85
def to_chw_bgr(image):
    """
    Transpose image from HWC to CHW and from RBG to BGR.
    Args:
        image (np.array): an image with HWC and RBG layout.
    """
    # HWC to CHW
    if len(image.shape) == 3:
        image = np.swapaxes(image, 1, 2)
        image = np.swapaxes(image, 1, 0)
    # RBG to BGR
    image = image[[2, 1, 0], :, :]
    return image
86 87 88


def preprocess(img, bbox_labels, mode, settings, image_path):
89 90 91
    img_width, img_height = img.size
    sampled_labels = bbox_labels
    if mode == 'train':
Q
qingqing01 已提交
92
        if settings.apply_distort:
93
            img = image_util.distort_image(img, settings)
Q
qingqing01 已提交
94
        if settings.apply_expand:
95 96
            img, bbox_labels, img_width, img_height = image_util.expand_image(
                img, bbox_labels, img_width, img_height, settings)
Q
qingqing01 已提交
97

98 99
        # sampling
        batch_sampler = []
u010070587's avatar
u010070587 已提交
100 101 102 103
        # used for continuous evaluation
        if 'ce_mode' in os.environ:
           random.seed(0)
           np.random.seed(0)
B
Bai Yifan 已提交
104
        prob = np.random.uniform(0., 1.)
Q
qingqing01 已提交
105 106 107 108 109 110 111 112 113 114
        if prob > settings.data_anchor_sampling_prob:
            scale_array = np.array([16, 32, 64, 128, 256, 512])
            batch_sampler.append(
                image_util.sampler(1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2,
                                   0.0, True))
            sampled_bbox = image_util.generate_batch_random_samples(
                batch_sampler, bbox_labels, img_width, img_height, scale_array,
                settings.resize_width, settings.resize_height)
            img = np.array(img)
            if len(sampled_bbox) > 0:
B
Bai Yifan 已提交
115
                idx = int(np.random.uniform(0, len(sampled_bbox)))
Q
qingqing01 已提交
116 117
                img, sampled_labels = image_util.crop_image_sampling(
                    img, bbox_labels, sampled_bbox[idx], img_width, img_height,
118 119 120 121
                    settings.resize_width, settings.resize_height,
                    settings.min_face_size)

            img = img.astype('uint8')
Q
qingqing01 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
            img = Image.fromarray(img)

        else:
            # hard-code here
            batch_sampler.append(
                image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            sampled_bbox = image_util.generate_batch_samples(
                batch_sampler, bbox_labels, img_width, img_height)

            img = np.array(img)
            if len(sampled_bbox) > 0:
B
Bai Yifan 已提交
146
                idx = int(np.random.uniform(0, len(sampled_bbox)))
Q
qingqing01 已提交
147
                img, sampled_labels = image_util.crop_image(
148 149 150
                    img, bbox_labels, sampled_bbox[idx], img_width, img_height,
                    settings.resize_width, settings.resize_height,
                    settings.min_face_size)
Q
qingqing01 已提交
151 152

            img = Image.fromarray(img)
B
Bai Yifan 已提交
153 154 155 156 157 158 159 160 161
    interp_mode = [
        Image.BILINEAR, Image.HAMMING, Image.NEAREST, Image.BICUBIC,
        Image.LANCZOS
    ]
    interp_indx = np.random.randint(0, 5)

    img = img.resize(
        (settings.resize_width, settings.resize_height),
        resample=interp_mode[interp_indx])
162 163 164
    img = np.array(img)

    if mode == 'train':
B
Bai Yifan 已提交
165
        mirror = int(np.random.uniform(0, 2))
166 167
        if mirror == 1:
            img = img[:, ::-1, :]
B
baiyf 已提交
168
            for i in six.moves.xrange(len(sampled_labels)):
169 170 171
                tmp = sampled_labels[i][1]
                sampled_labels[i][1] = 1 - sampled_labels[i][3]
                sampled_labels[i][3] = 1 - tmp
Q
qingqing01 已提交
172 173

    img = to_chw_bgr(img)
174 175
    img = img.astype('float32')
    img -= settings.img_mean
Q
qingqing01 已提交
176
    img = img * settings.scale
177 178 179
    return img, sampled_labels


180
def load_file_list(input_txt):
181 182 183
    with open(input_txt, 'r') as f_dir:
        lines_input_txt = f_dir.readlines()

184
    file_dict = {}
185 186
    num_class = 0
    for i in range(len(lines_input_txt)):
Q
qingqing01 已提交
187 188
        line_txt = lines_input_txt[i].strip('\n\t\r')
        if '--' in line_txt:
189 190
            if i != 0:
                num_class += 1
191
            file_dict[num_class] = []
Q
qingqing01 已提交
192 193 194 195
            file_dict[num_class].append(line_txt)
        if '--' not in line_txt:
            if len(line_txt) > 6:
                split_str = line_txt.split(' ')
196 197 198 199
                x1_min = float(split_str[0])
                y1_min = float(split_str[1])
                x2_max = float(split_str[2])
                y2_max = float(split_str[3])
Q
qingqing01 已提交
200
                line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
201
                    x2_max) + ' ' + str(y2_max)
Q
qingqing01 已提交
202
                file_dict[num_class].append(line_txt)
203
            else:
Q
qingqing01 已提交
204
                file_dict[num_class].append(line_txt)
205

Q
qingqing01 已提交
206
    return list(file_dict.values())
207 208


209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
def expand_bboxes(bboxes,
                  expand_left=2.,
                  expand_up=2.,
                  expand_right=2.,
                  expand_down=2.):
    """
    Expand bboxes, expand 2 times by defalut.
    """
    expand_boxes = []
    for bbox in bboxes:
        xmin = bbox[0]
        ymin = bbox[1]
        xmax = bbox[2]
        ymax = bbox[3]
        w = xmax - xmin
        h = ymax - ymin
        ex_xmin = max(xmin - w / expand_left, 0.)
        ex_ymin = max(ymin - h / expand_up, 0.)
        ex_xmax = min(xmax + w / expand_right, 1.)
        ex_ymax = min(ymax + h / expand_down, 1.)
        expand_boxes.append([ex_xmin, ex_ymin, ex_xmax, ex_ymax])
    return expand_boxes


233
def train_generator(settings, file_list, batch_size, shuffle=True):
234
    def reader():
u010070587's avatar
u010070587 已提交
235
        if shuffle and 'ce_mode' not in os.environ:
236
            np.random.shuffle(file_list)
237
        batch_out = []
238 239
        for item in file_list:
            image_name = item[0]
240 241 242 243 244 245
            image_path = os.path.join(settings.data_dir, image_name)
            im = Image.open(image_path)
            if im.mode == 'L':
                im = im.convert('RGB')
            im_width, im_height = im.size

246
            # layout: label | xmin | ymin | xmax | ymax
247
            bbox_labels = []
248
            for index_box in range(len(item)):
249 250
                if index_box >= 2:
                    bbox_sample = []
251
                    temp_info_box = item[index_box].split(' ')
252 253 254 255
                    xmin = float(temp_info_box[0])
                    ymin = float(temp_info_box[1])
                    w = float(temp_info_box[2])
                    h = float(temp_info_box[3])
256 257 258 259

                    # Filter out wrong labels
                    if w < 0 or h < 0:
                        continue
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
                    xmax = xmin + w
                    ymax = ymin + h

                    bbox_sample.append(1)
                    bbox_sample.append(float(xmin) / im_width)
                    bbox_sample.append(float(ymin) / im_height)
                    bbox_sample.append(float(xmax) / im_width)
                    bbox_sample.append(float(ymax) / im_height)
                    bbox_labels.append(bbox_sample)
            im, sample_labels = preprocess(im, bbox_labels, "train", settings,
                                           image_path)
            sample_labels = np.array(sample_labels)
            if len(sample_labels) == 0: continue

            im = im.astype('float32')
            face_box = sample_labels[:, 1:5]
            head_box = expand_bboxes(face_box)
            label = [1] * len(face_box)
278 279 280 281 282
            batch_out.append((im, face_box, head_box, label))
            if len(batch_out) == batch_size:
                yield batch_out
                batch_out = []

283
    return reader
284 285


286 287 288 289 290 291
def train(settings,
          file_list,
          batch_size,
          shuffle=True,
          use_multiprocess=True,
          num_workers=8):
292
    file_lists = load_file_list(file_list)
293 294 295 296 297 298 299 300 301 302 303 304
    if use_multiprocess:
        n = int(math.ceil(len(file_lists) // num_workers))
        split_lists = [
            file_lists[i:i + n] for i in range(0, len(file_lists), n)
        ]
        readers = []
        for iterm in split_lists:
            readers.append(
                train_generator(settings, iterm, batch_size, shuffle))
        return paddle.reader.multiprocess_reader(readers, False)
    else:
        return train_generator(settings, file_lists, batch_size, shuffle)
305 306


307
def test(settings, file_list):
308
    file_lists = load_file_list(file_list)
Q
qingqing01 已提交
309

310
    def reader():
311 312
        for image in file_lists:
            image_name = image[0]
313 314 315 316 317
            image_path = os.path.join(settings.data_dir, image_name)
            im = Image.open(image_path)
            if im.mode == 'L':
                im = im.convert('RGB')
            yield im, image_path
Q
qingqing01 已提交
318

319
    return reader
B
baiyfbupt 已提交
320 321


Q
qingqing01 已提交
322 323 324 325
def infer(settings, image_path):
    def batch_reader():
        img = Image.open(image_path)
        if img.mode == 'L':
326
            img = img.convert('RGB')
Q
qingqing01 已提交
327
        im_width, im_height = img.size
Q
qingqing01 已提交
328 329
        if settings.resize_width and settings.resize_height:
            img = img.resize((settings.resize_width, settings.resize_height),
Q
qingqing01 已提交
330 331
                             Image.ANTIALIAS)
        img = np.array(img)
Q
qingqing01 已提交
332
        img = to_chw_bgr(img)
Q
qingqing01 已提交
333 334
        img = img.astype('float32')
        img -= settings.img_mean
Q
qingqing01 已提交
335 336
        img = img * settings.scale
        return np.array([img])
Q
qingqing01 已提交
337 338

    return batch_reader