未验证 提交 3f1fddbd 编写于 作者: Q qingqing01 提交者: GitHub

Fix bug in data anchor sampling. (#1010)

* Fix bug in data anchor sampling.
* Add argument for memory optimization.
* Change arguments in detection_output.
上级 96b2eda5
......@@ -4,4 +4,6 @@ data/
label/
*.swp
*.log
infer_results/
log*
output*
infer_results*
......@@ -3,6 +3,7 @@ from PIL import ImageFile
import numpy as np
import random
import math
import cv2
ImageFile.LOAD_TRUNCATED_IMAGES = True #otherwise IOError raised image file is truncated
......@@ -107,10 +108,10 @@ def data_anchor_sampling(sampler, bbox_labels, image_width, image_height,
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
if num_gt != 0:
norm_xmin = bbox_labels[rand_idx][0]
norm_ymin = bbox_labels[rand_idx][1]
norm_xmax = bbox_labels[rand_idx][2]
norm_ymax = bbox_labels[rand_idx][3]
norm_xmin = bbox_labels[rand_idx][1]
norm_ymin = bbox_labels[rand_idx][2]
norm_xmax = bbox_labels[rand_idx][3]
norm_ymax = bbox_labels[rand_idx][4]
xmin = norm_xmin * image_width
ymin = norm_ymin * image_height
......@@ -321,7 +322,34 @@ def transform_labels(bbox_labels, sample_bbox):
return sample_labels
def crop_image(img, bbox_labels, sample_bbox, image_width, image_height):
def transform_labels_sampling(bbox_labels, sample_bbox, resize_val,
min_face_size):
sample_labels = []
for i in range(len(bbox_labels)):
sample_label = []
object_bbox = bbox(bbox_labels[i][1], bbox_labels[i][2],
bbox_labels[i][3], bbox_labels[i][4])
if not meet_emit_constraint(object_bbox, sample_bbox):
continue
proj_bbox = project_bbox(object_bbox, sample_bbox)
if proj_bbox:
real_width = float((proj_bbox.xmax - proj_bbox.xmin) * resize_val)
real_height = float((proj_bbox.ymax - proj_bbox.ymin) * resize_val)
if real_width * real_height < float(min_face_size * min_face_size):
continue
else:
sample_label.append(bbox_labels[i][0])
sample_label.append(float(proj_bbox.xmin))
sample_label.append(float(proj_bbox.ymin))
sample_label.append(float(proj_bbox.xmax))
sample_label.append(float(proj_bbox.ymax))
sample_label = sample_label + bbox_labels[i][5:]
sample_labels.append(sample_label)
return sample_labels
def crop_image(img, bbox_labels, sample_bbox, image_width, image_height,
resize_width, resize_height, min_face_size):
sample_bbox = clip_bbox(sample_bbox)
xmin = int(sample_bbox.xmin * image_width)
xmax = int(sample_bbox.xmax * image_width)
......@@ -329,12 +357,15 @@ def crop_image(img, bbox_labels, sample_bbox, image_width, image_height):
ymax = int(sample_bbox.ymax * image_height)
sample_img = img[ymin:ymax, xmin:xmax]
sample_labels = transform_labels(bbox_labels, sample_bbox)
resize_val = resize_width
sample_labels = transform_labels_sampling(bbox_labels, sample_bbox,
resize_val, min_face_size)
return sample_img, sample_labels
def crop_image_sampling(img, bbox_labels, sample_bbox, image_width,
image_height, resize_width, resize_height):
image_height, resize_width, resize_height,
min_face_size):
# no clipping here
xmin = int(sample_bbox.xmin * image_width)
xmax = int(sample_bbox.xmax * image_width)
......@@ -358,14 +389,16 @@ def crop_image_sampling(img, bbox_labels, sample_bbox, image_width,
roi_width = cross_width
roi_height = cross_height
sample_img = np.zeros((width, height, 3))
sample_img[roi_xmin : roi_xmin + roi_width, roi_ymin : roi_ymin + roi_height] = \
img[cross_xmin : cross_xmin + cross_width, cross_ymin : cross_ymin + cross_height]
sample_img = np.zeros((height, width, 3))
sample_img[int(roi_ymin) : int(roi_ymin + roi_height), int(roi_xmin) : int(roi_xmin + roi_width)] = \
img[int(cross_ymin) : int(cross_ymin + cross_height), int(cross_xmin) : int(cross_xmin + cross_width)]
sample_img = cv2.resize(
sample_img, (resize_width, resize_height), interpolation=cv2.INTER_AREA)
sample_labels = transform_labels(bbox_labels, sample_bbox)
resize_val = resize_width
sample_labels = transform_labels_sampling(bbox_labels, sample_bbox,
resize_val, min_face_size)
return sample_img, sample_labels
......
......@@ -385,6 +385,7 @@ class PyramidBox(object):
self.box_vars,
overlap_threshold=0.35,
neg_overlap=0.35)
face_loss.persistable = True
head_loss = fluid.layers.ssd_loss(
self.head_mbox_loc,
self.head_mbox_conf,
......@@ -394,9 +395,13 @@ class PyramidBox(object):
self.box_vars,
overlap_threshold=0.35,
neg_overlap=0.35)
head_loss.persistable = True
face_loss = fluid.layers.reduce_sum(face_loss)
face_loss.persistable = True
head_loss = fluid.layers.reduce_sum(head_loss)
head_loss.persistable = True
total_loss = face_loss + head_loss
total_loss.persistable = True
return face_loss, head_loss, total_loss
def infer(self, main_program=None):
......@@ -410,5 +415,8 @@ class PyramidBox(object):
self.face_mbox_conf,
self.prior_boxes,
self.box_vars,
nms_threshold=0.45)
nms_threshold=0.3,
nms_top_k=5000,
keep_top_k=750,
score_threshold=0.05)
return test_program, face_nmsed_out
......@@ -23,6 +23,7 @@ import os
import time
import copy
import random
import cv2
class Settings(object):
......@@ -61,9 +62,29 @@ class Settings(object):
self.brightness_delta = 0.125
self.scale = 0.007843 # 1 / 127.5
self.data_anchor_sampling_prob = 0.5
self.min_face_size = 8.0
def preprocess(img, bbox_labels, mode, settings):
def draw_image(faces_pred, img, resize_val):
for i in range(len(faces_pred)):
draw_rotate_rectange(img, faces_pred[i], resize_val, (0, 255, 0), 3)
def draw_rotate_rectange(img, face, resize_val, color, thickness):
cv2.line(img, (int(face[1] * resize_val), int(face[2] * resize_val)), (int(
face[3] * resize_val), int(face[2] * resize_val)), color, thickness)
cv2.line(img, (int(face[3] * resize_val), int(face[2] * resize_val)), (int(
face[3] * resize_val), int(face[4] * resize_val)), color, thickness)
cv2.line(img, (int(face[1] * resize_val), int(face[2] * resize_val)), (int(
face[1] * resize_val), int(face[4] * resize_val)), color, thickness)
cv2.line(img, (int(face[3] * resize_val), int(face[4] * resize_val)), (int(
face[1] * resize_val), int(face[4] * resize_val)), color, thickness)
def preprocess(img, bbox_labels, mode, settings, image_path):
img_width, img_height = img.size
sampled_labels = bbox_labels
if mode == 'train':
......@@ -86,13 +107,28 @@ def preprocess(img, bbox_labels, mode, settings):
batch_sampler, bbox_labels, img_width, img_height, scale_array,
settings.resize_width, settings.resize_height)
img = np.array(img)
# Debug
# img_save = Image.fromarray(img)
# img_save.save('img_orig.jpg')
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sampled_labels = image_util.crop_image_sampling(
img, bbox_labels, sampled_bbox[idx], img_width, img_height,
resize_width, resize_heigh)
settings.resize_width, settings.resize_height,
settings.min_face_size)
img = img.astype('uint8')
# Debug: visualize the gt bbox
visualize_bbox = 0
if visualize_bbox:
img_show = img
draw_image(sampled_labels, img_show, settings.resize_height)
img_show = Image.fromarray(img_show)
img_show.save('final_img_show.jpg')
img = Image.fromarray(img)
# Debug
# img.save('final_img.jpg')
else:
# hard-code here
......@@ -118,7 +154,9 @@ def preprocess(img, bbox_labels, mode, settings):
if len(sampled_bbox) > 0:
idx = int(random.uniform(0, len(sampled_bbox)))
img, sampled_labels = image_util.crop_image(
img, bbox_labels, sampled_bbox[idx], img_width, img_height)
img, bbox_labels, sampled_bbox[idx], img_width, img_height,
settings.resize_width, settings.resize_height,
settings.min_face_size)
img = Image.fromarray(img)
......@@ -240,7 +278,8 @@ def pyramidbox(settings, file_list, mode, shuffle):
bbox_sample.append(float(ymax) / im_height)
bbox_labels.append(bbox_sample)
im, sample_labels = preprocess(im, bbox_labels, mode, settings)
im, sample_labels = preprocess(im, bbox_labels, mode, settings,
image_path)
sample_labels = np.array(sample_labels)
if len(sample_labels) == 0: continue
im = im.astype('float32')
......
......@@ -18,13 +18,14 @@ add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('parallel', bool, True, "parallel")
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('batch_size', int, 12, "Minibatch size.")
add_arg('num_passes', int, 120, "Epoch number.")
add_arg('num_passes', int, 160, "Epoch number.")
add_arg('use_gpu', bool, True, "Whether use GPU.")
add_arg('use_pyramidbox', bool, True, "Whether use PyramidBox model.")
add_arg('model_save_dir', str, 'output', "The path to save model.")
add_arg('pretrained_model', str, './pretrained/', "The init model path.")
add_arg('resize_h', int, 640, "The resized image height.")
add_arg('resize_w', int, 640, "The resized image height.")
add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.")
#yapf: enable
......@@ -38,6 +39,7 @@ def train(args, config, train_file_list, optimizer_method):
use_pyramidbox = args.use_pyramidbox
model_save_dir = args.model_save_dir
pretrained_model = args.pretrained_model
with_memory_optimization = args.with_mem_opt
num_classes = 2
image_shape = [3, height, width]
......@@ -57,7 +59,7 @@ def train(args, config, train_file_list, optimizer_method):
fetches = [loss]
epocs = 12880 / batch_size
boundaries = [epocs * 40, epocs * 60, epocs * 80, epocs * 100]
boundaries = [epocs * 50, epocs * 80, epocs * 120, epocs * 140]
values = [
learning_rate, learning_rate * 0.5, learning_rate * 0.25,
learning_rate * 0.1, learning_rate * 0.01
......@@ -77,7 +79,8 @@ def train(args, config, train_file_list, optimizer_method):
)
optimizer.minimize(loss)
#fluid.memory_optimize(fluid.default_main_program())
if with_memory_optimization:
fluid.memory_optimize(fluid.default_main_program())
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -155,7 +158,8 @@ if __name__ == '__main__':
data_dir=data_dir,
resize_h=args.resize_h,
resize_w=args.resize_w,
apply_distort=True,
apply_expand=False,
mean_value=[104., 117., 123],
mean_value=[104., 117., 123.],
ap_version='11point')
train(args, config, train_file_list, optimizer_method="momentum")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册