未验证 提交 5abba732 编写于 作者: Q qingqing01 提交者: GitHub

Add infer scripts. (#966)

上级 b6c505b8
import os
import time
import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw
import paddle
import paddle.fluid as fluid
import reader
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, False, "Whether use PyramidBox model.")
add_arg('confs_threshold', float, 0.15, "Confidence threshold to draw bbox.")
add_arg('image_path', str, '', "The data root path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('resize_h', int, 0, "The resized image height.")
add_arg('resize_w', int, 0, "The resized image height.")
# yapf: enable
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
for dt in nms_out:
category_id, score, xmin, ymin, xmax, ymax = dt.tolist()
if score < confs_threshold:
continue
bbox = dt[2:]
xmin, ymin, xmax, ymax = bbox
(left, right, top, bottom) = (xmin * im_width, xmax * im_width,
ymin * im_height, ymax * im_height)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
image_name = image_path.split('/')[-1]
print("image with bbox drawed saved as {}".format(image_name))
image.save(image_name)
def infer(args, data_args):
num_classes = 2
infer_reader = reader.infer(data_args, args.image_path)
data = infer_reader()
if args.resize_h and args.resize_w:
image_shape = [3, args.resize_h, args.resize_w]
else:
image_shape = data.shape[1:]
fetches = []
network = PyramidBox(
image_shape,
num_classes,
sub_network=args.use_pyramidbox,
is_infer=True)
infer_program, nmsed_out = network.infer()
fetches = [nmsed_out]
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
model_dir = args.model_dir
if not os.path.exists(model_dir):
raise ValueError("The model path [%s] does not exist." % (model_dir))
def if_exist(var):
return os.path.exists(os.path.join(model_dir, var.name))
fluid.io.load_vars(exe, model_dir, predicate=if_exist)
feed = {'image': fluid.create_lod_tensor(data, [], place)}
predict, = exe.run(infer_program,
feed=feed,
fetch_list=fetches,
return_numpy=False)
predict = np.array(predict)
draw_bounding_box_on_image(args.image_path, predict, args.confs_threshold)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_dir = 'data/WIDERFACE/WIDER_val/images/'
file_list = 'label/val_gt_widerface.res'
data_args = reader.Settings(
data_dir=data_dir,
resize_h=args.resize_h,
resize_w=args.resize_w,
mean_value=[104., 117., 123],
apply_distort=False,
apply_expand=False,
ap_version='11point')
infer(args, data_args=data_args)
......@@ -45,12 +45,17 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True):
class PyramidBox(object):
def __init__(self, data_shape, is_infer=False, sub_network=False):
def __init__(self,
data_shape,
num_classes,
is_infer=False,
sub_network=False):
self.data_shape = data_shape
self.min_sizes = [16., 32., 64., 128., 256., 512.]
self.steps = [4., 8., 16., 32., 64., 128.]
self.is_infer = is_infer
self.sub_network = sub_network
self.num_classes = num_classes
# the base network is VGG with atrous layers
self._input()
......@@ -59,6 +64,8 @@ class PyramidBox(object):
self._low_level_fpn()
self._cpm_module()
self._pyramidbox()
else:
self._vgg_ssd()
def feeds(self):
if self.is_infer:
......@@ -188,9 +195,10 @@ class PyramidBox(object):
"""
Get prior-boxes and pyramid-box
"""
self.ssh_conv3_norm = self._l2_norm_scale(self.ssh_conv3)
self.ssh_conv4_norm = self._l2_norm_scale(self.ssh_conv4)
self.ssh_conv5_norm = self._l2_norm_scale(self.ssh_conv5)
self.ssh_conv3_norm = self._l2_norm_scale(
self.ssh_conv3, init_scale=10.)
self.ssh_conv4_norm = self._l2_norm_scale(self.ssh_conv4, init_scale=8.)
self.ssh_conv5_norm = self._l2_norm_scale(self.ssh_conv5, init_scale=5.)
def permute_and_reshape(input, last_dim):
trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1])
......@@ -253,10 +261,10 @@ class PyramidBox(object):
self.prior_boxes = fluid.layers.concat(boxes)
self.box_vars = fluid.layers.concat(vars)
def vgg_ssd(self, num_classes, image_shape):
self.conv3_norm = self._l2_norm_scale(self.conv3)
self.conv4_norm = self._l2_norm_scale(self.conv4)
self.conv5_norm = self._l2_norm_scale(self.conv5)
def _vgg_ssd(self):
self.conv3_norm = self._l2_norm_scale(self.conv3, init_scale=10.)
self.conv4_norm = self._l2_norm_scale(self.conv4, init_scale=8.)
self.conv5_norm = self._l2_norm_scale(self.conv5, init_scale=5.)
mbox_locs, mbox_confs, box, box_var = fluid.layers.multi_box_head(
inputs=[
......@@ -264,23 +272,30 @@ class PyramidBox(object):
self.conv7, self.conv8
],
image=self.image,
num_classes=num_classes,
# min_ratio=20,
# max_ratio=90,
num_classes=self.num_classes,
min_sizes=[16.0, 32.0, 64.0, 128.0, 256.0, 512.0],
max_sizes=[[], [], [], [], [], []],
# max_sizes=[[], 150.0, 195.0, 240.0, 285.0, 300.0],
aspect_ratios=[[1.], [1.], [1.], [1.], [1.], [1.]],
steps=[4.0, 8.0, 16.0, 32.0, 64.0, 128.0],
base_size=image_shape[2],
base_size=self.data_shape[2],
offset=0.5,
flip=False)
# locs, confs, box, box_var = vgg_extra_net(num_classes, image, image_shape)
# nmsed_out = fluid.layers.detection_output(
# locs, confs, box, box_var, nms_threshold=args.nms_threshold)
loss = fluid.layers.ssd_loss(mbox_locs, mbox_confs, self.face_box,
self.gt_label, box, box_var)
self.face_mbox_loc = mbox_locs
self.face_mbox_conf = mbox_confs
self.prior_boxes = box
self.box_vars = box_var
def vgg_ssd_loss(self):
loss = fluid.layers.ssd_loss(
self.face_mbox_loc,
self.face_mbox_conf,
self.face_box,
self.gt_label,
self.prior_boxes,
self.box_vars,
overlap_threshold=0.35,
neg_overlap=0.35)
loss = fluid.layers.reduce_sum(loss)
return loss
......@@ -297,7 +312,7 @@ class PyramidBox(object):
total_loss = face_loss + head_loss
return face_loss, head_loss, total_loss
def test(self):
def infer(self):
test_program = fluid.default_main_program().clone(for_test=True)
with fluid.program_guard(test_program):
face_nmsed_out = fluid.layers.detection_output(
......@@ -306,24 +321,4 @@ class PyramidBox(object):
self.prior_boxes,
self.box_vars,
nms_threshold=0.45)
head_nmsed_out = fluid.layers.detection_output(
self.head_mbox_loc,
self.head_mbox_conf,
self.prior_boxes,
self.box_vars,
nms_threshold=0.45)
face_map_eval = fluid.evaluator.DetectionMAP(
face_nmsed_out,
self.gt_label,
self.face_box,
class_num=2,
overlap_threshold=0.5,
ap_version='11point')
head_map_eval = fluid.evaluator.DetectionMAP(
head_nmsed_out,
self.gt_label,
self.head_box,
class_num=2,
overlap_threshold=0.5,
ap_version='11point')
return test_program, face_map_eval, head_map_eval
return test_program, face_nmsed_out
......@@ -272,3 +272,29 @@ def pyramidbox(settings, file_list, mode, shuffle):
def train(settings, file_list, shuffle=True):
return pyramidbox(settings, file_list, 'train', shuffle)
def infer(settings, image_path):
def batch_reader():
img = Image.open(image_path)
if img.mode == 'L':
img = im.convert('RGB')
im_width, im_height = img.size
if settings.resize_w and settings.resize_h:
img = img.resize((settings.resize_w, settings.resize_h),
Image.ANTIALIAS)
img = np.array(img)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= settings.img_mean
img = img * 0.007843
img = [img]
img = np.array(img)
return img
return batch_reader
......@@ -40,13 +40,13 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
image_shape = [3, data_args.resize_h, data_args.resize_w]
fetches = []
network = PyramidBox(image_shape, num_classes,
sub_network=args.use_pyramidbox)
if args.use_pyramidbox:
network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
face_loss, head_loss, loss = network.train()
fetches = [face_loss, head_loss]
else:
network = PyramidBox(image_shape, sub_network=args.use_pyramidbox)
loss = network.vgg_ssd(num_classes, image_shape)
loss = network.vgg_ssd_loss()
fetches = [loss]
epocs = 12880 / batch_size
......@@ -126,7 +126,7 @@ def train(args, data_args, learning_rate, batch_size, pretrained_model,
batch_id, fetch_vars[0], fetch_vars[1],
start_time - prev_start_time))
if pass_id % 10 == 0 or pass_id == num_passes - 1:
if pass_id % 1 == 0 or pass_id == num_passes - 1:
save_model(str(pass_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册