提交 fca6a9c4 编写于 作者: W wangxinxin08

add augmentation ops

fix bugs
上级 caf816a2
......@@ -61,7 +61,10 @@ def is_overlap(object_bbox, sample_bbox):
return True
def filter_and_process(sample_bbox, bboxes, labels, scores=None,
def filter_and_process(sample_bbox,
bboxes,
labels,
scores=None,
keypoints=None):
new_bboxes = []
new_labels = []
......@@ -92,8 +95,8 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None,
for j in range(len(sample_keypoint)):
kp_len = sample_height if j % 2 else sample_width
sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0]
sample_keypoint[j] = (
sample_keypoint[j] - sample_coord) / kp_len
sample_keypoint[j] = (sample_keypoint[j] -
sample_coord) / kp_len
sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0)
new_keypoints.append(sample_keypoint)
new_kp_ignore.append(keypoints[1][i])
......@@ -261,12 +264,12 @@ def jaccard_overlap(sample_bbox, object_bbox):
intersect_ymin = max(sample_bbox[1], object_bbox[1])
intersect_xmax = min(sample_bbox[2], object_bbox[2])
intersect_ymax = min(sample_bbox[3], object_bbox[3])
intersect_size = (intersect_xmax - intersect_xmin) * (
intersect_ymax - intersect_ymin)
intersect_size = (intersect_xmax - intersect_xmin) * (intersect_ymax -
intersect_ymin)
sample_bbox_size = bbox_area(sample_bbox)
object_bbox_size = bbox_area(object_bbox)
overlap = intersect_size / (
sample_bbox_size + object_bbox_size - intersect_size)
overlap = intersect_size / (sample_bbox_size + object_bbox_size -
intersect_size)
return overlap
......@@ -276,8 +279,10 @@ def intersect_bbox(bbox1, bbox2):
intersection_box = [0.0, 0.0, 0.0, 0.0]
else:
intersection_box = [
max(bbox1[0], bbox2[0]), max(bbox1[1], bbox2[1]),
min(bbox1[2], bbox2[2]), min(bbox1[3], bbox2[3])
max(bbox1[0], bbox2[0]),
max(bbox1[1], bbox2[1]),
min(bbox1[2], bbox2[2]),
min(bbox1[3], bbox2[3])
]
return intersection_box
......@@ -401,8 +406,8 @@ def crop_image_sampling(img, sample_bbox, image_width, image_height,
sample_img[roi_y1: roi_y2, roi_x1: roi_x2] = \
img[cross_y1: cross_y2, cross_x1: cross_x2]
sample_img = cv2.resize(
sample_img, (target_size, target_size), interpolation=cv2.INTER_AREA)
sample_img = cv2.resize(sample_img, (target_size, target_size),
interpolation=cv2.INTER_AREA)
return sample_img
......@@ -449,8 +454,8 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6):
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
radius + right]
masked_gaussian = gaussian[radius - top:radius + bottom,
radius - left:radius + right]
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
......@@ -458,7 +463,53 @@ def gaussian2D(shape, sigma_x=1, sigma_y=1):
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
sigma_y)))
h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y /
(2 * sigma_y * sigma_y)))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def transform_bbox(bbox,
label,
M,
w,
h,
area_thr=0.25,
wh_thr=2,
ar_thr=20,
perspective=False):
# rotate bbox
n = len(bbox)
xy = np.ones((n * 4, 3), dtype=np.float32)
xy[:, :2] = bbox[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2)
xy = xy @ M.T
if perspective:
xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8)
else:
xy = xy[:, :2].reshape(n, 8)
# get new bboxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
new_bbox = np.concatenate(
(x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip boxes
new_bbox, mask = clip_bbox(new_bbox, w, h, area_thr)
new_label = label[mask]
return new_bbox, new_label
def clip_bbox(bbox, w, h, area_thr=0.25, wh_thr=2, ar_thr=20):
# clip boxes
area1 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1)
bbox[:, [0, 2]] = bbox[:, [0, 2]].clip(0, w)
bbox[:, [1, 3]] = bbox[:, [1, 3]].clip(0, h)
# compute
area2 = (bbox[:, 2:4] - bbox[:, 0:2]).prod(1)
area_ratio = area2 / (area1 + 1e-16)
wh = bbox[:, 2:4] - bbox[:, 0:2]
ar_ratio = np.maximum(wh[:, 1] / (wh[:, 0] + 1e-16),
wh[:, 0] / (wh[:, 1] + 1e-16))
mask = (area_ratio > area_thr) & (
(wh > wh_thr).all(1)) & (ar_ratio < ar_thr)
bbox = bbox[mask]
return bbox, mask
......@@ -44,7 +44,7 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling,
satisfy_sample_constraint_coverage, crop_image_sampling,
generate_sample_bbox_square, bbox_area_sampling,
is_poly, gaussian_radius, draw_gaussian)
is_poly, gaussian_radius, draw_gaussian, transform_bbox, clip_bbox)
logger = logging.getLogger(__name__)
......@@ -2555,3 +2555,389 @@ class DebugVisibleImage(BaseOperator):
save_path = os.path.join(self.output_dir, out_file_name)
image.save(save_path, quality=95)
return sample
@register_op
class Rotate(BaseOperator):
def __init__(self,
degree,
scale=1.0,
center=None,
area_thr=0.25,
border_value=(114, 114, 114)):
super(Rotate, self).__init__()
self.degree = degree
self.scale = scale
self.center = center
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
im = sample['image']
bbox = sample['gt_bbox']
label = sample['gt_class']
# rotate image
height, width = im.shape[:2]
if self.center is None:
self.center = (width // 2, height // 2)
M = cv2.getRotationMatrix2D(self.center, self.degree, self.scale)
im = cv2.warpAffine(im,
M, (width, height),
borderValue=self.border_value)
# rotate bbox
if bbox.shape[0] > 0:
new_bbox, new_label = transform_bbox(bbox, label, M, width, height, self.area_thr)
else:
new_bbox, new_label = bbox, label
sample['image'] = im
sample['gt_bbox'] = new_bbox.astype(np.float32)
sample['gt_class'] = new_label.astype(np.int32)
return sample
@register_op
class RandomRotate(BaseOperator):
def __init__(self,
degree,
scale=0.0,
center=None,
area_thr=0.25,
border_value=(114, 114, 114)):
super(RandomRotate, self).__init__()
if isinstance(degree, (int, float)):
degree = abs(degree)
degree = (-degree, degree)
elif isinstance(degree, list) or isinstance(degree, tuple):
assert len(degree) == 2, 'len of degree is not equal to 2'
else:
raise ValueError('degree is not reasonable')
self.degree = degree
self.scale = scale
self.center = center
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
degree = random.uniform(*self.degree)
scale = random.uniform(1 - self.scale, 1 + self.scale)
rotate = Rotate(degree, scale, self.center, self.area_thr, self.border_value)
return rotate(sample, context)
@register_op
class Shear(BaseOperator):
def __init__(self, shear, area_thr=0.25, border_value=(114, 114, 114)):
super(Shear, self).__init__()
if isinstance(shear, (int, float)):
shear = (shear, shear)
elif isinstance(shear, list) or isinstance(shear, tuple):
assert len(shear) == 2, 'len of shear is not equal to 2'
else:
raise ValueError('shear is not reasonable')
self.shear = shear
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
im = sample['image']
bbox = sample['gt_bbox']
label = sample['gt_class']
# shear image
height, width = im.shape[:2]
shear_x = math.tan(self.shear[0] * math.pi / 180)
shear_y = math.tan(self.shear[1] * math.pi / 180)
M = np.array([[1, shear_x, 0], [shear_y, 1, 0]])
im = cv2.warpAffine(im,
M, (width, height),
borderValue=self.border_value)
# shear box
if bbox.shape[0] > 0:
new_bbox, new_label = transform_bbox(bbox, label, M, width, height, self.area_thr)
else:
new_bbox, new_label = bbox, label
sample['image'] = im
sample['gt_bbox'] = new_bbox.astype(np.float32)
sample['gt_class'] = new_label.astype(np.int32)
return sample
@register_op
class RandomShear(BaseOperator):
def __init__(self,
shear_x,
shear_y,
area_thr=0.25,
border_value=(114, 114, 114)):
super(RandomShear, self).__init__()
if isinstance(shear_x, (int, float)):
shear_x = abs(shear_x)
shear_x = (-shear_x, shear_x)
elif isinstance(shear_x, list) or isinstance(shear_x, tuple):
assert len(shear_x) == 2, 'len of shear_x is not equal to 2'
else:
raise ValueError('shear_x is not reasonable')
if isinstance(shear_y, (int, float)):
shear_y = abs(shear_y)
shear_y = (-shear_y, shear_y)
elif isinstance(shear_y, list) or isinstance(shear_y, tuple):
assert len(shear_y) == 2, 'len of shear_y is not equal to 2'
else:
raise ValueError('shear_y is not reasonable')
self.shear_x = shear_x
self.shear_y = shear_y
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
shear_x = random.uniform(*self.shear_x)
shear_y = random.uniform(*self.shear_y)
shear = Shear((shear_x, shear_y), self.area_thr, self.border_value)
return shear(sample, context)
@register_op
class Translate(BaseOperator):
def __init__(self, translate, area_thr=0.25, border_value=(114, 114, 114)):
super(Translate, self).__init__()
if isinstance(translate, (int, float)):
translate = (translate, translate)
elif isinstance(translate, list) or isinstance(translate, tuple):
assert len(translate) == 2, 'len of translate is not equal to 2'
else:
raise ValueError('translate is not reasonable')
assert abs(translate[0]) < 1 and abs(translate[1]) < 1, 'translate should be in (-1, 1)'
self.translate = translate
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
im = sample['image']
bbox = sample['gt_bbox']
label = sample['gt_class']
# translate image
height, width = im.shape[:2]
translate_x = int(self.translate[0] * width)
translate_y = int(self.translate[1] * height)
dst_cords = [
max(0, translate_y),
max(0, translate_x),
min(height, translate_y + height),
min(width, translate_x + width)
]
src_cords = [
max(-translate_y, 0),
max(-translate_x, 0),
min(-translate_y + height, height),
min(-translate_x + width, width)
]
canvas = np.ones(im.shape, dtype=np.uint8) * self.border_value
canvas[dst_cords[0]:dst_cords[2],
dst_cords[1]:dst_cords[3], :] = im[src_cords[0]:src_cords[2],
src_cords[1]:src_cords[3], :]
if bbox.shape[0] > 0:
new_bbox = bbox + [translate_x, translate_y, translate_x, translate_y]
# compute
new_bbox, mask = clip_bbox(new_bbox, width, height, self.area_thr)
new_label = label[mask]
else:
new_bbox, new_label = bbox, label
sample['image'] = canvas.astype(np.uint8)
sample['gt_bbox'] = new_bbox.astype(np.float32)
sample['gt_class'] = new_label.astype(np.int32)
return sample
@register_op
class RandomTranslate(BaseOperator):
def __init__(self,
translate_x,
translate_y,
area_thr=0.25,
border_value=(114, 114, 114)):
super(RandomTranslate, self).__init__()
if isinstance(translate_x, (int, float)):
translate_x = abs(translate_x)
translate_x = (-translate_x, translate_x)
elif isinstance(translate_x, list) or isinstance(translate_x, tuple):
assert len(translate_x) == 2, 'len of translate_x is not equal to 2'
else:
raise ValueError('translate_x is not reasonable')
if isinstance(translate_y, (int, float)):
translate_y = abs(translate_y)
translate_y = (-translate_y, translate_y)
elif isinstance(translate_y, list) or isinstance(translate_y, tuple):
assert len(translate_y) == 2, 'len of translate_y is not equal to 2'
else:
raise ValueError('translate_y is not reasonable')
self.translate_x = translate_x
self.translate_y = translate_y
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
translate_x = random.uniform(*self.translate_x)
translate_y = random.uniform(*self.translate_y)
translate = Translate((translate_x, translate_y), self.area_thr, self.border_value)
return translate(sample, context)
@register_op
class Scale(BaseOperator):
def __init__(self, scale, area_thr=0.25, border_value=(114, 114, 114)):
super(Scale, self).__init__()
if isinstance(scale, (int, float)):
scale = (scale, scale)
elif isinstance(scale, list) or isinstance(scale, tuple):
assert len(scale) == 2, 'len of scale is not equal to 2'
else:
raise ValueError('scale is not reasonable')
assert scale[0] > 0. and scale[1] > 0., 'scale should be great than 0'
self.scale = scale
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
im = sample['image']
bbox = sample['gt_bbox']
label = sample['gt_class']
# scale image
height, width = im.shape[:2]
dsize = (int(self.scale[0] * width), int(self.scale[1] * height))
dst_img = cv2.resize(im, dsize)
canvas = np.ones_like(im, dtype=np.uint8) * self.border_value
y_lim = min(height, dsize[1])
x_lim = min(width, dsize[0])
canvas[:y_lim, :x_lim, :] = dst_img[:y_lim, :x_lim, :]
# scale bbox
if bbox.shape[0] > 0:
new_bbox = bbox * [self.scale[0], self.scale[1], self.scale[0], self.scale[1]]
new_bbox, mask = clip_bbox(new_bbox, width, height, self.area_thr)
new_label = label[mask]
else:
new_bbox, new_label = bbox, label
sample['image'] = canvas.astype(np.uint8)
sample['gt_bbox'] = new_bbox.astype(np.float32)
sample['gt_class'] = new_label.astype(np.int32)
return sample
@register_op
class RandomScale(BaseOperator):
def __init__(self, scale_x, scale_y, area_thr=0.25, border_value=(114, 114, 114)):
super(RandomScale, self).__init__()
if isinstance(scale_x, (int, float)):
assert scale_x > 0., 'scale_x should be great than 0'
scale_x = (0., scale_x)
elif isinstance(scale_x, list) or isinstance(scale_x, tuple):
assert len(scale_x) == 2, 'len of scale_x is not equal to 2'
else:
raise ValueError('scale_x is not reasonable')
if isinstance(scale_y, (int, float)):
assert scale_y > 0., 'scale_y should be great than 0'
scale_y = (0., scale_y)
elif isinstance(scale_y, list) or isinstance(scale_y, tuple):
assert len(scale_y) == 2, 'len of scale_y is not equal to 2'
else:
raise ValueError('scale_y is not reasonable')
self.scale_x = scale_x
self.scale_y = scale_y
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
scale_x = random.uniform(*self.scale_x)
scale_y = random.uniform(*self.scale_y)
scale = Scale((scale_x, scale_y), self.area_thr, self.border_value)
return scale(sample, context)
@register_op
class RandomPerspective(BaseOperator):
def __init__(self, degree=10, translate=0.1, scale=0.1, shear=10, perspective=0.0, border=(0, 0), area_thr=0.25, border_value=(114,114,114)):
super(RandomPerspective, self).__init__()
self.degree = degree
self.translate = translate
self.scale = scale
self.shear = shear
self.perspective = perspective
self.border = border
self.area_thr = area_thr
self.border_value = border_value
def __call__(self, sample, context=None):
im = sample['image']
bbox = sample['gt_bbox']
label = sample['gt_class']
height = im.shape[0] + self.border[0]
width = im.shape[1] + self.border[1]
# center
C = np.eye(3)
C[0, 2] = -im.shape[1] / 2
C[1, 2] = -im.shape[0] / 2
# perspective
P = np.eye(3)
P[2, 0] = random.uniform(-self.perspective, self.perspective)
P[2, 1] = random.uniform(-self.perspective, self.perspective)
# Rotation and scale
R = np.eye(3)
a = random.uniform(-self.degree, self.degree)
s = random.uniform(1 - self.scale, 1 + self.scale)
R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
S = np.eye(3)
# shear x (deg)
S[0, 1] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180)
# shear y (deg)
S[1, 0] = math.tan(random.uniform(-self.shear, self.shear) * math.pi / 180)
# Translation
T = np.eye(3)
T[0, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * width
T[1, 2] = random.uniform(0.5 - self.translate, 0.5 + self.translate) * height
# matmul
M = T @ S @ R @ P @ C
if (self.border[0] != 0) or (self.border[1] != 0) or (M != np.eye(3)).any():
if self.perspective:
im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=self.border_value)
else:
im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=self.border_value)
if bbox.shape[0] > 0:
new_bbox, new_label = transform_bbox(bbox, label, M, width, height, area_thr=self.area_thr, perspective=self.perspective)
else:
new_bbox, new_label = bbox, label
sample['image'] = im
sample['gt_bbox'] = new_bbox.astype(np.float32)
sample['gt_class'] = new_label.astype(np.int32)
return sample
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册