提交 e053c8c9 编写于 作者: B baiyfbupt

code clean

上级 de4504d6
...@@ -3,7 +3,7 @@ import time ...@@ -3,7 +3,7 @@ import time
import numpy as np import numpy as np
import argparse import argparse
import functools import functools
import datetime import cv2
from PIL import Image from PIL import Image
from PIL import ImageDraw from PIL import ImageDraw
...@@ -12,7 +12,6 @@ import paddle.fluid as fluid ...@@ -12,7 +12,6 @@ import paddle.fluid as fluid
import reader import reader
from pyramidbox import PyramidBox from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
from paddle.fluid.framework import Program, Parameter, default_main_program, Variable
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
...@@ -21,8 +20,6 @@ add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.") ...@@ -21,8 +20,6 @@ add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
add_arg('confs_threshold', float, 0.25, "Confidence threshold to draw bbox.") add_arg('confs_threshold', float, 0.25, "Confidence threshold to draw bbox.")
add_arg('image_path', str, '', "The data root path.") add_arg('image_path', str, '', "The data root path.")
add_arg('model_dir', str, '', "The model path.") add_arg('model_dir', str, '', "The model path.")
add_arg('resize_h', int, 0, "The resized image height.")
add_arg('resize_w', int, 0, "The resized image height.")
# yapf: enable # yapf: enable
...@@ -115,19 +112,7 @@ def bbox_vote(det): ...@@ -115,19 +112,7 @@ def bbox_vote(det):
return dets return dets
def detect_face(image, image_shape, raw_image, shrink): def image_preprocess(image):
num_classes = 2
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
if shrink != 1:
image = image.resize((int(image_shape[2] * shrink),
int(image_shape[1] * shrink)), Image.ANTIALIAS)
image_shape = [
image_shape[0], int(image_shape[1] * shrink),
int(image_shape[2] * shrink)
]
print "image_shape:", image_shape
img = np.array(image) img = np.array(image)
# HWC to CHW # HWC to CHW
if len(img.shape) == 3: if len(img.shape) == 3:
...@@ -141,47 +126,54 @@ def detect_face(image, image_shape, raw_image, shrink): ...@@ -141,47 +126,54 @@ def detect_face(image, image_shape, raw_image, shrink):
img = img * 0.007843 img = img * 0.007843
img = [img] img = [img]
img = np.array(img) img = np.array(img)
return img
def detect_face(image, shrink):
image_shape = [3, image.size[1], image.size[0]]
num_classes = 2
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
if shrink != 1:
image = image.resize((int(image_shape[2] * shrink),
int(image_shape[1] * shrink)), Image.ANTIALIAS)
image_shape = [
image_shape[0], int(image_shape[1] * shrink),
int(image_shape[2] * shrink)
]
print "image_shape:", image_shape
img = image_preprocess(image)
scope = fluid.core.Scope() scope = fluid.core.Scope()
model_program = fluid.Program() main_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
with fluid.program_guard(model_program, startup_program): with fluid.program_guard(main_program, startup_program):
fetches = [] fetches = []
network = PyramidBox( network = PyramidBox(
image_shape, image_shape,
num_classes, num_classes,
sub_network=args.use_pyramidbox, sub_network=args.use_pyramidbox,
is_infer=True) is_infer=True)
infer_program, nmsed_out = network.infer() infer_program, nmsed_out = network.infer(main_program)
fetches = [nmsed_out] fetches = [nmsed_out]
feeder = fluid.DataFeeder(
place=place, feed_list=network.feeds())
fluid.io.load_persistables( fluid.io.load_persistables(
exe, args.model_dir, main_program=model_program) exe, args.model_dir, main_program=main_program)
#fluid.io.load_vars(exe, args.model_dir, predicate=if_exist)
detection, = exe.run(infer_program, detection, = exe.run(infer_program,
feed=feeder.feed([img]), feed={'image': img},
fetch_list=fetches, fetch_list=fetches,
return_numpy=False) return_numpy=False)
detection = np.array(detection) detection = np.array(detection)
# layout: xmin, ymin, xmax. ymax, score # layout: xmin, ymin, xmax. ymax, score
det_conf = detection[:, 1] det_conf = detection[:, 1]
if args.resize_h != 0 and args.resize_w != 0: det_xmin = image_shape[2] * detection[:, 2] / shrink
det_xmin = raw_image.size[0] * detection[:, 2] det_ymin = image_shape[1] * detection[:, 3] / shrink
det_ymin = raw_image.size[1] * detection[:, 3] det_xmax = image_shape[2] * detection[:, 4] / shrink
det_xmax = raw_image.size[0] * detection[:, 4] det_ymax = image_shape[1] * detection[:, 5] / shrink
det_ymax = raw_image.size[1] * detection[:, 5]
else:
det_xmin = image_shape[2] * detection[:, 2] / shrink
det_ymin = image_shape[1] * detection[:, 3] / shrink
det_xmax = image_shape[2] * detection[:, 4] / shrink
det_ymax = image_shape[1] * detection[:, 5] / shrink
det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf)) det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf))
keep_index = np.where(det[:, 4] >= 0)[0] keep_index = np.where(det[:, 4] >= 0)[0]
...@@ -189,40 +181,37 @@ def detect_face(image, image_shape, raw_image, shrink): ...@@ -189,40 +181,37 @@ def detect_face(image, image_shape, raw_image, shrink):
return det return det
def flip_test(image, image_shape, raw_image, shrink): def flip_test(image, shrink):
image = image.transpose(Image.FLIP_LEFT_RIGHT) image = image.transpose(Image.FLIP_LEFT_RIGHT)
det_f = detect_face(image, image_shape, raw_image, shrink) det_f = detect_face(image, shrink)
det_t = np.zeros(det_f.shape) det_t = np.zeros(det_f.shape)
det_t[:, 0] = raw_image.size[0] - det_f[:, 2] det_t[:, 0] = image.size[0] - det_f[:, 2]
det_t[:, 1] = det_f[:, 1] det_t[:, 1] = det_f[:, 1]
det_t[:, 2] = raw_image.size[0] - det_f[:, 0] det_t[:, 2] = image.size[0] - det_f[:, 0]
det_t[:, 3] = det_f[:, 3] det_t[:, 3] = det_f[:, 3]
det_t[:, 4] = det_f[:, 4] det_t[:, 4] = det_f[:, 4]
return det_t return det_t
def multi_scale_test(image, image_shape, raw_image, max_im_shrink): def multi_scale_test(image, max_shrink):
# shrink detecting and shrink only detect big face # shrink detecting and shrink only detect big face
st = 0.5 if max_im_shrink >= 0.75 else 0.5 * max_im_shrink st = 0.5 if max_shrink >= 0.75 else 0.5 * max_shrink
det_s = detect_face(image, image_shape, raw_image, st) det_s = detect_face(image, st)
index = np.where( index = np.where(
np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1) np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1)
> 30)[0] > 30)[0]
det_s = det_s[index, :] det_s = det_s[index, :]
# enlarge one times # enlarge one times
bt = min(2, max_im_shrink) if max_im_shrink > 1 else ( bt = min(2, max_shrink) if max_shrink > 1 else (st + max_shrink) / 2
st + max_im_shrink) / 2 det_b = detect_face(image, bt)
det_b = detect_face(image, image_shape, raw_image, bt)
# enlarge small iamge x times for small face # enlarge small image x times for small face
if max_im_shrink > 2: if max_shrink > 2:
bt *= 2 bt *= 2
while bt < max_im_shrink: while bt < max_shrink:
det_b = np.row_stack( det_b = np.row_stack((det_b, detect_face(image, bt)))
(det_b, detect_face(image, image_shape, raw_image, bt)))
bt *= 2 bt *= 2
det_b = np.row_stack( det_b = np.row_stack((det_b, detect_face(image, max_shrink)))
(det_b, detect_face(image, image_shape, raw_image, max_im_shrink)))
# enlarge only detect small face # enlarge only detect small face
if bt > 1: if bt > 1:
...@@ -239,30 +228,28 @@ def multi_scale_test(image, image_shape, raw_image, max_im_shrink): ...@@ -239,30 +228,28 @@ def multi_scale_test(image, image_shape, raw_image, max_im_shrink):
def get_im_shrink(image_shape): def get_im_shrink(image_shape):
max_im_shrink_v1 = (0x7fffffff / 577.0 / max_shrink_v1 = (0x7fffffff / 577.0 /
(image_shape[1] * image_shape[2]))**0.5 (image_shape[1] * image_shape[2]))**0.5
max_im_shrink_v2 = ( max_shrink_v2 = (
(678 * 1024 * 2.0 * 2.0) / (image_shape[1] * image_shape[2]))**0.5 (678 * 1024 * 2.0 * 2.0) / (image_shape[1] * image_shape[2]))**0.5
max_im_shrink = get_round(min(max_im_shrink_v1, max_im_shrink_v2), 2) - 0.3 max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3
if max_im_shrink >= 1.5 and max_im_shrink < 2: if max_shrink >= 1.5 and max_shrink < 2:
max_im_shrink = max_im_shrink - 0.1 max_shrink = max_shrink - 0.1
elif max_im_shrink >= 2 and max_im_shrink < 3: elif max_shrink >= 2 and max_shrink < 3:
max_im_shrink = max_im_shrink - 0.2 max_shrink = max_shrink - 0.2
elif max_im_shrink >= 3 and max_im_shrink < 4: elif max_shrink >= 3 and max_shrink < 4:
max_im_shrink = max_im_shrink - 0.3 max_shrink = max_shrink - 0.3
elif max_im_shrink >= 4 and max_im_shrink < 5: elif max_shrink >= 4 and max_shrink < 5:
max_im_shrink = max_im_shrink - 0.4 max_shrink = max_shrink - 0.4
elif max_im_shrink >= 5 and max_im_shrink < 6: elif max_shrink >= 5:
max_im_shrink = max_im_shrink - 0.5 max_shrink = max_shrink - 0.5
elif max_im_shrink >= 6:
max_im_shrink = max_im_shrink - 0.5 print 'max_shrink = ', max_shrink
shrink = max_shrink if max_shrink < 1 else 1
print 'max_im_shrink = ', max_im_shrink
shrink = max_im_shrink if max_im_shrink < 1 else 1
print "shrink = ", shrink print "shrink = ", shrink
return shrink, max_im_shrink return shrink, max_shrink
def infer(args, batch_size, data_args): def infer(args, batch_size, data_args):
...@@ -277,33 +264,25 @@ def infer(args, batch_size, data_args): ...@@ -277,33 +264,25 @@ def infer(args, batch_size, data_args):
image = img[0][0] image = img[0][0]
image_path = img[0][1] image_path = img[0][1]
raw_image = Image.open(image_path) image_shape = [3, image.size[1], image.size[0]]
if args.resize_h != 0 and args.resize_w != 0: shrink, max_shrink = get_im_shrink(image_shape)
image_shape = [3, args.resize_h, args.resize_w]
else:
image_shape = [3, image.size[1], image.size[0]]
shrink, max_im_shrink = get_im_shrink(image_shape)
det0 = detect_face(image, image_shape, raw_image, shrink)
det1 = flip_test(image, image_shape, raw_image, shrink)
[det2, det3] = multi_scale_test(image, image_shape, raw_image,
max_im_shrink)
det0 = detect_face(image, shrink)
det1 = flip_test(image, shrink)
[det2, det3] = multi_scale_test(image, max_shrink)
det = np.row_stack((det0, det1, det2, det3)) det = np.row_stack((det0, det1, det2, det3))
dets = bbox_vote(det) dets = bbox_vote(det)
image_name = image_path.split('/')[-1] image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2] image_class = image_path.split('/')[-2]
if not os.path.exists('./infer_results/' + image_class.encode('utf-8')): if not os.path.exists('./infer_results/' + image_class.encode('utf-8')):
os.makedirs('./infer_results/' + image_class.encode('utf-8')) os.makedirs('./infer_results/' + image_class.encode('utf-8'))
f = open('./infer_results/' + image_class.encode('utf-8') + '/' + f = open('./infer_results/' + image_class.encode('utf-8') + '/' +
image_name.encode('utf-8')[:-4] + '.txt', 'w') image_name.encode('utf-8')[:-4] + '.txt', 'w')
write_to_txt(image_path, f, dets) # write_to_txt(image_path, f, dets)
#draw_bounding_box_on_image(image_path, dets, args.confs_threshold) # draw_bounding_box_on_image(image_path, dets, args.confs_threshold)
print "Done" print "Done"
...@@ -316,8 +295,6 @@ if __name__ == '__main__': ...@@ -316,8 +295,6 @@ if __name__ == '__main__':
data_args = reader.Settings( data_args = reader.Settings(
data_dir=data_dir, data_dir=data_dir,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[104., 117., 123], mean_value=[104., 117., 123],
apply_distort=False, apply_distort=False,
apply_expand=False, apply_expand=False,
......
...@@ -269,9 +269,6 @@ def pyramidbox(settings, file_list, mode, shuffle): ...@@ -269,9 +269,6 @@ def pyramidbox(settings, file_list, mode, shuffle):
yield im, boxes, expand_bboxes(boxes), lbls, difficults yield im, boxes, expand_bboxes(boxes), lbls, difficults
if mode == 'test': if mode == 'test':
if settings.resize_w and settings.resize_h:
im = im.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
yield im, image_path yield im, image_path
return reader return reader
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册