infer.py 10.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import os
import time
import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw

import paddle
import paddle.fluid as fluid
import reader
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu',          bool,  True,      "Whether use GPU.")
add_arg('use_pyramidbox',   bool,  False, "Whether use PyramidBox model.")
Q
qingqing01 已提交
19
add_arg('confs_threshold',  float, 0.25,    "Confidence threshold to draw bbox.")
Q
qingqing01 已提交
20 21 22 23 24 25 26 27 28
add_arg('image_path',       str,   '',        "The data root path.")
add_arg('model_dir',        str,   '',     "The model path.")
# yapf: enable


def draw_bounding_box_on_image(image_path, nms_out, confs_threshold):
    image = Image.open(image_path)
    draw = ImageDraw.Draw(image)
    for dt in nms_out:
B
baiyfbupt 已提交
29
        xmin, ymin, xmax, ymax, score = dt
Q
qingqing01 已提交
30 31
        if score < confs_threshold:
            continue
B
baiyfbupt 已提交
32
        (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
Q
qingqing01 已提交
33 34 35 36 37 38
        draw.line(
            [(left, top), (left, bottom), (right, bottom), (right, top),
             (left, top)],
            width=4,
            fill='red')
    image_name = image_path.split('/')[-1]
B
baiyfbupt 已提交
39
    image_class = image_path.split('/')[-2]
Q
qingqing01 已提交
40
    print("image with bbox drawed saved as {}".format(image_name))
B
baiyfbupt 已提交
41 42
    image.save('./infer_results/' + image_class.encode('utf-8') + '/' +
               image_name.encode('utf-8'))
Q
qingqing01 已提交
43 44


B
baiyfbupt 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
def write_to_txt(image_path, f, nms_out):
    image_name = image_path.split('/')[-1]
    image_class = image_path.split('/')[-2]
    f.write('{:s}\n'.format(
        image_class.encode('utf-8') + '/' + image_name.encode('utf-8')))
    f.write('{:d}\n'.format(nms_out.shape[0]))
    for dt in nms_out:
        xmin, ymin, xmax, ymax, score = dt
        f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(xmin, ymin, (
            xmax - xmin + 1), (ymax - ymin + 1), score))
    print("image infer result saved {}".format(image_name[:-4]))


def get_round(x, loc):
    str_x = str(x)
    if '.' in str_x:
        len_after = len(str_x.split('.')[1])
        str_before = str_x.split('.')[0]
        str_after = str_x.split('.')[1]
        if len_after >= 3:
            str_final = str_before + '.' + str_after[0:loc]
            return float(str_final)
        else:
            return x
Q
qingqing01 已提交
69 70


B
baiyfbupt 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
def bbox_vote(det):
    order = det[:, 4].ravel().argsort()[::-1]
    det = det[order, :]
    if det.shape[0] == 0:
        dets = np.array([[10, 10, 20, 20, 0.002]])
        det = np.empty(shape=[0, 5])
    while det.shape[0] > 0:
        # IOU
        area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
        xx1 = np.maximum(det[0, 0], det[:, 0])
        yy1 = np.maximum(det[0, 1], det[:, 1])
        xx2 = np.minimum(det[0, 2], det[:, 2])
        yy2 = np.minimum(det[0, 3], det[:, 3])
        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h
        o = inter / (area[0] + area[:] - inter)

        # get needed merge det and delete these det
        merge_index = np.where(o >= 0.3)[0]
        det_accu = det[merge_index, :]
        det = np.delete(det, merge_index, 0)
        if merge_index.shape[0] <= 1:
            if det.shape[0] == 0:
                try:
                    dets = np.row_stack((dets, det_accu))
                except:
                    dets = det_accu
            continue
        det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
        max_score = np.max(det_accu[:, 4])
        det_accu_sum = np.zeros((1, 5))
        det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4],
                                      axis=0) / np.sum(det_accu[:, -1:])
        det_accu_sum[:, 4] = max_score
        try:
            dets = np.row_stack((dets, det_accu_sum))
        except:
            dets = det_accu_sum
    dets = dets[0:750, :]
    return dets
Q
qingqing01 已提交
112 113


B
baiyfbupt 已提交
114
def image_preprocess(image):
B
baiyfbupt 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127
    img = np.array(image)
    # HWC to CHW
    if len(img.shape) == 3:
        img = np.swapaxes(img, 1, 2)
        img = np.swapaxes(img, 1, 0)
    # RBG to BGR
    img = img[[2, 1, 0], :, :]
    img = img.astype('float32')
    img -= np.array(
        [104., 117., 123.])[:, np.newaxis, np.newaxis].astype('float32')
    img = img * 0.007843
    img = [img]
    img = np.array(img)
