reader.py 11.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

B
baiyfbupt 已提交
15 16 17 18
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19 20 21 22 23 24 25
from PIL import Image
from PIL import ImageDraw
import numpy as np
import xml.etree.ElementTree
import os
import time
import copy
Q
qingqing01 已提交
26
import random
27
import cv2
B
baiyf 已提交
28
import six
29 30 31 32
import math
from itertools import islice
import paddle
import image_util
33 34 35 36 37 38 39


class Settings(object):
    def __init__(self,
                 dataset=None,
                 data_dir=None,
                 label_file=None,
Q
qingqing01 已提交
40 41 42
                 resize_h=None,
                 resize_w=None,
                 mean_value=[104., 117., 123.],
43 44 45 46
                 apply_distort=True,
                 apply_expand=True,
                 ap_version='11point',
                 toy=0):
Q
qingqing01 已提交
47 48 49 50 51 52 53 54 55
        self.dataset = dataset
        self.ap_version = ap_version
        self.toy = toy
        self.data_dir = data_dir
        self.apply_distort = apply_distort
        self.apply_expand = apply_expand
        self.resize_height = resize_h
        self.resize_width = resize_w
        self.img_mean = np.array(mean_value)[:, np.newaxis, np.newaxis].astype(
56
            'float32')
Q
qingqing01 已提交
57 58 59 60 61 62 63 64 65
        self.expand_prob = 0.5
        self.expand_max_ratio = 4
        self.hue_prob = 0.5
        self.hue_delta = 18
        self.contrast_prob = 0.5
        self.contrast_delta = 0.5
        self.saturation_prob = 0.5
        self.saturation_delta = 0.5
        self.brightness_prob = 0.5
Q
qingqing01 已提交
66
        # _brightness_delta is the normalized value by 256
Q
qingqing01 已提交
67 68 69
        self.brightness_delta = 0.125
        self.scale = 0.007843  # 1 / 127.5
        self.data_anchor_sampling_prob = 0.5
70
        self.min_face_size = 8.0
71 72


Q
qingqing01 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85
def to_chw_bgr(image):
    """
    Transpose image from HWC to CHW and from RBG to BGR.
    Args:
        image (np.array): an image with HWC and RBG layout.
    """
    # HWC to CHW
    if len(image.shape) == 3:
        image = np.swapaxes(image, 1, 2)
        image = np.swapaxes(image, 1, 0)
    # RBG to BGR
    image = image[[2, 1, 0], :, :]
    return image
86 87 88


def preprocess(img, bbox_labels, mode, settings, image_path):
89 90 91
    img_width, img_height = img.size
    sampled_labels = bbox_labels
    if mode == 'train':
Q
qingqing01 已提交
92
        if settings.apply_distort:
93
            img = image_util.distort_image(img, settings)
Q
qingqing01 已提交
94
        if settings.apply_expand:
95 96
            img, bbox_labels, img_width, img_height = image_util.expand_image(
                img, bbox_labels, img_width, img_height, settings)
Q
qingqing01 已提交
97

98 99 100
        # sampling
        batch_sampler = []

B
Bai Yifan 已提交
101
        prob = np.random.uniform(0., 1.)
Q
qingqing01 已提交
102 103 104 105 106 107 108 109 110 111
        if prob > settings.data_anchor_sampling_prob:
            scale_array = np.array([16, 32, 64, 128, 256, 512])
            batch_sampler.append(
                image_util.sampler(1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2,
                                   0.0, True))
            sampled_bbox = image_util.generate_batch_random_samples(
                batch_sampler, bbox_labels, img_width, img_height, scale_array,
                settings.resize_width, settings.resize_height)
            img = np.array(img)
            if len(sampled_bbox) > 0:
B
Bai Yifan 已提交
112
                idx = int(np.random.uniform(0, len(sampled_bbox)))
Q
qingqing01 已提交
113 114
                img, sampled_labels = image_util.crop_image_sampling(
                    img, bbox_labels, sampled_bbox[idx], img_width, img_height,
115 116 117 118
                    settings.resize_width, settings.resize_height,
                    settings.min_face_size)

            img = img.astype('uint8')
