未验证 提交 82ac38e3 编写于 作者: B Bai Yifan 提交者: GitHub

Refine eval.py and update pyramidbox parameter init method (#1091)

* refine eval and update pyramidbox parameter init
上级 53fbde83
......@@ -145,7 +145,6 @@ class PyramidBox(object):
upsampling = fluid.layers.resize_bilinear(
conv1, out_shape=up_to.shape[2:])
b_attr = ParamAttr(learning_rate=2., regularizer=L2Decay(0.))
conv2 = fluid.layers.conv2d(
up_to, ch, 1, act='relu', bias_attr=b_attr)
if self.is_infer:
......@@ -220,10 +219,13 @@ class PyramidBox(object):
def permute_and_reshape(input, last_dim):
trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1])
new_shape = [
compile_shape = [
trans.shape[0], np.prod(trans.shape[1:]) / last_dim, last_dim
]
return fluid.layers.reshape(trans, shape=new_shape)
run_shape = fluid.layers.assign(
np.array([0, -1, last_dim]).astype("int32"))
return fluid.layers.reshape(
trans, shape=compile_shape, actual_shape=run_shape)
face_locs, face_confs = [], []
head_locs, head_confs = [], []
......@@ -288,10 +290,13 @@ class PyramidBox(object):
def permute_and_reshape(input, last_dim):
trans = fluid.layers.transpose(input, perm=[0, 2, 3, 1])
new_shape = [
compile_shape = [
trans.shape[0], np.prod(trans.shape[1:]) / last_dim, last_dim
]
return fluid.layers.reshape(trans, shape=new_shape)
run_shape = fluid.layers.assign(
np.array([0, -1, last_dim]).astype("int32"))
return fluid.layers.reshape(
trans, shape=compile_shape, actual_shape=run_shape)
locs, confs = [], []
boxes, vars = [], []
......
......@@ -16,14 +16,14 @@ add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('parallel', bool, True, "Whether use multi-GPU/threads or not.")
add_arg('learning_rate', float, 0.001, "The start learning rate.")
add_arg('batch_size', int, 12, "Minibatch size.")
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('num_passes', int, 160, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('model_save_dir', str, 'output', "The path to save model.")
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image width.")
add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, './vgg_ilsvrc_16_fc_reduced/', "The init model path.")
#yapf: enable
......
......@@ -8,33 +8,35 @@ from PIL import Image
import paddle.fluid as fluid
import reader
from pyramidbox import PyramidBox
from visualize import draw_bboxes
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether use GPU or not.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('data_dir', str, 'data/WIDER_val/images/', "The validation dataset path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('pred_dir', str, 'pred', "The path to save the evaluation results.")
add_arg('file_list', str, 'data/wider_face_split/wider_face_val_bbx_gt.txt', "The validation dataset path.")
add_arg('use_gpu', bool, True, "Whether use GPU or not.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('data_dir', str, 'data/WIDER_val/images/', "The validation dataset path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('pred_dir', str, 'pred', "The path to save the evaluation results.")
add_arg('file_list', str, 'data/wider_face_split/wider_face_val_bbx_gt.txt', "The validation dataset path.")
add_arg('infer', bool, False, "Whether do infer or eval.")
add_arg('confs_threshold', float, 0.15, "Confidence threshold to draw bbox.")
add_arg('image_path', str, '', "The image used to inference and visualize.")
# yapf: enable
def infer(args, config):
batch_size = 1
model_dir = args.model_dir
data_dir = args.data_dir
file_list = args.file_list
pred_dir = args.pred_dir
if not os.path.exists(model_dir):
raise ValueError("The model path [%s] does not exist." % (model_dir))
test_reader = reader.test(config, file_list)
for image, image_path in test_reader():
if args.infer:
image_path = args.image_path
image = Image.open(image_path)
if image.mode == 'L':
image = img.convert('RGB')
shrink, max_shrink = get_shrink(image.size[1], image.size[0])
det0 = detect_face(image, shrink)
......@@ -44,9 +46,24 @@ def infer(args, config):
det = np.row_stack((det0, det1, det2, det3, det4))
dets = bbox_vote(det)
save_widerface_bboxes(image_path, dets, pred_dir)
keep_index = np.where(dets[:, 4] >= args.confs_threshold)[0]
dets = dets[keep_index, :]
draw_bboxes(image_path, dets[:, 0:4])
else:
test_reader = reader.test(config, args.file_list)
for image, image_path in test_reader():
shrink, max_shrink = get_shrink(image.size[1], image.size[0])
det0 = detect_face(image, shrink)
det1 = flip_test(image, shrink)
[det2, det3] = multi_scale_test(image, max_shrink)
det4 = multi_scale_test_pyramid(image, max_shrink)
det = np.row_stack((det0, det1, det2, det3, det4))
dets = bbox_vote(det)
print("Finish evaluation.")
save_widerface_bboxes(image_path, dets, pred_dir)
print("Finish evaluation.")
def save_widerface_bboxes(image_path, bboxes_scores, output_dir):
......@@ -97,25 +114,11 @@ def detect_face(image, shrink):
img = [img]
img = np.array(img)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
network = PyramidBox(
image_shape, sub_network=args.use_pyramidbox, is_infer=True)
infer_program, nmsed_out = network.infer(main_program)
fetches = [nmsed_out]
fluid.io.load_persistables(
exe, args.model_dir, main_program=main_program)
detection, = exe.run(infer_program,
feed={'image': img},
fetch_list=fetches,
return_numpy=False)
detection = np.array(detection)
detection, = exe.run(infer_program,
feed={'image': img},
fetch_list=fetches,
return_numpy=False)
detection = np.array(detection)
# layout: xmin, ymin, xmax. ymax, score
if detection.shape == (1, ):
print("No face detected")
......@@ -290,4 +293,18 @@ if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
config = reader.Settings(data_dir=args.data_dir)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
main_program = fluid.Program()
startup_program = fluid.Program()
image_shape = [3, 1024, 1024]
with fluid.program_guard(main_program, startup_program):
network = PyramidBox(
image_shape, sub_network=args.use_pyramidbox, is_infer=True)
infer_program, nmsed_out = network.infer(main_program)
fetches = [nmsed_out]
fluid.io.load_persistables(
exe, args.model_dir, main_program=main_program)
infer(args, config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册