B
baiyfbupt 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    return img


def detect_face(image, shrink):
    image_shape = [3, image.size[1], image.size[0]]
    num_classes = 2
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)

    if shrink != 1:
        image = image.resize((int(image_shape[2] * shrink),
                              int(image_shape[1] * shrink)), Image.ANTIALIAS)
        image_shape = [
            image_shape[0], int(image_shape[1] * shrink),
            int(image_shape[2] * shrink)
        ]
    print "image_shape:", image_shape
    img = image_preprocess(image)
B
baiyfbupt 已提交
146 147

    scope = fluid.core.Scope()
B
baiyfbupt 已提交
148
    main_program = fluid.Program()
B
baiyfbupt 已提交
149 150 151 152
    startup_program = fluid.Program()

    with fluid.scope_guard(scope):
        with fluid.unique_name.guard():
B
baiyfbupt 已提交
153
            with fluid.program_guard(main_program, startup_program):
B
baiyfbupt 已提交
154 155 156 157 158 159
                fetches = []
                network = PyramidBox(
                    image_shape,
                    num_classes,
                    sub_network=args.use_pyramidbox,
                    is_infer=True)
B
baiyfbupt 已提交
160
                infer_program, nmsed_out = network.infer(main_program)
B
baiyfbupt 已提交
161 162
                fetches = [nmsed_out]
                fluid.io.load_persistables(
B
baiyfbupt 已提交
163
                    exe, args.model_dir, main_program=main_program)
B
baiyfbupt 已提交
164 165

                detection, = exe.run(infer_program,
B
baiyfbupt 已提交
166
                                     feed={'image': img},
B
baiyfbupt 已提交
167 168 169 170 171
                                     fetch_list=fetches,
                                     return_numpy=False)
                detection = np.array(detection)
    # layout: xmin, ymin, xmax. ymax, score
    det_conf = detection[:, 1]
B
baiyfbupt 已提交
172 173 174 175
    det_xmin = image_shape[2] * detection[:, 2] / shrink
    det_ymin = image_shape[1] * detection[:, 3] / shrink
    det_xmax = image_shape[2] * detection[:, 4] / shrink
    det_ymax = image_shape[1] * detection[:, 5] / shrink
B
baiyfbupt 已提交
176 177 178 179 180 181 182

    det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf))
    keep_index = np.where(det[:, 4] >= 0)[0]
    det = det[keep_index, :]
    return det


B
baiyfbupt 已提交
183
def flip_test(image, shrink):
184 185
    img = image.transpose(Image.FLIP_LEFT_RIGHT)
    det_f = detect_face(img, shrink)
B
baiyfbupt 已提交
186
    det_t = np.zeros(det_f.shape)
187
    # image.size: [width, height]
B
baiyfbupt 已提交
188
    det_t[:, 0] = image.size[0] - det_f[:, 2]
B
baiyfbupt 已提交
189
    det_t[:, 1] = det_f[:, 1]
B
baiyfbupt 已提交
190
    det_t[:, 2] = image.size[0] - det_f[:, 0]
B
baiyfbupt 已提交
191 192 193 194 195
    det_t[:, 3] = det_f[:, 3]
    det_t[:, 4] = det_f[:, 4]
    return det_t


B
baiyfbupt 已提交
196
def multi_scale_test(image, max_shrink):
B
baiyfbupt 已提交
197
    # shrink detecting and shrink only detect big face
B
baiyfbupt 已提交
198 199
    st = 0.5 if max_shrink >= 0.75 else 0.5 * max_shrink
    det_s = detect_face(image, st)
B
baiyfbupt 已提交
200 201 202 203 204
    index = np.where(
        np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1)
        > 30)[0]
    det_s = det_s[index, :]
    # enlarge one times
B
baiyfbupt 已提交
205 206
    bt = min(2, max_shrink) if max_shrink > 1 else (st + max_shrink) / 2
    det_b = detect_face(image, bt)
B
baiyfbupt 已提交
207

B
baiyfbupt 已提交
208 209
    # enlarge small image x times for small face
    if max_shrink > 2:
B
baiyfbupt 已提交
210
        bt *= 2
