提交 30489edb 编写于 作者: G Guanghua Yu 提交者: qingqing01

Add data anchor sampling for face detection. (#3450)

* add CropImageWithDataAnchorSampling
* Refine bbox sampling code
* Update comments
* fix import and variable problem
上级 ebe2fc68
......@@ -18,6 +18,9 @@ from __future__ import division
from __future__ import print_function
import numpy as np
import random
import math
import cv2
def meet_emit_constraint(src_bbox, sample_bbox):
......@@ -40,20 +43,35 @@ def clip_bbox(src_bbox):
def bbox_area(src_bbox):
width = src_bbox[2] - src_bbox[0]
height = src_bbox[3] - src_bbox[1]
return width * height
if src_bbox[2] < src_bbox[0] or src_bbox[3] < src_bbox[1]:
return 0.
else:
width = src_bbox[2] - src_bbox[0]
height = src_bbox[3] - src_bbox[1]
return width * height
def is_overlap(object_bbox, sample_bbox):
if object_bbox[0] >= sample_bbox[2] or \
object_bbox[2] <= sample_bbox[0] or \
object_bbox[1] >= sample_bbox[3] or \
object_bbox[3] <= sample_bbox[1]:
return False
else:
return True
def filter_and_process(sample_bbox, bboxes, labels, scores=None):
new_bboxes = []
new_labels = []
new_scores = []
for i in range(len(labels)):
for i in range(len(bboxes)):
new_bbox = [0, 0, 0, 0]
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
if not meet_emit_constraint(obj_bbox, sample_bbox):
continue
if not is_overlap(obj_bbox, sample_bbox):
continue
sample_width = sample_bbox[2] - sample_bbox[0]
sample_height = sample_bbox[3] - sample_bbox[1]
new_bbox[0] = (obj_bbox[0] - sample_bbox[0]) / sample_width
......@@ -72,6 +90,25 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None):
return bboxes, labels, scores
def bbox_area_sampling(bboxes, labels, scores, target_size, min_size):
new_bboxes = []
new_labels = []
new_scores = []
for i, bbox in enumerate(bboxes):
w = float((bbox[2] - bbox[0]) * target_size)
h = float((bbox[3] - bbox[1]) * target_size)
if w * h < float(min_size * min_size):
continue
else:
new_bboxes.append(bbox)
new_labels.append(labels[i])
new_scores.append(scores[i])
bboxes = np.array(new_bboxes)
labels = np.array(new_labels)
scores = np.array(new_scores)
return bboxes, labels, scores
def generate_sample_bbox(sampler):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
......@@ -89,6 +126,112 @@ def generate_sample_bbox(sampler):
return sampled_bbox
def generate_sample_bbox_square(sampler, image_width, image_height):
scale = np.random.uniform(sampler[2], sampler[3])
aspect_ratio = np.random.uniform(sampler[4], sampler[5])
aspect_ratio = max(aspect_ratio, (scale**2.0))
aspect_ratio = min(aspect_ratio, 1 / (scale**2.0))
bbox_width = scale * (aspect_ratio**0.5)
bbox_height = scale / (aspect_ratio**0.5)
if image_height < image_width:
bbox_width = bbox_height * image_height / image_width
else:
bbox_height = bbox_width * image_width / image_height
xmin_bound = 1 - bbox_width
ymin_bound = 1 - bbox_height
xmin = np.random.uniform(0, xmin_bound)
ymin = np.random.uniform(0, ymin_bound)
xmax = xmin + bbox_width
ymax = ymin + bbox_height
sampled_bbox = [xmin, ymin, xmax, ymax]
return sampled_bbox
def data_anchor_sampling(bbox_labels, image_width, image_height, scale_array,
resize_width):
num_gt = len(bbox_labels)
# np.random.randint range: [low, high)
rand_idx = np.random.randint(0, num_gt) if num_gt != 0 else 0
if num_gt != 0:
norm_xmin = bbox_labels[rand_idx][0]
norm_ymin = bbox_labels[rand_idx][1]
norm_xmax = bbox_labels[rand_idx][2]
norm_ymax = bbox_labels[rand_idx][3]
xmin = norm_xmin * image_width
ymin = norm_ymin * image_height
wid = image_width * (norm_xmax - norm_xmin)
hei = image_height * (norm_ymax - norm_ymin)
range_size = 0
area = wid * hei
for scale_ind in range(0, len(scale_array) - 1):
if area > scale_array[scale_ind] ** 2 and area < \
scale_array[scale_ind + 1] ** 2:
range_size = scale_ind + 1
break
if area > scale_array[len(scale_array) - 2]**2:
range_size = len(scale_array) - 2
scale_choose = 0.0
if range_size == 0:
rand_idx_size = 0
else:
# np.random.randint range: [low, high)
rng_rand_size = np.random.randint(0, range_size + 1)
rand_idx_size = rng_rand_size % (range_size + 1)
if rand_idx_size == range_size:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = min(2.0 * scale_array[rand_idx_size],
2 * math.sqrt(wid * hei))
scale_choose = random.uniform(min_resize_val, max_resize_val)
else:
min_resize_val = scale_array[rand_idx_size] / 2.0
max_resize_val = 2.0 * scale_array[rand_idx_size]
scale_choose = random.uniform(min_resize_val, max_resize_val)
sample_bbox_size = wid * resize_width / scale_choose
w_off_orig = 0.0
h_off_orig = 0.0
if sample_bbox_size < max(image_height, image_width):
if wid <= sample_bbox_size:
w_off_orig = np.random.uniform(xmin + wid - sample_bbox_size,
xmin)
else:
w_off_orig = np.random.uniform(xmin,
xmin + wid - sample_bbox_size)
if hei <= sample_bbox_size:
h_off_orig = np.random.uniform(ymin + hei - sample_bbox_size,
ymin)
else:
h_off_orig = np.random.uniform(ymin,
ymin + hei - sample_bbox_size)
else:
w_off_orig = np.random.uniform(image_width - sample_bbox_size, 0.0)
h_off_orig = np.random.uniform(image_height - sample_bbox_size, 0.0)
w_off_orig = math.floor(w_off_orig)
h_off_orig = math.floor(h_off_orig)
# Figure out top left coordinates.
w_off = float(w_off_orig / image_width)
h_off = float(h_off_orig / image_height)
sampled_bbox = [
w_off, h_off, w_off + float(sample_bbox_size / image_width),
h_off + float(sample_bbox_size / image_height)
]
return sampled_bbox
else:
return 0
def jaccard_overlap(sample_bbox, object_bbox):
if sample_bbox[0] >= object_bbox[2] or \
sample_bbox[2] <= object_bbox[0] or \
......@@ -108,6 +251,29 @@ def jaccard_overlap(sample_bbox, object_bbox):
return overlap
def intersect_bbox(bbox1, bbox2):
if bbox2[0] > bbox1[2] or bbox2[2] < bbox1[0] or \
bbox2[1] > bbox1[3] or bbox2[3] < bbox1[1]:
intersection_box = [0.0, 0.0, 0.0, 0.0]
else:
intersection_box = [
max(bbox1[0], bbox2[0]), max(bbox1[1], bbox2[1]),
min(bbox1[2], bbox2[2]), min(bbox1[3], bbox2[3])
]
return intersection_box
def bbox_coverage(bbox1, bbox2):
inter_box = intersect_bbox(bbox1, bbox2)
intersect_size = bbox_area(inter_box)
if intersect_size > 0:
bbox1_size = bbox_area(bbox1)
return intersect_size / bbox1_size
else:
return 0.
def satisfy_sample_constraint(sampler,
sample_bbox,
gt_bboxes,
......@@ -136,3 +302,87 @@ def satisfy_sample_constraint(sampler,
return np.all(satisfied)
else:
return False
def satisfy_sample_constraint_coverage(sampler, sample_bbox, gt_bboxes):
if sampler[6] == 0 and sampler[7] == 0:
has_jaccard_overlap = False
else:
has_jaccard_overlap = True
if sampler[8] == 0 and sampler[9] == 0:
has_object_coverage = False
else:
has_object_coverage = True
if not has_jaccard_overlap and not has_object_coverage:
return True
found = False
for i in range(len(gt_bboxes)):
object_bbox = [
gt_bboxes[i][0], gt_bboxes[i][1], gt_bboxes[i][2], gt_bboxes[i][3]
]
if has_jaccard_overlap:
overlap = jaccard_overlap(sample_bbox, object_bbox)
if sampler[6] != 0 and \
overlap < sampler[6]:
continue
if sampler[7] != 0 and \
overlap > sampler[7]:
continue
found = True
if has_object_coverage:
object_coverage = bbox_coverage(object_bbox, sample_bbox)
if sampler[8] != 0 and \
object_coverage < sampler[8]:
continue
if sampler[9] != 0 and \
object_coverage > sampler[9]:
continue
found = True
if found:
return True
return found
def crop_image_sampling(img, sample_bbox, image_width, image_height,
target_size):
# no clipping here
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
w_off = xmin
h_off = ymin
width = xmax - xmin
height = ymax - ymin
cross_xmin = max(0.0, float(w_off))
cross_ymin = max(0.0, float(h_off))
cross_xmax = min(float(w_off + width - 1.0), float(image_width))
cross_ymax = min(float(h_off + height - 1.0), float(image_height))
cross_width = cross_xmax - cross_xmin
cross_height = cross_ymax - cross_ymin
roi_xmin = 0 if w_off >= 0 else abs(w_off)
roi_ymin = 0 if h_off >= 0 else abs(h_off)
roi_width = cross_width
roi_height = cross_height
roi_y1 = int(roi_ymin)
roi_y2 = int(roi_ymin + roi_height)
roi_x1 = int(roi_xmin)
roi_x2 = int(roi_xmin + roi_width)
cross_y1 = int(cross_ymin)
cross_y2 = int(cross_ymin + cross_height)
cross_x1 = int(cross_xmin)
cross_x2 = int(cross_xmin + cross_width)
sample_img = np.zeros((height, width, 3))
sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \
img[cross_y1: cross_y2, cross_x1: cross_x2]
sample_img = cv2.resize(
sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
return sample_img
......@@ -31,7 +31,9 @@ from PIL import Image, ImageEnhance
from ppdet.core.workspace import serializable
from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox)
generate_sample_bbox, clip_bbox, data_anchor_sampling,
satisfy_sample_constraint_coverage, crop_image_sampling,
generate_sample_bbox_square, bbox_area_sampling)
logger = logging.getLogger(__name__)
......@@ -526,8 +528,6 @@ class CropImage(BaseOperator):
batch_sampler (list): Multiple sets of different
parameters for cropping.
satisfy_all (bool): whether all boxes must satisfy.
avoid_no_bbox (bool): whether to to avoid the
situation where the box does not appear.
e.g.[[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
......@@ -538,6 +538,8 @@ class CropImage(BaseOperator):
[max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio,
min overlap, max overlap]
avoid_no_bbox (bool): whether to to avoid the
situation where the box does not appear.
"""
super(CropImage, self).__init__()
self.batch_sampler = batch_sampler
......@@ -599,6 +601,151 @@ class CropImage(BaseOperator):
return sample
@register_op
class CropImageWithDataAchorSampling(BaseOperator):
def __init__(self,
batch_sampler,
anchor_sampler=None,
target_size=None,
das_anchor_scales=[16, 32, 64, 128],
sampling_prob=0.5,
min_size=8.,
avoid_no_bbox=True):
"""
Args:
anchor_sampler (list): anchor_sampling sets of different
parameters for cropping.
batch_sampler (list): Multiple sets of different
parameters for cropping.
e.g.[[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]]
[[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0]]
[max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio,
min overlap, max overlap, min coverage, max coverage]
target_size (bool): target image size.
das_anchor_scales (list[float]): a list of anchor scales in data
anchor smapling.
min_size (float): minimum size of sampled bbox.
avoid_no_bbox (bool): whether to to avoid the
situation where the box does not appear.
"""
super(CropImageWithDataAchorSampling, self).__init__()
self.anchor_sampler = anchor_sampler
self.batch_sampler = batch_sampler
self.target_size = target_size
self.sampling_prob = sampling_prob
self.min_size = min_size
self.avoid_no_bbox = avoid_no_bbox
self.scale_array = np.array(das_anchor_scales)
def __call__(self, sample, context):
"""
Crop the image and modify bounding box.
Operators:
1. Scale the image weight and height.
2. Crop the image according to a radom sample.
3. Rescale the bounding box.
4. Determine if the new bbox is satisfied in the new image.
Returns:
sample: the image, bounding box are replaced.
"""
assert 'image' in sample, "image data not found"
im = sample['image']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
image_width = sample['w']
image_height = sample['h']
gt_score = None
if 'gt_score' in sample:
gt_score = sample['gt_score']
sampled_bbox = []
gt_bbox = gt_bbox.tolist()
prob = np.random.uniform(0., 1.)
if prob > self.sampling_prob: # anchor sampling
assert self.anchor_sampler
for sampler in self.anchor_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = data_anchor_sampling(
gt_bbox, image_width, image_height, self.scale_array,
self.target_size)
if sample_bbox == 0:
break
if satisfy_sample_constraint_coverage(sampler, sample_bbox,
gt_bbox):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, gt_score)
crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size,
self.min_size)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
im = crop_image_sampling(im, sample_bbox, image_width,
image_height, self.target_size)
sample['image'] = im
sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class
sample['gt_score'] = crop_score
return sample
return sample
else:
for sampler in self.batch_sampler:
found = 0
for i in range(sampler[1]):
if found >= sampler[0]:
break
sample_bbox = generate_sample_bbox_square(
sampler, image_width, image_height)
if satisfy_sample_constraint_coverage(sampler, sample_bbox,
gt_bbox):
sampled_bbox.append(sample_bbox)
found = found + 1
im = np.array(im)
while sampled_bbox:
idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, gt_score)
# sampling bbox according the bbox area
crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size,
self.min_size)
if self.avoid_no_bbox:
if len(crop_bbox) < 1:
continue
xmin = int(sample_bbox[0] * image_width)
xmax = int(sample_bbox[2] * image_width)
ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height)
im = im[ymin:ymax, xmin:xmax]
sample['image'] = im
sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class
sample['gt_score'] = crop_score
return sample
return sample
@register_op
class NormalizeBox(BaseOperator):
"""Transform the bounding box's coornidates to [0,1]."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册