未验证 提交 897d86ac 编写于 作者: S sunxl1988 提交者: GitHub

test=dygraph sync reader from static ppdet (#1084)

sync reader from static ppdet
上级 ce71cdc2
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import copy import copy
import functools import functools
import collections import collections
...@@ -167,6 +168,8 @@ class Reader(object): ...@@ -167,6 +168,8 @@ class Reader(object):
Default True. Default True.
mixup_epoch (int): mixup epoc number. Default is -1, meaning mixup_epoch (int): mixup epoc number. Default is -1, meaning
not use mixup. not use mixup.
cutmix_epoch (int): cutmix epoc number. Default is -1, meaning
not use cutmix.
class_aware_sampling (bool): whether use class-aware sampling or not. class_aware_sampling (bool): whether use class-aware sampling or not.
Default False. Default False.
worker_num (int): number of working threads/processes. worker_num (int): number of working threads/processes.
...@@ -191,6 +194,7 @@ class Reader(object): ...@@ -191,6 +194,7 @@ class Reader(object):
drop_last=False, drop_last=False,
drop_empty=True, drop_empty=True,
mixup_epoch=-1, mixup_epoch=-1,
cutmix_epoch=-1,
class_aware_sampling=False, class_aware_sampling=False,
worker_num=-1, worker_num=-1,
use_process=False, use_process=False,
...@@ -241,6 +245,7 @@ class Reader(object): ...@@ -241,6 +245,7 @@ class Reader(object):
# sampling # sampling
self._mixup_epoch = mixup_epoch self._mixup_epoch = mixup_epoch
self._cutmix_epoch = cutmix_epoch
self._class_aware_sampling = class_aware_sampling self._class_aware_sampling = class_aware_sampling
self._load_img = False self._load_img = False
...@@ -253,6 +258,8 @@ class Reader(object): ...@@ -253,6 +258,8 @@ class Reader(object):
self._pos = -1 self._pos = -1
self._epoch = -1 self._epoch = -1
self._curr_iter = 0
# multi-process # multi-process
self._worker_num = worker_num self._worker_num = worker_num
self._parallel = None self._parallel = None
...@@ -274,6 +281,11 @@ class Reader(object): ...@@ -274,6 +281,11 @@ class Reader(object):
def reset(self): def reset(self):
"""implementation of Dataset.reset """implementation of Dataset.reset
""" """
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self.indexes = [i for i in range(self.size())] self.indexes = [i for i in range(self.size())]
if self._class_aware_sampling: if self._class_aware_sampling:
self.indexes = np.random.choice( self.indexes = np.random.choice(
...@@ -283,17 +295,18 @@ class Reader(object): ...@@ -283,17 +295,18 @@ class Reader(object):
p=self.img_weights) p=self.img_weights)
if self._shuffle: if self._shuffle:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
np.random.seed(self._epoch + trainer_id)
np.random.shuffle(self.indexes) np.random.shuffle(self.indexes)
if self._mixup_epoch > 0 and len(self.indexes) < 2: if self._mixup_epoch > 0 and len(self.indexes) < 2:
logger.debug("Disable mixup for dataset samples " logger.debug("Disable mixup for dataset samples "
"less than 2 samples") "less than 2 samples")
self._mixup_epoch = -1 self._mixup_epoch = -1
if self._cutmix_epoch > 0 and len(self.indexes) < 2:
if self._epoch < 0: logger.info("Disable cutmix for dataset samples "
self._epoch = 0 "less than 2 samples")
else: self._cutmix_epoch = -1
self._epoch += 1
self._pos = 0 self._pos = 0
...@@ -306,6 +319,7 @@ class Reader(object): ...@@ -306,6 +319,7 @@ class Reader(object):
if self.drained(): if self.drained():
raise StopIteration raise StopIteration
batch = self._load_batch() batch = self._load_batch()
self._curr_iter += 1
if self._drop_last and len(batch) < self._batch_size: if self._drop_last and len(batch) < self._batch_size:
raise StopIteration raise StopIteration
if self._worker_num > -1: if self._worker_num > -1:
...@@ -321,6 +335,7 @@ class Reader(object): ...@@ -321,6 +335,7 @@ class Reader(object):
break break
pos = self.indexes[self._pos] pos = self.indexes[self._pos]
sample = copy.deepcopy(self._roidbs[pos]) sample = copy.deepcopy(self._roidbs[pos])
sample["curr_iter"] = self._curr_iter
self._pos += 1 self._pos += 1
if self._drop_empty and self._fields and 'gt_mask' in self._fields: if self._drop_empty and self._fields and 'gt_mask' in self._fields:
...@@ -343,9 +358,18 @@ class Reader(object): ...@@ -343,9 +358,18 @@ class Reader(object):
mix_idx = np.random.randint(1, num) mix_idx = np.random.randint(1, num)
mix_idx = self.indexes[(mix_idx + self._pos - 1) % num] mix_idx = self.indexes[(mix_idx + self._pos - 1) % num]
sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx]) sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx])
sample['mixup']["curr_iter"] = self._curr_iter
if self._load_img: if self._load_img:
sample['mixup']['image'] = self._load_image(sample['mixup'][ sample['mixup']['image'] = self._load_image(sample['mixup'][
'im_file']) 'im_file'])
if self._epoch < self._cutmix_epoch:
num = len(self.indexes)
mix_idx = np.random.randint(1, num)
sample['cutmix'] = copy.deepcopy(self._roidbs[mix_idx])
sample['cutmix']["curr_iter"] = self._curr_iter
if self._load_img:
sample['cutmix']['image'] = self._load_image(sample[
'cutmix']['im_file'])
batch.append(sample) batch.append(sample)
bs += 1 bs += 1
......
...@@ -41,7 +41,8 @@ class WIDERFaceDataSet(DataSet): ...@@ -41,7 +41,8 @@ class WIDERFaceDataSet(DataSet):
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
sample_num=-1, sample_num=-1,
with_background=True): with_background=True,
with_lmk=False):
super(WIDERFaceDataSet, self).__init__( super(WIDERFaceDataSet, self).__init__(
image_dir=image_dir, image_dir=image_dir,
anno_path=anno_path, anno_path=anno_path,
...@@ -53,6 +54,7 @@ class WIDERFaceDataSet(DataSet): ...@@ -53,6 +54,7 @@ class WIDERFaceDataSet(DataSet):
self.with_background = with_background self.with_background = with_background
self.roidbs = None self.roidbs = None
self.cname2cid = None self.cname2cid = None
self.with_lmk = with_lmk
def load_roidb_and_cname2cid(self): def load_roidb_and_cname2cid(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path) anno_path = os.path.join(self.dataset_dir, self.anno_path)
...@@ -62,33 +64,23 @@ class WIDERFaceDataSet(DataSet): ...@@ -62,33 +64,23 @@ class WIDERFaceDataSet(DataSet):
records = [] records = []
ct = 0 ct = 0
file_lists = _load_file_list(txt_file) file_lists = self._load_file_list(txt_file)
cname2cid = widerface_label(self.with_background) cname2cid = widerface_label(self.with_background)
for item in file_lists: for item in file_lists:
im_fname = item[0] im_fname = item[0]
im_id = np.array([ct]) im_id = np.array([ct])
gt_bbox = np.zeros((len(item) - 2, 4), dtype=np.float32) gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32)
gt_class = np.ones((len(item) - 2, 1), dtype=np.int32) gt_class = np.ones((len(item) - 1, 1), dtype=np.int32)
gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32)
lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32)
for index_box in range(len(item)): for index_box in range(len(item)):
if index_box >= 2: if index_box < 1:
temp_info_box = item[index_box].split(' ')
xmin = float(temp_info_box[0])
ymin = float(temp_info_box[1])
w = float(temp_info_box[2])
h = float(temp_info_box[3])
# Filter out wrong labels
if w < 0 or h < 0:
logger.warn('Illegal box with w: {}, h: {} in '
'img: {}, and it will be ignored'.format(
w, h, im_fname))
continue continue
xmin = max(0, xmin) gt_bbox[index_box - 1] = item[index_box][0]
ymin = max(0, ymin) if self.with_lmk:
xmax = xmin + w gt_lmk_labels[index_box - 1] = item[index_box][1]
ymax = ymin + h lmk_ignore_flag[index_box - 1] = item[index_box][2]
gt_bbox[index_box - 2] = [xmin, ymin, xmax, ymax]
im_fname = os.path.join(image_dir, im_fname = os.path.join(image_dir,
im_fname) if image_dir else im_fname im_fname) if image_dir else im_fname
widerface_rec = { widerface_rec = {
...@@ -97,7 +89,10 @@ class WIDERFaceDataSet(DataSet): ...@@ -97,7 +89,10 @@ class WIDERFaceDataSet(DataSet):
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_class': gt_class, 'gt_class': gt_class,
} }
# logger.debug if self.with_lmk:
widerface_rec['gt_keypoint'] = gt_lmk_labels
widerface_rec['keypoint_ignore'] = lmk_ignore_flag
if len(item) != 0: if len(item) != 0:
records.append(widerface_rec) records.append(widerface_rec)
...@@ -108,8 +103,7 @@ class WIDERFaceDataSet(DataSet): ...@@ -108,8 +103,7 @@ class WIDERFaceDataSet(DataSet):
logger.debug('{} samples in file {}'.format(ct, anno_path)) logger.debug('{} samples in file {}'.format(ct, anno_path))
self.roidbs, self.cname2cid = records, cname2cid self.roidbs, self.cname2cid = records, cname2cid
def _load_file_list(self, input_txt):
def _load_file_list(input_txt):
with open(input_txt, 'r') as f_dir: with open(input_txt, 'r') as f_dir:
lines_input_txt = f_dir.readlines() lines_input_txt = f_dir.readlines()
...@@ -123,17 +117,48 @@ def _load_file_list(input_txt): ...@@ -123,17 +117,48 @@ def _load_file_list(input_txt):
file_dict[num_class] = [] file_dict[num_class] = []
file_dict[num_class].append(line_txt) file_dict[num_class].append(line_txt)
if '.jpg' not in line_txt: if '.jpg' not in line_txt:
if len(line_txt) > 6: if len(line_txt) <= 6:
continue
result_boxs = []
split_str = line_txt.split(' ') split_str = line_txt.split(' ')
x1_min = float(split_str[0]) xmin = float(split_str[0])
y1_min = float(split_str[1]) ymin = float(split_str[1])
x2_max = float(split_str[2]) w = float(split_str[2])
y2_max = float(split_str[3]) h = float(split_str[3])
line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str( # Filter out wrong labels
x2_max) + ' ' + str(y2_max) if w < 0 or h < 0:
file_dict[num_class].append(line_txt) logger.warn('Illegal box with w: {}, h: {} in '
else: 'img: {}, and it will be ignored'.format(
file_dict[num_class].append(line_txt) w, h, file_dict[num_class][0]))
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox = [xmin, ymin, xmax, ymax]
result_boxs.append(gt_bbox)
if self.with_lmk:
assert len(split_str) > 18, 'When `with_lmk=True`, the number' \
'of characters per line in the annotation file should' \
'exceed 18.'
lmk0_x = float(split_str[5])
lmk0_y = float(split_str[6])
lmk1_x = float(split_str[8])
lmk1_y = float(split_str[9])
lmk2_x = float(split_str[11])
lmk2_y = float(split_str[12])
lmk3_x = float(split_str[14])
lmk3_y = float(split_str[15])
lmk4_x = float(split_str[17])
lmk4_y = float(split_str[18])
lmk_ignore_flag = 0 if lmk0_x == -1 else 1
gt_lmk_label = [
lmk0_x, lmk0_y, lmk1_x, lmk1_y, lmk2_x, lmk2_y, lmk3_x,
lmk3_y, lmk4_x, lmk4_y
]
result_boxs.append(gt_lmk_label)
result_boxs.append(lmk_ignore_flag)
file_dict[num_class].append(result_boxs)
return list(file_dict.values()) return list(file_dict.values())
......
...@@ -26,13 +26,17 @@ import cv2 ...@@ -26,13 +26,17 @@ import cv2
import numpy as np import numpy as np
from .operators import register_op, BaseOperator from .operators import register_op, BaseOperator
from .op_helper import jaccard_overlap from .op_helper import jaccard_overlap, gaussian2D
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget', 'PadBatch',
'Gt2FCOSTarget' 'RandomShape',
'PadMultiScaleTest',
'Gt2YoloTarget',
'Gt2FCOSTarget',
'Gt2TTFTarget',
] ]
...@@ -41,17 +45,15 @@ class PadBatch(BaseOperator): ...@@ -41,17 +45,15 @@ class PadBatch(BaseOperator):
""" """
Pad a batch of samples so they can be divisible by a stride. Pad a batch of samples so they can be divisible by a stride.
The layout of each image should be 'CHW'. The layout of each image should be 'CHW'.
Args: Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
height and width is divisible by `pad_to_stride`. height and width is divisible by `pad_to_stride`.
""" """
def __init__(self, pad_to_stride=0, use_padded_im_info=True, pad_gt=False): def __init__(self, pad_to_stride=0, use_padded_im_info=True):
super(PadBatch, self).__init__() super(PadBatch, self).__init__()
self.pad_to_stride = pad_to_stride self.pad_to_stride = pad_to_stride
self.use_padded_im_info = use_padded_im_info self.use_padded_im_info = use_padded_im_info
self.pad_gt = pad_gt
def __call__(self, samples, context=None): def __call__(self, samples, context=None):
""" """
...@@ -61,9 +63,9 @@ class PadBatch(BaseOperator): ...@@ -61,9 +63,9 @@ class PadBatch(BaseOperator):
coarsest_stride = self.pad_to_stride coarsest_stride = self.pad_to_stride
if coarsest_stride == 0: if coarsest_stride == 0:
return samples return samples
max_shape = np.array([data['image'].shape for data in samples]).max( max_shape = np.array([data['image'].shape for data in samples]).max(
axis=0) axis=0)
if coarsest_stride > 0: if coarsest_stride > 0:
max_shape[1] = int( max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride) np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
...@@ -80,52 +82,6 @@ class PadBatch(BaseOperator): ...@@ -80,52 +82,6 @@ class PadBatch(BaseOperator):
data['image'] = padding_im data['image'] = padding_im
if self.use_padded_im_info: if self.use_padded_im_info:
data['im_info'][:2] = max_shape[1:3] data['im_info'][:2] = max_shape[1:3]
if self.pad_gt:
gt_num = []
if data['gt_poly'] is not None and len(data['gt_poly']) > 0:
pad_mask = True
else:
pad_mask = False
if pad_mask:
poly_num = []
poly_part_num = []
point_num = []
for data in samples:
gt_num.append(data['gt_bbox'].shape[0])
if pad_mask:
poly_num.append(len(data['gt_poly']))
for poly in data['gt_poly']:
poly_part_num.append(int(len(poly)))
for p_p in poly:
point_num.append(int(len(p_p) / 2))
gt_num_max = max(gt_num)
gt_box_data = np.zeros([gt_num_max, 4])
gt_class_data = np.zeros([gt_num_max])
is_crowd_data = np.ones([gt_num_max])
if pad_mask:
poly_num_max = max(poly_num)
poly_part_num_max = max(poly_part_num)
point_num_max = max(point_num)
gt_masks_data = -np.ones(
[poly_num_max, poly_part_num_max, point_num_max, 2])
for i, data in enumerate(samples):
gt_num = data['gt_bbox'].shape[0]
gt_box_data[0:gt_num, :] = data['gt_bbox']
gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd'])
if pad_mask:
for j, poly in enumerate(data['gt_poly']):
for k, p_p in enumerate(poly):
pp_np = np.array(p_p).reshape(-1, 2)
gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
data['gt_poly'] = gt_masks_data
data['gt_bbox'] = gt_box_data
data['gt_class'] = gt_class_data
data['is_crowd_data'] = is_crowd_data
return samples return samples
...@@ -136,13 +92,12 @@ class RandomShape(BaseOperator): ...@@ -136,13 +92,12 @@ class RandomShape(BaseOperator):
select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR, select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR,
cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is
False, use cv2.INTER_NEAREST. False, use cv2.INTER_NEAREST.
Args: Args:
sizes (list): list of int, random choose a size from these sizes (list): list of int, random choose a size from these
random_inter (bool): whether to randomly interpolation, defalut true. random_inter (bool): whether to randomly interpolation, defalut true.
""" """
def __init__(self, sizes=[], random_inter=False): def __init__(self, sizes=[], random_inter=False, resize_box=False):
super(RandomShape, self).__init__() super(RandomShape, self).__init__()
self.sizes = sizes self.sizes = sizes
self.random_inter = random_inter self.random_inter = random_inter
...@@ -153,6 +108,7 @@ class RandomShape(BaseOperator): ...@@ -153,6 +108,7 @@ class RandomShape(BaseOperator):
cv2.INTER_CUBIC, cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4, cv2.INTER_LANCZOS4,
] if random_inter else [] ] if random_inter else []
self.resize_box = resize_box
def __call__(self, samples, context=None): def __call__(self, samples, context=None):
shape = np.random.choice(self.sizes) shape = np.random.choice(self.sizes)
...@@ -166,6 +122,12 @@ class RandomShape(BaseOperator): ...@@ -166,6 +122,12 @@ class RandomShape(BaseOperator):
im = cv2.resize( im = cv2.resize(
im, None, None, fx=scale_x, fy=scale_y, interpolation=method) im, None, None, fx=scale_x, fy=scale_y, interpolation=method)
samples[i]['image'] = im samples[i]['image'] = im
if self.resize_box and 'gt_bbox' in samples[i] and len(samples[0][
'gt_bbox']) > 0:
scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32)
samples[i]['gt_bbox'] = np.clip(samples[i]['gt_bbox'] *
scale_array, 0,
float(shape) - 1)
return samples return samples
...@@ -525,3 +487,99 @@ class Gt2FCOSTarget(BaseOperator): ...@@ -525,3 +487,99 @@ class Gt2FCOSTarget(BaseOperator):
sample['centerness{}'.format(lvl)] = np.reshape( sample['centerness{}'.format(lvl)] = np.reshape(
ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1]) ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
return samples return samples
@register_op
class Gt2TTFTarget(BaseOperator):
"""
Gt2TTFTarget
Generate TTFNet targets by ground truth data
Args:
num_classes(int): the number of classes.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
alpha(float): the alpha parameter to generate gaussian target.
0.54 by default.
"""
def __init__(self, num_classes, down_ratio=4, alpha=0.54):
super(Gt2TTFTarget, self).__init__()
self.down_ratio = down_ratio
self.num_classes = num_classes
self.alpha = alpha
def __call__(self, samples, context=None):
output_size = samples[0]['image'].shape[1]
feat_size = output_size // self.down_ratio
for sample in samples:
heatmap = np.zeros(
(self.num_classes, feat_size, feat_size), dtype='float32')
box_target = np.ones(
(4, feat_size, feat_size), dtype='float32') * -1
reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32')
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1
bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1
area = bbox_w * bbox_h
boxes_areas_log = np.log(area)
boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1]
boxes_area_topk_log = boxes_areas_log[boxes_ind]
gt_bbox = gt_bbox[boxes_ind]
gt_class = gt_class[boxes_ind]
feat_gt_bbox = gt_bbox / self.down_ratio
feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1)
feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1],
feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0])
ct_inds = np.stack(
[(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2,
(gt_bbox[:, 1] + gt_bbox[:, 3]) / 2],
axis=1) / self.down_ratio
h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32')
w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32')
for k in range(len(gt_bbox)):
cls_id = gt_class[k]
fake_heatmap = np.zeros((feat_size, feat_size), dtype='float32')
self.draw_truncate_gaussian(fake_heatmap, ct_inds[k],
h_radiuses_alpha[k],
w_radiuses_alpha[k])
heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap)
box_target_inds = fake_heatmap > 0
box_target[:, box_target_inds] = gt_bbox[k][:, None]
local_heatmap = fake_heatmap[box_target_inds]
ct_div = np.sum(local_heatmap)
local_heatmap *= boxes_area_topk_log[k]
reg_weight[0, box_target_inds] = local_heatmap / ct_div
sample['ttf_heatmap'] = heatmap
sample['ttf_box_target'] = box_target
sample['ttf_reg_weight'] = reg_weight
return samples
def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius):
h, w = 2 * h_radius + 1, 2 * w_radius + 1
sigma_x = w / 6
sigma_y = h / 6
gaussian = gaussian2D((h, w), sigma_x, sigma_y)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, w_radius), min(width - x, w_radius + 1)
top, bottom = min(y, h_radius), min(height - y, h_radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius -
left:w_radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
masked_heatmap, masked_gaussian)
return heatmap
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import numpy as np
from PIL import Image
class GridMask(object):
def __init__(self,
use_h=True,
use_w=True,
rotate=1,
offset=False,
ratio=0.5,
mode=1,
prob=0.7,
upper_iter=360000):
super(GridMask, self).__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.prob = prob
self.st_prob = prob
self.upper_iter = upper_iter
def __call__(self, x, curr_iter):
self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter)
if np.random.rand() > self.prob:
return x
_, h, w = x.shape
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(2, h)
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + self.l, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + self.l, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // 2
+ w].astype(np.float32)
if self.mode == 1:
mask = 1 - mask
mask = np.expand_dims(mask, axis=0)
if self.offset:
offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32)
x = (x * mask + offset * (1 - mask)).astype(x.dtype)
else:
x = (x * mask).astype(x.dtype)
return x
...@@ -61,10 +61,13 @@ def is_overlap(object_bbox, sample_bbox): ...@@ -61,10 +61,13 @@ def is_overlap(object_bbox, sample_bbox):
return True return True
def filter_and_process(sample_bbox, bboxes, labels, scores=None): def filter_and_process(sample_bbox, bboxes, labels, scores=None,
keypoints=None):
new_bboxes = [] new_bboxes = []
new_labels = [] new_labels = []
new_scores = [] new_scores = []
new_keypoints = []
new_kp_ignore = []
for i in range(len(bboxes)): for i in range(len(bboxes)):
new_bbox = [0, 0, 0, 0] new_bbox = [0, 0, 0, 0]
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]] obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
...@@ -84,9 +87,24 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None): ...@@ -84,9 +87,24 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None):
new_labels.append([labels[i][0]]) new_labels.append([labels[i][0]])
if scores is not None: if scores is not None:
new_scores.append([scores[i][0]]) new_scores.append([scores[i][0]])
if keypoints is not None:
sample_keypoint = keypoints[0][i]
for j in range(len(sample_keypoint)):
kp_len = sample_height if j % 2 else sample_width
sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0]
sample_keypoint[j] = (
sample_keypoint[j] - sample_coord) / kp_len
sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0)
new_keypoints.append(sample_keypoint)
new_kp_ignore.append(keypoints[1][i])
bboxes = np.array(new_bboxes) bboxes = np.array(new_bboxes)
labels = np.array(new_labels) labels = np.array(new_labels)
scores = np.array(new_scores) scores = np.array(new_scores)
if keypoints is not None:
keypoints = np.array(new_keypoints)
new_kp_ignore = np.array(new_kp_ignore)
return bboxes, labels, scores, (keypoints, new_kp_ignore)
return bboxes, labels, scores return bboxes, labels, scores
...@@ -420,7 +438,8 @@ def gaussian_radius(bbox_size, min_overlap): ...@@ -420,7 +438,8 @@ def gaussian_radius(bbox_size, min_overlap):
def draw_gaussian(heatmap, center, radius, k=1, delte=6): def draw_gaussian(heatmap, center, radius, k=1, delte=6):
diameter = 2 * radius + 1 diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / delte) sigma = diameter / delte
gaussian = gaussian2D((diameter, diameter), sigma_x=sigma, sigma_y=sigma)
x, y = center x, y = center
...@@ -435,10 +454,11 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6): ...@@ -435,10 +454,11 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6):
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
def gaussian2D(shape, sigma=1): def gaussian2D(shape, sigma_x=1, sigma_y=1):
m, n = [(ss - 1.) / 2. for ss in shape] m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1] y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
sigma_y)))
h[h < np.finfo(h.dtype).eps * h.max()] = 0 h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h return h
...@@ -32,9 +32,10 @@ import logging ...@@ -32,9 +32,10 @@ import logging
import random import random
import math import math
import numpy as np import numpy as np
import os
import cv2 import cv2
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance, ImageDraw
from ppdet.core.workspace import serializable from ppdet.core.workspace import serializable
from ppdet.modeling.ops import AnchorGrid from ppdet.modeling.ops import AnchorGrid
...@@ -89,21 +90,24 @@ class BaseOperator(object): ...@@ -89,21 +90,24 @@ class BaseOperator(object):
@register_op @register_op
class DecodeImage(BaseOperator): class DecodeImage(BaseOperator):
def __init__(self, to_rgb=True, with_mixup=False): def __init__(self, to_rgb=True, with_mixup=False, with_cutmix=False):
""" Transform the image data to numpy format. """ Transform the image data to numpy format.
Args: Args:
to_rgb (bool): whether to convert BGR to RGB to_rgb (bool): whether to convert BGR to RGB
with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score
with_cutmix (bool): whether or not to cutmix image and gt_bbbox/gt_score
""" """
super(DecodeImage, self).__init__() super(DecodeImage, self).__init__()
self.to_rgb = to_rgb self.to_rgb = to_rgb
self.with_mixup = with_mixup self.with_mixup = with_mixup
self.with_cutmix = with_cutmix
if not isinstance(self.to_rgb, bool): if not isinstance(self.to_rgb, bool):
raise TypeError("{}: input type is invalid.".format(self)) raise TypeError("{}: input type is invalid.".format(self))
if not isinstance(self.with_mixup, bool): if not isinstance(self.with_mixup, bool):
raise TypeError("{}: input type is invalid.".format(self)) raise TypeError("{}: input type is invalid.".format(self))
if not isinstance(self.with_cutmix, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None): def __call__(self, sample, context=None):
""" load image if 'im_file' field is not empty but 'image' is""" """ load image if 'im_file' field is not empty but 'image' is"""
...@@ -142,6 +146,10 @@ class DecodeImage(BaseOperator): ...@@ -142,6 +146,10 @@ class DecodeImage(BaseOperator):
# decode mixup image # decode mixup image
if self.with_mixup and 'mixup' in sample: if self.with_mixup and 'mixup' in sample:
self.__call__(sample['mixup'], context) self.__call__(sample['mixup'], context)
# decode cutmix image
if self.with_cutmix and 'cutmix' in sample:
self.__call__(sample['cutmix'], context)
return sample return sample
...@@ -156,7 +164,6 @@ class MultiscaleTestResize(BaseOperator): ...@@ -156,7 +164,6 @@ class MultiscaleTestResize(BaseOperator):
use_flip=True): use_flip=True):
""" """
Rescale image to the each size in target size, and capped at max_size. Rescale image to the each size in target size, and capped at max_size.
Args: Args:
origin_target_size(int): original target size of image's short side. origin_target_size(int): original target size of image's short side.
origin_max_size(int): original max size of image. origin_max_size(int): original max size of image.
...@@ -265,7 +272,6 @@ class ResizeImage(BaseOperator): ...@@ -265,7 +272,6 @@ class ResizeImage(BaseOperator):
if max_size != 0. if max_size != 0.
If target_size is list, selected a scale randomly as the specified If target_size is list, selected a scale randomly as the specified
target size. target size.
Args: Args:
target_size (int|list): the target size of image's short side, target_size (int|list): the target size of image's short side,
multi-scale training is adopted when type is list. multi-scale training is adopted when type is list.
...@@ -392,6 +398,16 @@ class RandomFlipImage(BaseOperator): ...@@ -392,6 +398,16 @@ class RandomFlipImage(BaseOperator):
flipped_segms.append(_flip_rle(segm, height, width)) flipped_segms.append(_flip_rle(segm, height, width))
return flipped_segms return flipped_segms
def flip_keypoint(self, gt_keypoint, width):
for i in range(gt_keypoint.shape[1]):
if i % 2 == 0:
old_x = gt_keypoint[:, i].copy()
if self.is_normalized:
gt_keypoint[:, i] = 1 - old_x
else:
gt_keypoint[:, i] = width - old_x - 1
return gt_keypoint
def __call__(self, sample, context=None): def __call__(self, sample, context=None):
"""Filp the image and bounding box. """Filp the image and bounding box.
Operators: Operators:
...@@ -439,12 +455,130 @@ class RandomFlipImage(BaseOperator): ...@@ -439,12 +455,130 @@ class RandomFlipImage(BaseOperator):
if self.is_mask_flip and len(sample['gt_poly']) != 0: if self.is_mask_flip and len(sample['gt_poly']) != 0:
sample['gt_poly'] = self.flip_segms(sample['gt_poly'], sample['gt_poly'] = self.flip_segms(sample['gt_poly'],
height, width) height, width)
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = self.flip_keypoint(
sample['gt_keypoint'], width)
sample['flipped'] = True sample['flipped'] = True
sample['image'] = im sample['image'] = im
sample = samples if batch_input else samples[0] sample = samples if batch_input else samples[0]
return sample return sample
@register_op
class RandomErasingImage(BaseOperator):
def __init__(self, prob=0.5, sl=0.02, sh=0.4, r1=0.3):
"""
Random Erasing Data Augmentation, see https://arxiv.org/abs/1708.04896
Args:
prob (float): probability to carry out random erasing
sl (float): lower limit of the erasing area ratio
sh (float): upper limit of the erasing area ratio
r1 (float): aspect ratio of the erasing region
"""
super(RandomErasingImage, self).__init__()
self.prob = prob
self.sl = sl
self.sh = sh
self.r1 = r1
def __call__(self, sample, context=None):
samples = sample
batch_input = True
if not isinstance(samples, Sequence):
batch_input = False
samples = [samples]
for sample in samples:
gt_bbox = sample['gt_bbox']
im = sample['image']
if not isinstance(im, np.ndarray):
raise TypeError("{}: image is not a numpy array.".format(self))
if len(im.shape) != 3:
raise ImageError("{}: image is not 3-dimensional.".format(self))
for idx in range(gt_bbox.shape[0]):
if self.prob <= np.random.rand():
continue
x1, y1, x2, y2 = gt_bbox[idx, :]
w_bbox = x2 - x1 + 1
h_bbox = y2 - y1 + 1
area = w_bbox * h_bbox
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < w_bbox and h < h_bbox:
off_y1 = random.randint(0, int(h_bbox - h))
off_x1 = random.randint(0, int(w_bbox - w))
im[int(y1 + off_y1):int(y1 + off_y1 + h), int(x1 + off_x1):
int(x1 + off_x1 + w), :] = 0
sample['image'] = im
sample = samples if batch_input else samples[0]
return sample
@register_op
class GridMaskOp(BaseOperator):
def __init__(self,
use_h=True,
use_w=True,
rotate=1,
offset=False,
ratio=0.5,
mode=1,
prob=0.7,
upper_iter=360000):
"""
GridMask Data Augmentation, see https://arxiv.org/abs/2001.04086
Args:
use_h (bool): whether to mask vertically
use_w (boo;): whether to mask horizontally
rotate (float): angle for the mask to rotate
offset (float): mask offset
ratio (float): mask ratio
mode (int): gridmask mode
prob (float): max probability to carry out gridmask
upper_iter (int): suggested to be equal to global max_iter
"""
super(GridMaskOp, self).__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.prob = prob
self.upper_iter = upper_iter
from .gridmask_utils import GridMask
self.gridmask_op = GridMask(
use_h,
use_w,
rotate=rotate,
offset=offset,
ratio=ratio,
mode=mode,
prob=prob,
upper_iter=upper_iter)
def __call__(self, sample, context=None):
samples = sample
batch_input = True
if not isinstance(samples, Sequence):
batch_input = False
samples = [samples]
for sample in samples:
sample['image'] = self.gridmask_op(sample['image'],
sample['curr_iter'])
if not batch_input:
samples = samples[0]
return sample
@register_op @register_op
class AutoAugmentImage(BaseOperator): class AutoAugmentImage(BaseOperator):
def __init__(self, is_normalized=False, autoaug_type="v1"): def __init__(self, is_normalized=False, autoaug_type="v1"):
...@@ -733,8 +867,17 @@ class ExpandImage(BaseOperator): ...@@ -733,8 +867,17 @@ class ExpandImage(BaseOperator):
im = Image.fromarray(im) im = Image.fromarray(im)
expand_im.paste(im, (int(w_off), int(h_off))) expand_im.paste(im, (int(w_off), int(h_off)))
expand_im = np.asarray(expand_im) expand_im = np.asarray(expand_im)
gt_bbox, gt_class, _ = filter_and_process(expand_bbox, gt_bbox, if 'gt_keypoint' in sample.keys(
gt_class) ) and 'keypoint_ignore' in sample.keys():
keypoints = (sample['gt_keypoint'],
sample['keypoint_ignore'])
gt_bbox, gt_class, _, gt_keypoints = filter_and_process(
expand_bbox, gt_bbox, gt_class, keypoints=keypoints)
sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1]
else:
gt_bbox, gt_class, _ = filter_and_process(expand_bbox,
gt_bbox, gt_class)
sample['image'] = expand_im sample['image'] = expand_im
sample['gt_bbox'] = gt_bbox sample['gt_bbox'] = gt_bbox
sample['gt_class'] = gt_class sample['gt_class'] = gt_class
...@@ -808,7 +951,7 @@ class CropImage(BaseOperator): ...@@ -808,7 +951,7 @@ class CropImage(BaseOperator):
sample_bbox = sampled_bbox.pop(idx) sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox) sample_bbox = clip_bbox(sample_bbox)
crop_bbox, crop_class, crop_score = \ crop_bbox, crop_class, crop_score = \
filter_and_process(sample_bbox, gt_bbox, gt_class, gt_score) filter_and_process(sample_bbox, gt_bbox, gt_class, scores=gt_score)
if self.avoid_no_bbox: if self.avoid_no_bbox:
if len(crop_bbox) < 1: if len(crop_bbox) < 1:
continue continue
...@@ -911,8 +1054,16 @@ class CropImageWithDataAchorSampling(BaseOperator): ...@@ -911,8 +1054,16 @@ class CropImageWithDataAchorSampling(BaseOperator):
idx = int(np.random.uniform(0, len(sampled_bbox))) idx = int(np.random.uniform(0, len(sampled_bbox)))
sample_bbox = sampled_bbox.pop(idx) sample_bbox = sampled_bbox.pop(idx)
if 'gt_keypoint' in sample.keys():
keypoints = (sample['gt_keypoint'],
sample['keypoint_ignore'])
crop_bbox, crop_class, crop_score, gt_keypoints = \
filter_and_process(sample_bbox, gt_bbox, gt_class,
scores=gt_score,
keypoints=keypoints)
else:
crop_bbox, crop_class, crop_score = filter_and_process( crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, gt_score) sample_bbox, gt_bbox, gt_class, scores=gt_score)
crop_bbox, crop_class, crop_score = bbox_area_sampling( crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size, crop_bbox, crop_class, crop_score, self.target_size,
self.min_size) self.min_size)
...@@ -926,6 +1077,9 @@ class CropImageWithDataAchorSampling(BaseOperator): ...@@ -926,6 +1077,9 @@ class CropImageWithDataAchorSampling(BaseOperator):
sample['gt_bbox'] = crop_bbox sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class sample['gt_class'] = crop_class
sample['gt_score'] = crop_score sample['gt_score'] = crop_score
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1]
return sample return sample
return sample return sample
...@@ -947,8 +1101,16 @@ class CropImageWithDataAchorSampling(BaseOperator): ...@@ -947,8 +1101,16 @@ class CropImageWithDataAchorSampling(BaseOperator):
sample_bbox = sampled_bbox.pop(idx) sample_bbox = sampled_bbox.pop(idx)
sample_bbox = clip_bbox(sample_bbox) sample_bbox = clip_bbox(sample_bbox)
if 'gt_keypoint' in sample.keys():
keypoints = (sample['gt_keypoint'],
sample['keypoint_ignore'])
crop_bbox, crop_class, crop_score, gt_keypoints = \
filter_and_process(sample_bbox, gt_bbox, gt_class,
scores=gt_score,
keypoints=keypoints)
else:
crop_bbox, crop_class, crop_score = filter_and_process( crop_bbox, crop_class, crop_score = filter_and_process(
sample_bbox, gt_bbox, gt_class, gt_score) sample_bbox, gt_bbox, gt_class, scores=gt_score)
# sampling bbox according the bbox area # sampling bbox according the bbox area
crop_bbox, crop_class, crop_score = bbox_area_sampling( crop_bbox, crop_class, crop_score = bbox_area_sampling(
crop_bbox, crop_class, crop_score, self.target_size, crop_bbox, crop_class, crop_score, self.target_size,
...@@ -966,6 +1128,9 @@ class CropImageWithDataAchorSampling(BaseOperator): ...@@ -966,6 +1128,9 @@ class CropImageWithDataAchorSampling(BaseOperator):
sample['gt_bbox'] = crop_bbox sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class sample['gt_class'] = crop_class
sample['gt_score'] = crop_score sample['gt_score'] = crop_score
if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1]
return sample return sample
return sample return sample
...@@ -987,6 +1152,17 @@ class NormalizeBox(BaseOperator): ...@@ -987,6 +1152,17 @@ class NormalizeBox(BaseOperator):
gt_bbox[i][2] = gt_bbox[i][2] / width gt_bbox[i][2] = gt_bbox[i][2] / width
gt_bbox[i][3] = gt_bbox[i][3] / height gt_bbox[i][3] = gt_bbox[i][3] / height
sample['gt_bbox'] = gt_bbox sample['gt_bbox'] = gt_bbox
if 'gt_keypoint' in sample.keys():
gt_keypoint = sample['gt_keypoint']
for i in range(gt_keypoint.shape[1]):
if i % 2:
gt_keypoint[:, i] = gt_keypoint[:, i] / height
else:
gt_keypoint[:, i] = gt_keypoint[:, i] / width
sample['gt_keypoint'] = gt_keypoint
return sample return sample
...@@ -998,7 +1174,6 @@ class Permute(BaseOperator): ...@@ -998,7 +1174,6 @@ class Permute(BaseOperator):
Args: Args:
to_bgr (bool): confirm whether to convert RGB to BGR to_bgr (bool): confirm whether to convert RGB to BGR
channel_first (bool): confirm whether to change channel channel_first (bool): confirm whether to change channel
""" """
super(Permute, self).__init__() super(Permute, self).__init__()
self.to_bgr = to_bgr self.to_bgr = to_bgr
...@@ -1094,6 +1269,84 @@ class MixupImage(BaseOperator): ...@@ -1094,6 +1269,84 @@ class MixupImage(BaseOperator):
return sample return sample
@register_op
class CutmixImage(BaseOperator):
def __init__(self, alpha=1.5, beta=1.5):
"""
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, see https://https://arxiv.org/abs/1905.04899
Cutmix image and gt_bbbox/gt_score
Args:
alpha (float): alpha parameter of beta distribute
beta (float): beta parameter of beta distribute
"""
super(CutmixImage, self).__init__()
self.alpha = alpha
self.beta = beta
if self.alpha <= 0.0:
raise ValueError("alpha shold be positive in {}".format(self))
if self.beta <= 0.0:
raise ValueError("beta shold be positive in {}".format(self))
def _rand_bbox(self, img1, img2, factor):
""" _rand_bbox """
h = max(img1.shape[0], img2.shape[0])
w = max(img1.shape[1], img2.shape[1])
cut_rat = np.sqrt(1. - factor)
cut_w = np.int(w * cut_rat)
cut_h = np.int(h * cut_rat)
# uniform
cx = np.random.randint(w)
cy = np.random.randint(h)
bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)
bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)
img_1 = np.zeros((h, w, img1.shape[2]), 'float32')
img_1[:img1.shape[0], :img1.shape[1], :] = \
img1.astype('float32')
img_2 = np.zeros((h, w, img2.shape[2]), 'float32')
img_2[:img2.shape[0], :img2.shape[1], :] = \
img2.astype('float32')
img_1[bby1:bby2, bbx1:bbx2, :] = img2[bby1:bby2, bbx1:bbx2, :]
return img_1
def __call__(self, sample, context=None):
if 'cutmix' not in sample:
return sample
factor = np.random.beta(self.alpha, self.beta)
factor = max(0.0, min(1.0, factor))
if factor >= 1.0:
sample.pop('cutmix')
return sample
if factor <= 0.0:
return sample['cutmix']
img1 = sample['image']
img2 = sample['cutmix']['image']
img = self._rand_bbox(img1, img2, factor)
gt_bbox1 = sample['gt_bbox']
gt_bbox2 = sample['cutmix']['gt_bbox']
gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
gt_class1 = sample['gt_class']
gt_class2 = sample['cutmix']['gt_class']
gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
gt_score1 = sample['gt_score']
gt_score2 = sample['cutmix']['gt_score']
gt_score = np.concatenate(
(gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
sample['image'] = img
sample['gt_bbox'] = gt_bbox
sample['gt_score'] = gt_score
sample['gt_class'] = gt_class
sample['h'] = img.shape[0]
sample['w'] = img.shape[1]
sample.pop('cutmix')
return sample
@register_op @register_op
class RandomInterpImage(BaseOperator): class RandomInterpImage(BaseOperator):
def __init__(self, target_size=0, max_size=0): def __init__(self, target_size=0, max_size=0):
...@@ -1129,7 +1382,6 @@ class RandomInterpImage(BaseOperator): ...@@ -1129,7 +1382,6 @@ class RandomInterpImage(BaseOperator):
@register_op @register_op
class Resize(BaseOperator): class Resize(BaseOperator):
"""Resize image and bbox. """Resize image and bbox.
Args: Args:
target_dim (int or list): target size, can be a single number or a list target_dim (int or list): target size, can be a single number or a list
(for random shape). (for random shape).
...@@ -1162,6 +1414,7 @@ class Resize(BaseOperator): ...@@ -1162,6 +1414,7 @@ class Resize(BaseOperator):
scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32) scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32)
sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0, sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0,
dim - 1) dim - 1)
sample['scale_factor'] = [scale_x, scale_y] * 2
sample['h'] = resize_h sample['h'] = resize_h
sample['w'] = resize_w sample['w'] = resize_w
...@@ -1173,7 +1426,6 @@ class Resize(BaseOperator): ...@@ -1173,7 +1426,6 @@ class Resize(BaseOperator):
@register_op @register_op
class ColorDistort(BaseOperator): class ColorDistort(BaseOperator):
"""Random color distortion. """Random color distortion.
Args: Args:
hue (list): hue settings. hue (list): hue settings.
in [lower, upper, probability] format. in [lower, upper, probability] format.
...@@ -1185,6 +1437,8 @@ class ColorDistort(BaseOperator): ...@@ -1185,6 +1437,8 @@ class ColorDistort(BaseOperator):
in [lower, upper, probability] format. in [lower, upper, probability] format.
random_apply (bool): whether to apply in random (yolo) or fixed (SSD) random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
order. order.
hsv_format (bool): whether to convert color from BGR to HSV
random_channel (bool): whether to swap channels randomly
""" """
def __init__(self, def __init__(self,
...@@ -1192,13 +1446,17 @@ class ColorDistort(BaseOperator): ...@@ -1192,13 +1446,17 @@ class ColorDistort(BaseOperator):
saturation=[0.5, 1.5, 0.5], saturation=[0.5, 1.5, 0.5],
contrast=[0.5, 1.5, 0.5], contrast=[0.5, 1.5, 0.5],
brightness=[0.5, 1.5, 0.5], brightness=[0.5, 1.5, 0.5],
random_apply=True): random_apply=True,
hsv_format=False,
random_channel=False):
super(ColorDistort, self).__init__() super(ColorDistort, self).__init__()
self.hue = hue self.hue = hue
self.saturation = saturation self.saturation = saturation
self.contrast = contrast self.contrast = contrast
self.brightness = brightness self.brightness = brightness
self.random_apply = random_apply self.random_apply = random_apply
self.hsv_format = hsv_format
self.random_channel = random_channel
def apply_hue(self, img): def apply_hue(self, img):
low, high, prob = self.hue low, high, prob = self.hue
...@@ -1206,6 +1464,11 @@ class ColorDistort(BaseOperator): ...@@ -1206,6 +1464,11 @@ class ColorDistort(BaseOperator):
return img return img
img = img.astype(np.float32) img = img.astype(np.float32)
if self.hsv_format:
img[..., 0] += random.uniform(low, high)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
return img
# XXX works, but result differ from HSV version # XXX works, but result differ from HSV version
delta = np.random.uniform(low, high) delta = np.random.uniform(low, high)
...@@ -1225,8 +1488,10 @@ class ColorDistort(BaseOperator): ...@@ -1225,8 +1488,10 @@ class ColorDistort(BaseOperator):
if np.random.uniform(0., 1.) < prob: if np.random.uniform(0., 1.) < prob:
return img return img
delta = np.random.uniform(low, high) delta = np.random.uniform(low, high)
img = img.astype(np.float32) img = img.astype(np.float32)
if self.hsv_format:
img[..., 1] *= delta
return img
gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32) gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
gray = gray.sum(axis=2, keepdims=True) gray = gray.sum(axis=2, keepdims=True)
gray *= (1.0 - delta) gray *= (1.0 - delta)
...@@ -1273,12 +1538,24 @@ class ColorDistort(BaseOperator): ...@@ -1273,12 +1538,24 @@ class ColorDistort(BaseOperator):
if np.random.randint(0, 2): if np.random.randint(0, 2):
img = self.apply_contrast(img) img = self.apply_contrast(img)
if self.hsv_format:
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
img = self.apply_saturation(img) img = self.apply_saturation(img)
img = self.apply_hue(img) img = self.apply_hue(img)
if self.hsv_format:
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
else: else:
if self.hsv_format:
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
img = self.apply_saturation(img) img = self.apply_saturation(img)
img = self.apply_hue(img) img = self.apply_hue(img)
if self.hsv_format:
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
img = self.apply_contrast(img) img = self.apply_contrast(img)
if self.random_channel:
if np.random.randint(0, 2):
img = img[..., np.random.permutation(3)]
sample['image'] = img sample['image'] = img
return sample return sample
...@@ -1346,7 +1623,6 @@ class CornerRandColor(ColorDistort): ...@@ -1346,7 +1623,6 @@ class CornerRandColor(ColorDistort):
@register_op @register_op
class NormalizePermute(BaseOperator): class NormalizePermute(BaseOperator):
"""Normalize and permute channel order. """Normalize and permute channel order.
Args: Args:
mean (list): mean values in RGB order. mean (list): mean values in RGB order.
std (list): std values in RGB order. std (list): std values in RGB order.
...@@ -1376,7 +1652,6 @@ class NormalizePermute(BaseOperator): ...@@ -1376,7 +1652,6 @@ class NormalizePermute(BaseOperator):
@register_op @register_op
class RandomExpand(BaseOperator): class RandomExpand(BaseOperator):
"""Random expand the canvas. """Random expand the canvas.
Args: Args:
ratio (float): maximum expansion ratio. ratio (float): maximum expansion ratio.
prob (float): probability to expand. prob (float): probability to expand.
...@@ -1468,7 +1743,6 @@ class RandomExpand(BaseOperator): ...@@ -1468,7 +1743,6 @@ class RandomExpand(BaseOperator):
@register_op @register_op
class RandomCrop(BaseOperator): class RandomCrop(BaseOperator):
"""Random crop image and bboxes. """Random crop image and bboxes.
Args: Args:
aspect_ratio (list): aspect ratio of cropped region. aspect_ratio (list): aspect ratio of cropped region.
in [min, max] format. in [min, max] format.
...@@ -1595,11 +1869,23 @@ class RandomCrop(BaseOperator): ...@@ -1595,11 +1869,23 @@ class RandomCrop(BaseOperator):
found = False found = False
for i in range(self.num_attempts): for i in range(self.num_attempts):
scale = np.random.uniform(*self.scaling) scale = np.random.uniform(*self.scaling)
if self.aspect_ratio is not None:
min_ar, max_ar = self.aspect_ratio min_ar, max_ar = self.aspect_ratio
aspect_ratio = np.random.uniform( aspect_ratio = np.random.uniform(
max(min_ar, scale**2), min(max_ar, scale**-2)) max(min_ar, scale**2), min(max_ar, scale**-2))
crop_h = int(h * scale / np.sqrt(aspect_ratio)) h_scale = scale / np.sqrt(aspect_ratio)
crop_w = int(w * scale * np.sqrt(aspect_ratio)) w_scale = scale * np.sqrt(aspect_ratio)
else:
h_scale = np.random.uniform(*self.scaling)
w_scale = np.random.uniform(*self.scaling)
crop_h = h * h_scale
crop_w = w * w_scale
if self.aspect_ratio is None:
if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
continue
crop_h = int(crop_h)
crop_w = int(crop_w)
crop_y = np.random.randint(0, h - crop_h) crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w) crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
...@@ -1751,7 +2037,6 @@ class BboxXYXY2XYWH(BaseOperator): ...@@ -1751,7 +2037,6 @@ class BboxXYXY2XYWH(BaseOperator):
return sample return sample
@register_op
class Lighting(BaseOperator): class Lighting(BaseOperator):
""" """
Lighting the imagen by eigenvalues and eigenvectors Lighting the imagen by eigenvalues and eigenvectors
...@@ -1991,7 +2276,6 @@ class CornerRatio(BaseOperator): ...@@ -1991,7 +2276,6 @@ class CornerRatio(BaseOperator):
class RandomScaledCrop(BaseOperator): class RandomScaledCrop(BaseOperator):
"""Resize image and bbox based on long side (with optional random scaling), """Resize image and bbox based on long side (with optional random scaling),
then crop or pad image to target size. then crop or pad image to target size.
Args: Args:
target_dim (int): target size. target_dim (int): target size.
scale_range (list): random scale range. scale_range (list): random scale range.
...@@ -2046,7 +2330,6 @@ class RandomScaledCrop(BaseOperator): ...@@ -2046,7 +2330,6 @@ class RandomScaledCrop(BaseOperator):
@register_op @register_op
class ResizeAndPad(BaseOperator): class ResizeAndPad(BaseOperator):
"""Resize image and bbox, then pad image to target size. """Resize image and bbox, then pad image to target size.
Args: Args:
target_dim (int): target size target_dim (int): target size
interp (int): interpolation method, default to `cv2.INTER_LINEAR`. interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
...@@ -2085,7 +2368,6 @@ class ResizeAndPad(BaseOperator): ...@@ -2085,7 +2368,6 @@ class ResizeAndPad(BaseOperator):
@register_op @register_op
class TargetAssign(BaseOperator): class TargetAssign(BaseOperator):
"""Assign regression target and labels. """Assign regression target and labels.
Args: Args:
image_size (int or list): input image size, a single integer or list of image_size (int or list): input image size, a single integer or list of
[h, w]. Default: 512 [h, w]. Default: 512
...@@ -2184,3 +2466,69 @@ class TargetAssign(BaseOperator): ...@@ -2184,3 +2466,69 @@ class TargetAssign(BaseOperator):
targets[matched_indices] = matched_targets targets[matched_indices] = matched_targets
sample['fg_num'] = np.array(len(matched_targets), dtype=np.int32) sample['fg_num'] = np.array(len(matched_targets), dtype=np.int32)
return sample return sample
@register_op
class DebugVisibleImage(BaseOperator):
"""
In debug mode, visualize images according to `gt_box`.
(Currently only supported when not cropping and flipping image.)
"""
def __init__(self, output_dir='output/debug', is_normalized=False):
super(DebugVisibleImage, self).__init__()
self.is_normalized = is_normalized
self.output_dir = output_dir
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
if not isinstance(self.is_normalized, bool):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
image = Image.open(sample['im_file']).convert('RGB')
out_file_name = sample['im_file'].split('/')[-1]
width = sample['w']
height = sample['h']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
draw = ImageDraw.Draw(image)
for i in range(gt_bbox.shape[0]):
if self.is_normalized:
gt_bbox[i][0] = gt_bbox[i][0] * width
gt_bbox[i][1] = gt_bbox[i][1] * height
gt_bbox[i][2] = gt_bbox[i][2] * width
gt_bbox[i][3] = gt_bbox[i][3] * height
xmin, ymin, xmax, ymax = gt_bbox[i]
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
(xmin, ymin)],
width=2,
fill='green')
# draw label
text = str(gt_class[i][0])
tw, th = draw.textsize(text)
draw.rectangle(
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill='green')
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
if 'gt_keypoint' in sample.keys():
gt_keypoint = sample['gt_keypoint']
if self.is_normalized:
for i in range(gt_keypoint.shape[1]):
if i % 2:
gt_keypoint[:, i] = gt_keypoint[:, i] * height
else:
gt_keypoint[:, i] = gt_keypoint[:, i] * width
for i in range(gt_keypoint.shape[0]):
keypoint = gt_keypoint[i]
for j in range(int(keypoint.shape[0] / 2)):
x1 = round(keypoint[2 * j]).astype(np.int32)
y1 = round(keypoint[2 * j + 1]).astype(np.int32)
draw.ellipse(
(x1, y1, x1 + 5, y1i + 5),
fill='green',
outline='green')
save_path = os.path.join(self.output_dir, out_file_name)
image.save(save_path, quality=95)
return sample
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册