Q
qingqing01 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
            img = Image.fromarray(img)

        else:
            # hard-code here
            batch_sampler.append(
                image_util.sampler(1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            batch_sampler.append(
                image_util.sampler(1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0,
                                   0.0, True))
            sampled_bbox = image_util.generate_batch_samples(
                batch_sampler, bbox_labels, img_width, img_height)

            img = np.array(img)
            if len(sampled_bbox) > 0:
B
Bai Yifan 已提交
143
                idx = int(np.random.uniform(0, len(sampled_bbox)))
Q
qingqing01 已提交
144
                img, sampled_labels = image_util.crop_image(
145 146 147
                    img, bbox_labels, sampled_bbox[idx], img_width, img_height,
                    settings.resize_width, settings.resize_height,
                    settings.min_face_size)
Q
qingqing01 已提交
148 149

            img = Image.fromarray(img)
B
Bai Yifan 已提交
150 151 152 153 154 155 156 157 158
    interp_mode = [
        Image.BILINEAR, Image.HAMMING, Image.NEAREST, Image.BICUBIC,
        Image.LANCZOS
    ]
    interp_indx = np.random.randint(0, 5)

    img = img.resize(
        (settings.resize_width, settings.resize_height),
        resample=interp_mode[interp_indx])
159 160 161
    img = np.array(img)

    if mode == 'train':
B
Bai Yifan 已提交
162
        mirror = int(np.random.uniform(0, 2))
163 164
        if mirror == 1:
            img = img[:, ::-1, :]
B
baiyf 已提交
165
            for i in six.moves.xrange(len(sampled_labels)):
166 167 168
                tmp = sampled_labels[i][1]
                sampled_labels[i][1] = 1 - sampled_labels[i][3]
                sampled_labels[i][3] = 1 - tmp
Q
qingqing01 已提交
169 170

    img = to_chw_bgr(img)
171 172
    img = img.astype('float32')
    img -= settings.img_mean
Q
qingqing01 已提交
173
    img = img * settings.scale
174 175 176
    return img, sampled_labels


177
def load_file_list(input_txt):
178 179 180
    with open(input_txt, 'r') as f_dir:
        lines_input_txt = f_dir.readlines()

181
    file_dict = {}
182 183
    num_class = 0
    for i in range(len(lines_input_txt)):
Q
qingqing01 已提交
184 185
        line_txt = lines_input_txt[i].strip('\n\t\r')
        if '--' in line_txt:
186 187
            if i != 0:
                num_class += 1
188
            file_dict[num_class] = []
Q
qingqing01 已提交
189 190 191 192
            file_dict[num_class].append(line_txt)
        if '--' not in line_txt:
            if len(line_txt) > 6:
                split_str = line_txt.split(' ')
193 194 195 196
                x1_min = float(split_str[0])
                y1_min = float(split_str[1])
                x2_max = float(split_str[2])
                y2_max = float(split_str[3])
Q
qingqing01 已提交
197
                line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
198
                    x2_max) + ' ' + str(y2_max)
Q
qingqing01 已提交
199
                file_dict[num_class].append(line_txt)
200
            else:
Q
qingqing01 已提交
201
                file_dict[num_class].append(line_txt)
202

Q
qingqing01 已提交
203
    return list(file_dict.values())
204 205


206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
def expand_bboxes(bboxes,
                  expand_left=2.,
                  expand_up=2.,
                  expand_right=2.,
                  expand_down=2.):
    """
    Expand bboxes, expand 2 times by defalut.
    """
    expand_boxes = []
    for bbox in bboxes:
        xmin = bbox[0]
        ymin = bbox[1]
        xmax = bbox[2]
        ymax = bbox[3]
        w = xmax - xmin
        h = ymax - ymin
        ex_xmin = max(xmin - w / expand_left, 0.)
        ex_ymin = max(ymin - h / expand_up, 0.)
        ex_xmax = min(xmax + w / expand_right, 1.)
        ex_ymax = min(ymax + h / expand_down, 1.)
        expand_boxes.append([ex_xmin, ex_ymin, ex_xmax, ex_ymax])
    return expand_boxes


230
def train_generator(settings, file_list, batch_size, shuffle=True):
231
    def reader():
232
        if shuffle:
233
            np.random.shuffle(file_list)
234
        batch_out = []
235 236
        for item in file_list:
            image_name = item[0]
237 238 239 240 241 242
            image_path = os.path.join(settings.data_dir, image_name)
            im = Image.open(image_path)
            if im.mode == 'L':
                im = im.convert('RGB')
            im_width, im_height = im.size

243
            # layout: label | xmin | ymin | xmax | ymax
244
            bbox_labels = []
245
            for index_box in range(len(item)):
246 247
                if index_box >= 2:
                    bbox_sample = []
248
                    temp_info_box = item[index_box].split(' ')
249 250 251 252
                    xmin = float(temp_info_box[0])
                    ymin = float(temp_info_box[1])
                    w = float(temp_info_box[2])
                    h = float(temp_info_box[3])
253 254 255 256

                    # Filter out wrong labels
                    if w < 0 or h < 0:
                        continue
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
                    xmax = xmin + w
                    ymax = ymin + h

                    bbox_sample.append(1)
                    bbox_sample.append(float(xmin) / im_width)
                    bbox_sample.append(float(ymin) / im_height)
                    bbox_sample.append(float(xmax) / im_width)
                    bbox_sample.append(float(ymax) / im_height)
                    bbox_labels.append(bbox_sample)
            im, sample_labels = preprocess(im, bbox_labels, "train", settings,
                                           image_path)
            sample_labels = np.array(sample_labels)
            if len(sample_labels) == 0: continue

            im = im.astype('float32')
            face_box = sample_labels[:, 1:5]
            head_box = expand_bboxes(face_box)
            label = [1] * len(face_box)
275 276 277 278 279
            batch_out.append((im, face_box, head_box, label))
            if len(batch_out) == batch_size:
                yield batch_out
                batch_out = []

280
    return reader
281 282


283 284 285 286 287 288 289 290
def train(settings, file_list, batch_size, shuffle=True, num_workers=8):
    file_lists = load_file_list(file_list)
    n = int(math.ceil(len(file_lists) // num_workers))
    split_lists = [file_lists[i:i + n] for i in range(0, len(file_lists), n)]
    readers = []
    for iterm in split_lists:
        readers.append(train_generator(settings, iterm, batch_size, shuffle))
    return paddle.reader.multiprocess_reader(readers, False)
291 292


293
def test(settings, file_list):
294
    file_lists = load_file_list(file_list)
Q
qingqing01 已提交
295

296
    def reader():
297 298
        for image in file_lists:
            image_name = image[0]
299 300 301 302 303
            image_path = os.path.join(settings.data_dir, image_name)
            im = Image.open(image_path)
            if im.mode == 'L':
                im = im.convert('RGB')
            yield im, image_path
Q
qingqing01 已提交
304

305
    return reader
B
baiyfbupt 已提交
306 307


Q
qingqing01 已提交
308 309 310 311 312 313
def infer(settings, image_path):
    def batch_reader():
        img = Image.open(image_path)
        if img.mode == 'L':
            img = im.convert('RGB')
        im_width, im_height = img.size
Q
qingqing01 已提交
314 315
        if settings.resize_width and settings.resize_height:
            img = img.resize((settings.resize_width, settings.resize_height),
Q
qingqing01 已提交
316 317
                             Image.ANTIALIAS)
        img = np.array(img)
Q
qingqing01 已提交
318
        img = to_chw_bgr(img)
Q
qingqing01 已提交
319 320
        img = img.astype('float32')
        img -= settings.img_mean
Q
qingqing01 已提交
321 322
        img = img * settings.scale
        return np.array([img])
Q
qingqing01 已提交
323 324

    return batch_reader