未验证 提交 e9d5a403 编写于 作者: Q qingqing01 提交者: GitHub

Refine evaluation code. (#1024)

* Refine evaluation code.
* Clean code.
上级 72051882
model/
pretrained/
data/
label/
data/WIDER_train
data/WIDER_val
data/wider_face_split
vgg_ilsvrc_16_fc_reduced*
*.swp
*.log
log*
output*
infer_results*
pred
eval_tools
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
echo "Downloading..."
wget http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/bbx_annotation/wider_face_split.zip
echo "Extracting..."
unzip wider_face_split.zip && rm -f wider_face_split.zip
......@@ -131,12 +131,13 @@ def data_anchor_sampling(sampler, bbox_labels, image_width, image_height,
rand_idx_size = range_size + 1
else:
# np.random.randint range: [low, high)
rng_rand_size = np.random.randint(0, range_size)
rand_idx_size = rng_rand_size % range_size
scale_choose = random.uniform(scale_array[rand_idx_size] / 2.0,
2.0 * scale_array[rand_idx_size])
rng_rand_size = np.random.randint(0, range_size + 1)
rand_idx_size = rng_rand_size % (range_size + 1)
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = min(2.0 * scale_array[rand_idx_size],
2 * math.sqrt(wid * hei))
scale_choose = random.uniform(min_resize_val, max_resize_val)
sample_bbox_size = wid * resize_width / scale_choose
w_off_orig = 0.0
......@@ -389,9 +390,19 @@ def crop_image_sampling(img, bbox_labels, sample_bbox, image_width,
roi_width = cross_width
roi_height = cross_height
roi_y1 = int(roi_ymin)
roi_y2 = int(roi_ymin + roi_height)
roi_x1 = int(roi_xmin)
roi_x2 = int(roi_xmin + roi_width)
cross_y1 = int(cross_ymin)
cross_y2 = int(cross_ymin + cross_height)
cross_x1 = int(cross_xmin)
cross_x2 = int(cross_xmin + cross_width)
sample_img = np.zeros((height, width, 3))
sample_img[int(roi_ymin) : int(roi_ymin + roi_height), int(roi_xmin) : int(roi_xmin + roi_width)] = \
img[int(cross_ymin) : int(cross_ymin + cross_height), int(cross_xmin) : int(cross_xmin + cross_width)]
sample_img[roi_y1 : roi_y2, roi_x1 : roi_x2] = \
img[cross_y1 : cross_y2, cross_x1 : cross_x2]
sample_img = cv2.resize(
sample_img, (resize_width, resize_height), interpolation=cv2.INTER_AREA)
......
......@@ -52,7 +52,7 @@ def conv_block(input, groups, filters, ksizes, strides=None, with_pool=True):
class PyramidBox(object):
def __init__(self,
data_shape,
num_classes,
num_classes=None,
use_transposed_conv2d=True,
is_infer=False,
sub_network=False):
......@@ -414,5 +414,5 @@ class PyramidBox(object):
nms_threshold=0.3,
nms_top_k=5000,
keep_top_k=750,
score_threshold=0.05)
score_threshold=0.01)
return test_program, face_nmsed_out
......@@ -59,30 +59,25 @@ class Settings(object):
self.saturation_delta = 0.5
self.brightness_prob = 0.5
# _brightness_delta is the normalized value by 256
# self._brightness_delta = 32
self.brightness_delta = 0.125
self.scale = 0.007843 # 1 / 127.5
self.data_anchor_sampling_prob = 0.5
self.min_face_size = 8.0
def draw_image(faces_pred, img, resize_val):
for i in range(len(faces_pred)):
draw_rotate_rectange(img, faces_pred[i], resize_val, (0, 255, 0), 3)
def draw_rotate_rectange(img, face, resize_val, color, thickness):
cv2.line(img, (int(face[1] * resize_val), int(face[2] * resize_val)), (int(
face[3] * resize_val), int(face[2] * resize_val)), color, thickness)
cv2.line(img, (int(face[3] * resize_val), int(face[2] * resize_val)), (int(
face[3] * resize_val), int(face[4] * resize_val)), color, thickness)
cv2.line(img, (int(face[1] * resize_val), int(face[2] * resize_val)), (int(
face[1] * resize_val), int(face[4] * resize_val)), color, thickness)
cv2.line(img, (int(face[3] * resize_val), int(face[4] * resize_val)), (int(
face[1] * resize_val), int(face[4] * resize_val)), color, thickness)
def to_chw_bgr(image):
"""
Transpose image from HWC to CHW and from RBG to BGR.
Args:
image (np.array): an image with HWC and RBG layout.
"""
# HWC to CHW
if len(image.shape) == 3:
image = np.swapaxes(image, 1, 2)
image = np.swapaxes(image, 1, 0)
# RBG to BGR
image = image[[2, 1, 0], :, :]
return image
def preprocess(img, bbox_labels, mode, settings, image_path):
......@@ -108,9 +103,6 @@ def preprocess(img, bbox_labels, mode, settings, image_path):
batch_sampler, bbox_labels, img_width, img_height, scale_array,
settings.resize_width, settings.resize_height)
img = np.array(img)
# Debug
# img_save = Image.fromarray(img)
# img_save.save('img_orig.jpg')
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sampled_labels = image_util.crop_image_sampling(
......@@ -119,17 +111,7 @@ def preprocess(img, bbox_labels, mode, settings, image_path):
settings.min_face_size)
img = img.astype('uint8')
# Debug: visualize the gt bbox
visualize_bbox = 0
if visualize_bbox:
img_show = img
draw_image(sampled_labels, img_show, settings.resize_height)
img_show = Image.fromarray(img_show)
img_show.save('final_img_show.jpg')
img = Image.fromarray(img)
# Debug
# img.save('final_img.jpg')
else:
# hard-code here
......@@ -173,12 +155,8 @@ def preprocess(img, bbox_labels, mode, settings, image_path):
tmp = sampled_labels[i][1]
sampled_labels[i][1] = 1 - sampled_labels[i][3]
sampled_labels[i][3] = 1 - tmp
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = to_chw_bgr(img)
img = img.astype('float32')
img -= settings.img_mean
img = img * settings.scale
......@@ -192,25 +170,24 @@ def load_file_list(input_txt):
file_dict = {}
num_class = 0
for i in range(len(lines_input_txt)):
tmp_line_txt = lines_input_txt[i].strip('\n\t\r')
if '--' in tmp_line_txt:
line_txt = lines_input_txt[i].strip('\n\t\r')
if '--' in line_txt:
if i != 0:
num_class += 1
file_dict[num_class] = []
dict_name = tmp_line_txt
file_dict[num_class].append(tmp_line_txt)
if '--' not in tmp_line_txt:
if len(tmp_line_txt) > 6:
split_str = tmp_line_txt.split(' ')
file_dict[num_class].append(line_txt)
if '--' not in line_txt:
if len(line_txt) > 6:
split_str = line_txt.split(' ')
x1_min = float(split_str[0])
y1_min = float(split_str[1])
x2_max = float(split_str[2])
y2_max = float(split_str[3])
tmp_line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
x2_max) + ' ' + str(y2_max)
file_dict[num_class].append(tmp_line_txt)
file_dict[num_class].append(line_txt)
else:
file_dict[num_class].append(tmp_line_txt)
file_dict[num_class].append(line_txt)
return file_dict
......@@ -248,7 +225,7 @@ def train_generator(settings, file_list, batch_size, shuffle=True):
label_offs = [0]
for index_image in file_dict.keys():
image_name = file_dict[index_image][0] + '.jpg'
image_name = file_dict[index_image][0]
image_path = os.path.join(settings.data_dir, image_name)
im = Image.open(image_path)
if im.mode == 'L':
......@@ -331,7 +308,7 @@ def test(settings, file_list):
def reader():
for index_image in file_dict.keys():
image_name = file_dict[index_image][0] + '.jpg'
image_name = file_dict[index_image][0]
image_path = os.path.join(settings.data_dir, image_name)
im = Image.open(image_path)
if im.mode == 'L':
......@@ -351,12 +328,7 @@ def infer(settings, image_path):
img = img.resize((settings.resize_width, settings.resize_height),
Image.ANTIALIAS)
img = np.array(img)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = to_chw_bgr(img)
img = img.astype('float32')
img -= settings.img_mean
img = img * settings.scale
......
......@@ -5,27 +5,26 @@ import time
import argparse
import functools
import reader
import paddle
import paddle.fluid as fluid
from pyramidbox import PyramidBox
import reader
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('parallel', bool, True, "parallel")
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 12, "Minibatch size.")
add_arg('parallel', bool, True, "Whether use multi-GPU/threads or not.")
add_arg('learning_rate', float, 0.001, "The start learning rate.")
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('num_passes', int, 160, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('model_save_dir', str, 'output', "The path to save model.")
add_arg('pretrained_model', str, './pretrained/', "The init model path.")
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.")
add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.")
add_arg('resize_w', int, 640, "The resized image width.")
add_arg('with_mem_opt', bool, True, "Whether to use memory optimization or not.")
add_arg('pretrained_model', str, './vgg_ilsvrc_16_fc_reduced/', "The init model path.")
#yapf: enable
......@@ -145,7 +144,7 @@ def train(args, config, train_file_list, optimizer_method):
fetch_list=fetches)
end_time = time.time()
fetch_vars = [np.mean(np.array(v)) for v in fetch_vars]
if batch_id % 1 == 0:
if batch_id % 10 == 0:
if not args.use_pyramidbox:
print("Pass {0}, batch {1}, loss {2}, time {3}".format(
pass_id, batch_id, fetch_vars[0],
......@@ -164,8 +163,8 @@ if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_dir = 'data/WIDERFACE/WIDER_train/images/'
train_file_list = 'label/train_gt_widerface.res'
data_dir = 'data/WIDER_train/images/'
train_file_list = 'data/wider_face_split/wider_face_train_bbx_gt.txt'
config = reader.Settings(
data_dir=data_dir,
......
import os
from PIL import Image
from PIL import ImageDraw
def draw_bbox(image, bbox):
"""
Draw one bounding box on image.
Args:
image (PIL.Image): a PIL Image object.
bbox (np.array|list|tuple): (xmin, ymin, xmax, ymax).
"""
draw = ImageDraw.Draw(image)
xmin, ymin, xmax, ymax = box
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
def draw_bboxes(image_file, bboxes, labels=None, output_dir=None):
"""
Draw bounding boxes on image.
Args:
image_file (string): input image path.
bboxes (np.array): bounding boxes.
labels (list of string): the label names of bboxes.
output_dir (string): output directory.
"""
if labels:
assert len(bboxes) == len(labels)
image = Image.open(image_file)
draw = ImageDraw.Draw(image)
for i in range(len(bboxes)):
xmin, ymin, xmax, ymax = bboxes[i]
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
if labels and image.mode == 'RGB':
draw.text((left, top), labels[i], (255, 255, 0))
output_file = image_file.split('/')[-1]
if output_dir:
output_file = os.path.join(output_dir, output_file)
print("The image with bbox is saved as {}".format(output_file))
image.save(output_file)
......@@ -4,68 +4,131 @@ import numpy as np
import argparse
import functools
from PIL import Image
from PIL import ImageDraw
import paddle
import paddle.fluid as fluid
import reader
from pyramidbox import PyramidBox
from utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('confs_threshold', float, 0.25, "Confidence threshold to draw bbox.")
add_arg('image_path', str, '', "The data root path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('use_gpu', bool, True, "Whether use GPU or not.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('data_dir', str, 'data/WIDER_val/images/', "The validation dataset path.")
add_arg('model_dir', str, '', "The model path.")
add_arg('pred_dir', str, 'pred', "The path to save the evaluation results.")
add_arg('file_list', str, 'data/wider_face_split/wider_face_val_bbx_gt.txt', "The validation dataset path.")
# yapf: enable
def draw_bounding_box_on_image(image_path, nms_out, confs_threshold):
image = Image.open(image_path)
draw = ImageDraw.Draw(image)
for dt in nms_out:
xmin, ymin, xmax, ymax, score = dt
if score < confs_threshold:
continue
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line(
[(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=4,
fill='red')
image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2]
print("image with bbox drawed saved as {}".format(image_name))
image.save('./infer_results/' + image_class.encode('utf-8') + '/' +
image_name.encode('utf-8'))
def infer(args, config):
batch_size = 1
model_dir = args.model_dir
data_dir = args.data_dir
file_list = args.file_list
pred_dir = args.pred_dir
if not os.path.exists(model_dir):
raise ValueError("The model path [%s] does not exist." % (model_dir))
test_reader = reader.test(config, file_list)
for image, image_path in test_reader():
shrink, max_shrink = get_shrink(image.size[1], image.size[0])
det0 = detect_face(image, shrink)
det1 = flip_test(image, shrink)
[det2, det3] = multi_scale_test(image, max_shrink)
det4 = multi_scale_test_pyramid(image, max_shrink)
det = np.row_stack((det0, det1, det2, det3, det4))
dets = bbox_vote(det)
def write_to_txt(image_path, f, nms_out):
save_widerface_bboxes(image_path, dets, pred_dir)
print("Finish evaluation.")
def save_widerface_bboxes(image_path, bboxes_scores, output_dir):
"""
Save predicted results, including bbox and score into text file.
Args:
image_path (string): file name.
bboxes_scores (np.array|list): the predicted bboxed and scores, layout
is (xmin, ymin, xmax, ymax, score)
output_dir (string): output directory.
"""
image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2]
f.write('{:s}\n'.format(
image_class.encode('utf-8') + '/' + image_name.encode('utf-8')))
f.write('{:d}\n'.format(nms_out.shape[0]))
for dt in nms_out:
xmin, ymin, xmax, ymax, score = dt
image_name = image_name.encode('utf-8')
image_class = image_class.encode('utf-8')
odir = os.path.join(output_dir, image_class)
if not os.path.exists(odir):
os.makedirs(odir)
ofname = os.path.join(odir, '%s.txt' % (image_name[:-4]))
f = open(ofname, 'w')
f.write('{:s}\n'.format(image_class + '/' + image_name))
f.write('{:d}\n'.format(bboxes_scores.shape[0]))
for box_score in bboxes_scores:
xmin, ymin, xmax, ymax, score = box_score
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(xmin, ymin, (
xmax - xmin + 1), (ymax - ymin + 1), score))
print("image infer result saved {}".format(image_name[:-4]))
f.close()
print("The predicted result is saved as {}".format(ofname))
def get_round(x, loc):
str_x = str(x)
if '.' in str_x:
len_after = len(str_x.split('.')[1])
str_before = str_x.split('.')[0]
str_after = str_x.split('.')[1]
if len_after >= 3:
str_final = str_before + '.' + str_after[0:loc]
return float(str_final)
else:
return x
def detect_face(image, shrink):
image_shape = [3, image.size[1], image.size[0]]
if shrink != 1:
h, w = int(image_shape[1] * shrink), int(image_shape[2] * shrink)
image = image.resize((w, h), Image.ANTIALIAS)
image_shape = [3, h, w]
img = np.array(image)
img = reader.to_chw_bgr(img)
mean = [104., 117., 123.]
scale = 0.007843
img = img.astype('float32')
img -= np.array(mean)[:, np.newaxis, np.newaxis].astype('float32')
img = img * scale
img = [img]
img = np.array(img)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
network = PyramidBox(
image_shape, sub_network=args.use_pyramidbox, is_infer=True)
infer_program, nmsed_out = network.infer(main_program)
fetches = [nmsed_out]
fluid.io.load_persistables(
exe, args.model_dir, main_program=main_program)
detection, = exe.run(infer_program,
feed={'image': img},
fetch_list=fetches,
return_numpy=False)
detection = np.array(detection)
# layout: xmin, ymin, xmax. ymax, score
if detection.shape == (1, ):
print("No face detected")
return np.array([[0, 0, 0, 0, 0]])
det_conf = detection[:, 1]
det_xmin = image_shape[2] * detection[:, 2] / shrink
det_ymin = image_shape[1] * detection[:, 3] / shrink
det_xmax = image_shape[2] * detection[:, 4] / shrink
det_ymax = image_shape[1] * detection[:, 5] / shrink
det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf))
det = det[keep_index, :]
return det
def bbox_vote(det):
......@@ -86,7 +149,7 @@ def bbox_vote(det):
inter = w * h
o = inter / (area[0] + area[:] - inter)
# get needed merge det and delete these det
# nms
merge_index = np.where(o >= 0.3)[0]
det_accu = det[merge_index, :]
det = np.delete(det, merge_index, 0)
......@@ -111,78 +174,6 @@ def bbox_vote(det):
return dets
def image_preprocess(image):
img = np.array(image)
# HWC to CHW
if len(img.shape) == 3:
img = np.swapaxes(img, 1, 2)
img = np.swapaxes(img, 1, 0)
# RBG to BGR
img = img[[2, 1, 0], :, :]
img = img.astype('float32')
img -= np.array(
[104., 117., 123.])[:, np.newaxis, np.newaxis].astype('float32')
img = img * 0.007843
img = [img]
img = np.array(img)
return img
def detect_face(image, shrink):
image_shape = [3, image.size[1], image.size[0]]
num_classes = 2
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
if shrink != 1:
image = image.resize((int(image_shape[2] * shrink),
int(image_shape[1] * shrink)), Image.ANTIALIAS)
image_shape = [
image_shape[0], int(image_shape[1] * shrink),
int(image_shape[2] * shrink)
]
print "image_shape:", image_shape
img = image_preprocess(image)
scope = fluid.core.Scope()
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.scope_guard(scope):
with fluid.unique_name.guard():
with fluid.program_guard(main_program, startup_program):
fetches = []
network = PyramidBox(
image_shape,
num_classes,
sub_network=args.use_pyramidbox,
is_infer=True)
infer_program, nmsed_out = network.infer(main_program)
fetches = [nmsed_out]
fluid.io.load_persistables(
exe, args.model_dir, main_program=main_program)
detection, = exe.run(infer_program,
feed={'image': img},
fetch_list=fetches,
return_numpy=False)
detection = np.array(detection)
# layout: xmin, ymin, xmax. ymax, score
if detection.shape == (1, ):
print("No face detected")
return np.array([[0, 0, 0, 0, 0]])
det_conf = detection[:, 1]
det_xmin = image_shape[2] * detection[:, 2] / shrink
det_ymin = image_shape[1] * detection[:, 3] / shrink
det_xmax = image_shape[2] * detection[:, 4] / shrink
det_ymax = image_shape[1] * detection[:, 5] / shrink
det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf))
keep_index = np.where(det[:, 4] >= 0)[0]
det = det[keep_index, :]
return det
def flip_test(image, shrink):
img = image.transpose(Image.FLIP_LEFT_RIGHT)
det_f = detect_face(img, shrink)
......@@ -197,18 +188,18 @@ def flip_test(image, shrink):
def multi_scale_test(image, max_shrink):
# shrink detecting and shrink only detect big face
# Shrink detecting is only used to detect big faces
st = 0.5 if max_shrink >= 0.75 else 0.5 * max_shrink
det_s = detect_face(image, st)
index = np.where(
np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1)
> 30)[0]
det_s = det_s[index, :]
# enlarge one times
# Enlarge one times
bt = min(2, max_shrink) if max_shrink > 1 else (st + max_shrink) / 2
det_b = detect_face(image, bt)
# enlarge small image x times for small face
# Enlarge small image x times for small faces
if max_shrink > 2:
bt *= 2
while bt < max_shrink:
......@@ -216,12 +207,13 @@ def multi_scale_test(image, max_shrink):
bt *= 2
det_b = np.row_stack((det_b, detect_face(image, max_shrink)))
# enlarge only detect small face
# Enlarged images are only used to detect small faces.
if bt > 1:
index = np.where(
np.minimum(det_b[:, 2] - det_b[:, 0] + 1,
det_b[:, 3] - det_b[:, 1] + 1) < 100)[0]
det_b = det_b[index, :]
# Shrinked images are only used to detect big faces.
else:
index = np.where(
np.maximum(det_b[:, 2] - det_b[:, 0] + 1,
......@@ -231,23 +223,24 @@ def multi_scale_test(image, max_shrink):
def multi_scale_test_pyramid(image, max_shrink):
# shrink detecting and shrink only detect big face
# Use image pyramids to detect faces
det_b = detect_face(image, 0.25)
index = np.where(
np.maximum(det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1)
> 30)[0]
det_b = det_b[index, :]
st = [0.5, 0.75, 1.25, 1.5, 1.75, 2.25]
st = [0.75, 1.25, 1.5, 1.75]
for i in range(len(st)):
if (st[i] <= max_shrink):
det_temp = detect_face(image, st[i])
# enlarge only detect small face
# Enlarged images are only used to detect small faces.
if st[i] > 1:
index = np.where(
np.minimum(det_temp[:, 2] - det_temp[:, 0] + 1,
det_temp[:, 3] - det_temp[:, 1] + 1) < 100)[0]
det_temp = det_temp[index, :]
# Shrinked images are only used to detect big faces.
else:
index = np.where(
np.maximum(det_temp[:, 2] - det_temp[:, 0] + 1,
......@@ -257,13 +250,28 @@ def multi_scale_test_pyramid(image, max_shrink):
return det_b
def get_im_shrink(image_shape):
max_shrink_v1 = (0x7fffffff / 577.0 /
(image_shape[1] * image_shape[2]))**0.5
max_shrink_v2 = (
(678 * 1024 * 2.0 * 2.0) / (image_shape[1] * image_shape[2]))**0.5
max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3
def get_shrink(height, width):
"""
Args:
height (int): image height.
width (int): image width.
"""
# avoid out of memory
max_shrink_v1 = (0x7fffffff / 577.0 / (height * width))**0.5
max_shrink_v2 = ((678 * 1024 * 2.0 * 2.0) / (height * width))**0.5
def get_round(x, loc):
str_x = str(x)
if '.' in str_x:
str_before, str_after = str_x.split('.')
len_after = len(str_after)
if len_after >= 3:
str_final = str_before + '.' + str_after[0:loc]
return float(str_final)
else:
return x
max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3
if max_shrink >= 1.5 and max_shrink < 2:
max_shrink = max_shrink - 0.1
elif max_shrink >= 2 and max_shrink < 3:
......@@ -275,60 +283,12 @@ def get_im_shrink(image_shape):
elif max_shrink >= 5:
max_shrink = max_shrink - 0.5
print 'max_shrink = ', max_shrink
shrink = max_shrink if max_shrink < 1 else 1
print "shrink = ", shrink
return shrink, max_shrink
def infer(args, batch_size, data_args):
if not os.path.exists(args.model_dir):
raise ValueError("The model path [%s] does not exist." %
(args.model_dir))
infer_reader = paddle.batch(
reader.test(data_args, file_list), batch_size=batch_size)
for batch_id, img in enumerate(infer_reader()):
image = img[0][0]
image_path = img[0][1]
# image.size: [width, height]
image_shape = [3, image.size[1], image.size[0]]
shrink, max_shrink = get_im_shrink(image_shape)
det0 = detect_face(image, shrink)
det1 = flip_test(image, shrink)
[det2, det3] = multi_scale_test(image, max_shrink)
det4 = multi_scale_test_pyramid(image, max_shrink)
det = np.row_stack((det0, det1, det2, det3, det4))
dets = bbox_vote(det)
image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2]
if not os.path.exists('./infer_results/' + image_class.encode('utf-8')):
os.makedirs('./infer_results/' + image_class.encode('utf-8'))
f = open('./infer_results/' + image_class.encode('utf-8') + '/' +
image_name.encode('utf-8')[:-4] + '.txt', 'w')
write_to_txt(image_path, f, dets)
# draw_bounding_box_on_image(image_path, dets, args.confs_threshold)
print "Done"
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
data_dir = 'data/WIDERFACE/WIDER_val/images/'
file_list = 'label/val_gt_widerface.res'
data_args = reader.Settings(
data_dir=data_dir,
mean_value=[104., 117., 123],
apply_distort=False,
apply_expand=False,
ap_version='11point')
infer(args, batch_size=1, data_args=data_args)
config = reader.Settings(data_dir=args.data_dir)
infer(args, config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册