B
baiyfbupt 已提交
211 212
        while bt < max_shrink:
            det_b = np.row_stack((det_b, detect_face(image, bt)))
B
baiyfbupt 已提交
213
            bt *= 2
B
baiyfbupt 已提交
214
        det_b = np.row_stack((det_b, detect_face(image, max_shrink)))
B
baiyfbupt 已提交
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230

    # enlarge only detect small face
    if bt > 1:
        index = np.where(
            np.minimum(det_b[:, 2] - det_b[:, 0] + 1,
                       det_b[:, 3] - det_b[:, 1] + 1) < 100)[0]
        det_b = det_b[index, :]
    else:
        index = np.where(
            np.maximum(det_b[:, 2] - det_b[:, 0] + 1,
                       det_b[:, 3] - det_b[:, 1] + 1) > 30)[0]
        det_b = det_b[index, :]
    return det_s, det_b


def get_im_shrink(image_shape):
B
baiyfbupt 已提交
231 232 233
    max_shrink_v1 = (0x7fffffff / 577.0 /
                     (image_shape[1] * image_shape[2]))**0.5
    max_shrink_v2 = (
B
baiyfbupt 已提交
234
        (678 * 1024 * 2.0 * 2.0) / (image_shape[1] * image_shape[2]))**0.5
B
baiyfbupt 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
    max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3

    if max_shrink >= 1.5 and max_shrink < 2:
        max_shrink = max_shrink - 0.1
    elif max_shrink >= 2 and max_shrink < 3:
        max_shrink = max_shrink - 0.2
    elif max_shrink >= 3 and max_shrink < 4:
        max_shrink = max_shrink - 0.3
    elif max_shrink >= 4 and max_shrink < 5:
        max_shrink = max_shrink - 0.4
    elif max_shrink >= 5:
        max_shrink = max_shrink - 0.5

    print 'max_shrink = ', max_shrink
    shrink = max_shrink if max_shrink < 1 else 1
B
baiyfbupt 已提交
250 251
    print "shrink = ", shrink

B
baiyfbupt 已提交
252
    return shrink, max_shrink
B
baiyfbupt 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265 266


def infer(args, batch_size, data_args):
    if not os.path.exists(args.model_dir):
        raise ValueError("The model path [%s] does not exist." %
                         (args.model_dir))

    infer_reader = paddle.batch(
        reader.test(data_args, file_list), batch_size=batch_size)

    for batch_id, img in enumerate(infer_reader()):
        image = img[0][0]
        image_path = img[0][1]

267
        # image.size: [width, height]
B
baiyfbupt 已提交
268
        image_shape = [3, image.size[1], image.size[0]]
B
baiyfbupt 已提交
269

B
baiyfbupt 已提交
270
        shrink, max_shrink = get_im_shrink(image_shape)
B
baiyfbupt 已提交
271

B
baiyfbupt 已提交
272 273 274
        det0 = detect_face(image, shrink)
        det1 = flip_test(image, shrink)
        [det2, det3] = multi_scale_test(image, max_shrink)
B
baiyfbupt 已提交
275 276
        det = np.row_stack((det0, det1, det2, det3))
        dets = bbox_vote(det)
Q
qingqing01 已提交
277

B
baiyfbupt 已提交
278 279 280 281
        image_name = image_path.split('/')[-1]
        image_class = image_path.split('/')[-2]
        if not os.path.exists('./infer_results/' + image_class.encode('utf-8')):
            os.makedirs('./infer_results/' + image_class.encode('utf-8'))
Q
qingqing01 已提交
282

B
baiyfbupt 已提交
283 284
        f = open('./infer_results/' + image_class.encode('utf-8') + '/' +
                 image_name.encode('utf-8')[:-4] + '.txt', 'w')
B
baiyfbupt 已提交
285
        write_to_txt(image_path, f, dets)
B
baiyfbupt 已提交
286
        # draw_bounding_box_on_image(image_path, dets, args.confs_threshold)
B
baiyfbupt 已提交
287
    print "Done"
Q
qingqing01 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302


if __name__ == '__main__':
    args = parser.parse_args()
    print_arguments(args)

    data_dir = 'data/WIDERFACE/WIDER_val/images/'
    file_list = 'label/val_gt_widerface.res'

    data_args = reader.Settings(
        data_dir=data_dir,
        mean_value=[104., 117., 123],
        apply_distort=False,
        apply_expand=False,
        ap_version='11point')
B
baiyfbupt 已提交
303
    infer(args, batch_size=1, data_args=data_args)