未验证 提交 897d86ac 编写于 作者: S sunxl1988 提交者: GitHub

test=dygraph sync reader from static ppdet (#1084)

sync reader from static ppdet
上级 ce71cdc2
...@@ -16,6 +16,7 @@ from __future__ import absolute_import ...@@ -16,6 +16,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import copy import copy
import functools import functools
import collections import collections
...@@ -167,6 +168,8 @@ class Reader(object): ...@@ -167,6 +168,8 @@ class Reader(object):
Default True. Default True.
mixup_epoch (int): mixup epoc number. Default is -1, meaning mixup_epoch (int): mixup epoc number. Default is -1, meaning
not use mixup. not use mixup.
cutmix_epoch (int): cutmix epoc number. Default is -1, meaning
not use cutmix.
class_aware_sampling (bool): whether use class-aware sampling or not. class_aware_sampling (bool): whether use class-aware sampling or not.
Default False. Default False.
worker_num (int): number of working threads/processes. worker_num (int): number of working threads/processes.
...@@ -191,6 +194,7 @@ class Reader(object): ...@@ -191,6 +194,7 @@ class Reader(object):
drop_last=False, drop_last=False,
drop_empty=True, drop_empty=True,
mixup_epoch=-1, mixup_epoch=-1,
cutmix_epoch=-1,
class_aware_sampling=False, class_aware_sampling=False,
worker_num=-1, worker_num=-1,
use_process=False, use_process=False,
...@@ -241,6 +245,7 @@ class Reader(object): ...@@ -241,6 +245,7 @@ class Reader(object):
# sampling # sampling
self._mixup_epoch = mixup_epoch self._mixup_epoch = mixup_epoch
self._cutmix_epoch = cutmix_epoch
self._class_aware_sampling = class_aware_sampling self._class_aware_sampling = class_aware_sampling
self._load_img = False self._load_img = False
...@@ -253,6 +258,8 @@ class Reader(object): ...@@ -253,6 +258,8 @@ class Reader(object):
self._pos = -1 self._pos = -1
self._epoch = -1 self._epoch = -1
self._curr_iter = 0
# multi-process # multi-process
self._worker_num = worker_num self._worker_num = worker_num
self._parallel = None self._parallel = None
...@@ -274,6 +281,11 @@ class Reader(object): ...@@ -274,6 +281,11 @@ class Reader(object):
def reset(self): def reset(self):
"""implementation of Dataset.reset """implementation of Dataset.reset
""" """
if self._epoch < 0:
self._epoch = 0
else:
self._epoch += 1
self.indexes = [i for i in range(self.size())] self.indexes = [i for i in range(self.size())]
if self._class_aware_sampling: if self._class_aware_sampling:
self.indexes = np.random.choice( self.indexes = np.random.choice(
...@@ -283,17 +295,18 @@ class Reader(object): ...@@ -283,17 +295,18 @@ class Reader(object):
p=self.img_weights) p=self.img_weights)
if self._shuffle: if self._shuffle:
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
np.random.seed(self._epoch + trainer_id)
np.random.shuffle(self.indexes) np.random.shuffle(self.indexes)
if self._mixup_epoch > 0 and len(self.indexes) < 2: if self._mixup_epoch > 0 and len(self.indexes) < 2:
logger.debug("Disable mixup for dataset samples " logger.debug("Disable mixup for dataset samples "
"less than 2 samples") "less than 2 samples")
self._mixup_epoch = -1 self._mixup_epoch = -1
if self._cutmix_epoch > 0 and len(self.indexes) < 2:
if self._epoch < 0: logger.info("Disable cutmix for dataset samples "
self._epoch = 0 "less than 2 samples")
else: self._cutmix_epoch = -1
self._epoch += 1
self._pos = 0 self._pos = 0
...@@ -306,6 +319,7 @@ class Reader(object): ...@@ -306,6 +319,7 @@ class Reader(object):
if self.drained(): if self.drained():
raise StopIteration raise StopIteration
batch = self._load_batch() batch = self._load_batch()
self._curr_iter += 1
if self._drop_last and len(batch) < self._batch_size: if self._drop_last and len(batch) < self._batch_size:
raise StopIteration raise StopIteration
if self._worker_num > -1: if self._worker_num > -1:
...@@ -321,6 +335,7 @@ class Reader(object): ...@@ -321,6 +335,7 @@ class Reader(object):
break break
pos = self.indexes[self._pos] pos = self.indexes[self._pos]
sample = copy.deepcopy(self._roidbs[pos]) sample = copy.deepcopy(self._roidbs[pos])
sample["curr_iter"] = self._curr_iter
self._pos += 1 self._pos += 1
if self._drop_empty and self._fields and 'gt_mask' in self._fields: if self._drop_empty and self._fields and 'gt_mask' in self._fields:
...@@ -343,9 +358,18 @@ class Reader(object): ...@@ -343,9 +358,18 @@ class Reader(object):
mix_idx = np.random.randint(1, num) mix_idx = np.random.randint(1, num)
mix_idx = self.indexes[(mix_idx + self._pos - 1) % num] mix_idx = self.indexes[(mix_idx + self._pos - 1) % num]
sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx]) sample['mixup'] = copy.deepcopy(self._roidbs[mix_idx])
sample['mixup']["curr_iter"] = self._curr_iter
if self._load_img: if self._load_img:
sample['mixup']['image'] = self._load_image(sample['mixup'][ sample['mixup']['image'] = self._load_image(sample['mixup'][
'im_file']) 'im_file'])
if self._epoch < self._cutmix_epoch:
num = len(self.indexes)
mix_idx = np.random.randint(1, num)
sample['cutmix'] = copy.deepcopy(self._roidbs[mix_idx])
sample['cutmix']["curr_iter"] = self._curr_iter
if self._load_img:
sample['cutmix']['image'] = self._load_image(sample[
'cutmix']['im_file'])
batch.append(sample) batch.append(sample)
bs += 1 bs += 1
......
...@@ -41,7 +41,8 @@ class WIDERFaceDataSet(DataSet): ...@@ -41,7 +41,8 @@ class WIDERFaceDataSet(DataSet):
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
sample_num=-1, sample_num=-1,
with_background=True): with_background=True,
with_lmk=False):
super(WIDERFaceDataSet, self).__init__( super(WIDERFaceDataSet, self).__init__(
image_dir=image_dir, image_dir=image_dir,
anno_path=anno_path, anno_path=anno_path,
...@@ -53,6 +54,7 @@ class WIDERFaceDataSet(DataSet): ...@@ -53,6 +54,7 @@ class WIDERFaceDataSet(DataSet):
self.with_background = with_background self.with_background = with_background
self.roidbs = None self.roidbs = None
self.cname2cid = None self.cname2cid = None
self.with_lmk = with_lmk
def load_roidb_and_cname2cid(self): def load_roidb_and_cname2cid(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path) anno_path = os.path.join(self.dataset_dir, self.anno_path)
...@@ -62,33 +64,23 @@ class WIDERFaceDataSet(DataSet): ...@@ -62,33 +64,23 @@ class WIDERFaceDataSet(DataSet):
records = [] records = []
ct = 0 ct = 0
file_lists = _load_file_list(txt_file) file_lists = self._load_file_list(txt_file)
cname2cid = widerface_label(self.with_background) cname2cid = widerface_label(self.with_background)
for item in file_lists: for item in file_lists:
im_fname = item[0] im_fname = item[0]
im_id = np.array([ct]) im_id = np.array([ct])
gt_bbox = np.zeros((len(item) - 2, 4), dtype=np.float32) gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32)
gt_class = np.ones((len(item) - 2, 1), dtype=np.int32) gt_class = np.ones((len(item) - 1, 1), dtype=np.int32)
gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32)
lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32)
for index_box in range(len(item)): for index_box in range(len(item)):
if index_box >= 2: if index_box < 1:
temp_info_box = item[index_box].split(' ') continue
xmin = float(temp_info_box[0]) gt_bbox[index_box - 1] = item[index_box][0]
ymin = float(temp_info_box[1]) if self.with_lmk:
w = float(temp_info_box[2]) gt_lmk_labels[index_box - 1] = item[index_box][1]
h = float(temp_info_box[3]) lmk_ignore_flag[index_box - 1] = item[index_box][2]
# Filter out wrong labels
if w < 0 or h < 0:
logger.warn('Illegal box with w: {}, h: {} in '
'img: {}, and it will be ignored'.format(
w, h, im_fname))
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox[index_box - 2] = [xmin, ymin, xmax, ymax]
im_fname = os.path.join(image_dir, im_fname = os.path.join(image_dir,
im_fname) if image_dir else im_fname im_fname) if image_dir else im_fname
widerface_rec = { widerface_rec = {
...@@ -97,7 +89,10 @@ class WIDERFaceDataSet(DataSet): ...@@ -97,7 +89,10 @@ class WIDERFaceDataSet(DataSet):
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_class': gt_class, 'gt_class': gt_class,
} }
# logger.debug if self.with_lmk:
widerface_rec['gt_keypoint'] = gt_lmk_labels
widerface_rec['keypoint_ignore'] = lmk_ignore_flag
if len(item) != 0: if len(item) != 0:
records.append(widerface_rec) records.append(widerface_rec)
...@@ -108,34 +103,64 @@ class WIDERFaceDataSet(DataSet): ...@@ -108,34 +103,64 @@ class WIDERFaceDataSet(DataSet):
logger.debug('{} samples in file {}'.format(ct, anno_path)) logger.debug('{} samples in file {}'.format(ct, anno_path))
self.roidbs, self.cname2cid = records, cname2cid self.roidbs, self.cname2cid = records, cname2cid
def _load_file_list(self, input_txt):
def _load_file_list(input_txt): with open(input_txt, 'r') as f_dir:
with open(input_txt, 'r') as f_dir: lines_input_txt = f_dir.readlines()
lines_input_txt = f_dir.readlines()
file_dict = {}
file_dict = {} num_class = 0
num_class = 0 for i in range(len(lines_input_txt)):
for i in range(len(lines_input_txt)): line_txt = lines_input_txt[i].strip('\n\t\r')
line_txt = lines_input_txt[i].strip('\n\t\r') if '.jpg' in line_txt:
if '.jpg' in line_txt: if i != 0:
if i != 0: num_class += 1
num_class += 1 file_dict[num_class] = []
file_dict[num_class] = []
file_dict[num_class].append(line_txt)
if '.jpg' not in line_txt:
if len(line_txt) > 6:
split_str = line_txt.split(' ')
x1_min = float(split_str[0])
y1_min = float(split_str[1])
x2_max = float(split_str[2])
y2_max = float(split_str[3])
line_txt = str(x1_min) + ' ' + str(y1_min) + ' ' + str(
x2_max) + ' ' + str(y2_max)
file_dict[num_class].append(line_txt) file_dict[num_class].append(line_txt)
else: if '.jpg' not in line_txt:
file_dict[num_class].append(line_txt) if len(line_txt) <= 6:
continue
return list(file_dict.values()) result_boxs = []
split_str = line_txt.split(' ')
xmin = float(split_str[0])
ymin = float(split_str[1])
w = float(split_str[2])
h = float(split_str[3])
# Filter out wrong labels
if w < 0 or h < 0:
logger.warn('Illegal box with w: {}, h: {} in '
'img: {}, and it will be ignored'.format(
w, h, file_dict[num_class][0]))
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox = [xmin, ymin, xmax, ymax]
result_boxs.append(gt_bbox)
if self.with_lmk:
assert len(split_str) > 18, 'When `with_lmk=True`, the number' \
'of characters per line in the annotation file should' \
'exceed 18.'
lmk0_x = float(split_str[5])
lmk0_y = float(split_str[6])
lmk1_x = float(split_str[8])
lmk1_y = float(split_str[9])
lmk2_x = float(split_str[11])
lmk2_y = float(split_str[12])
lmk3_x = float(split_str[14])
lmk3_y = float(split_str[15])
lmk4_x = float(split_str[17])
lmk4_y = float(split_str[18])
lmk_ignore_flag = 0 if lmk0_x == -1 else 1
gt_lmk_label = [
lmk0_x, lmk0_y, lmk1_x, lmk1_y, lmk2_x, lmk2_y, lmk3_x,
lmk3_y, lmk4_x, lmk4_y
]
result_boxs.append(gt_lmk_label)
result_boxs.append(lmk_ignore_flag)
file_dict[num_class].append(result_boxs)
return list(file_dict.values())
def widerface_label(with_background=True): def widerface_label(with_background=True):
......
...@@ -26,13 +26,17 @@ import cv2 ...@@ -26,13 +26,17 @@ import cv2
import numpy as np import numpy as np
from .operators import register_op, BaseOperator from .operators import register_op, BaseOperator
from .op_helper import jaccard_overlap from .op_helper import jaccard_overlap, gaussian2D
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget', 'PadBatch',
'Gt2FCOSTarget' 'RandomShape',
'PadMultiScaleTest',
'Gt2YoloTarget',
'Gt2FCOSTarget',
'Gt2TTFTarget',
] ]
...@@ -41,17 +45,15 @@ class PadBatch(BaseOperator): ...@@ -41,17 +45,15 @@ class PadBatch(BaseOperator):
""" """
Pad a batch of samples so they can be divisible by a stride. Pad a batch of samples so they can be divisible by a stride.
The layout of each image should be 'CHW'. The layout of each image should be 'CHW'.
Args: Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
height and width is divisible by `pad_to_stride`. height and width is divisible by `pad_to_stride`.
""" """
def __init__(self, pad_to_stride=0, use_padded_im_info=True, pad_gt=False): def __init__(self, pad_to_stride=0, use_padded_im_info=True):
super(PadBatch, self).__init__() super(PadBatch, self).__init__()
self.pad_to_stride = pad_to_stride self.pad_to_stride = pad_to_stride
self.use_padded_im_info = use_padded_im_info self.use_padded_im_info = use_padded_im_info
self.pad_gt = pad_gt
def __call__(self, samples, context=None): def __call__(self, samples, context=None):
""" """
...@@ -61,9 +63,9 @@ class PadBatch(BaseOperator): ...@@ -61,9 +63,9 @@ class PadBatch(BaseOperator):
coarsest_stride = self.pad_to_stride coarsest_stride = self.pad_to_stride
if coarsest_stride == 0: if coarsest_stride == 0:
return samples return samples
max_shape = np.array([data['image'].shape for data in samples]).max( max_shape = np.array([data['image'].shape for data in samples]).max(
axis=0) axis=0)
if coarsest_stride > 0: if coarsest_stride > 0:
max_shape[1] = int( max_shape[1] = int(
np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride) np.ceil(max_shape[1] / coarsest_stride) * coarsest_stride)
...@@ -80,52 +82,6 @@ class PadBatch(BaseOperator): ...@@ -80,52 +82,6 @@ class PadBatch(BaseOperator):
data['image'] = padding_im data['image'] = padding_im
if self.use_padded_im_info: if self.use_padded_im_info:
data['im_info'][:2] = max_shape[1:3] data['im_info'][:2] = max_shape[1:3]
if self.pad_gt:
gt_num = []
if data['gt_poly'] is not None and len(data['gt_poly']) > 0:
pad_mask = True
else:
pad_mask = False
if pad_mask:
poly_num = []
poly_part_num = []
point_num = []
for data in samples:
gt_num.append(data['gt_bbox'].shape[0])
if pad_mask:
poly_num.append(len(data['gt_poly']))
for poly in data['gt_poly']:
poly_part_num.append(int(len(poly)))
for p_p in poly:
point_num.append(int(len(p_p) / 2))
gt_num_max = max(gt_num)
gt_box_data = np.zeros([gt_num_max, 4])
gt_class_data = np.zeros([gt_num_max])
is_crowd_data = np.ones([gt_num_max])
if pad_mask:
poly_num_max = max(poly_num)
poly_part_num_max = max(poly_part_num)
point_num_max = max(point_num)
gt_masks_data = -np.ones(
[poly_num_max, poly_part_num_max, point_num_max, 2])
for i, data in enumerate(samples):
gt_num = data['gt_bbox'].shape[0]
gt_box_data[0:gt_num, :] = data['gt_bbox']
gt_class_data[0:gt_num] = np.squeeze(data['gt_class'])
is_crowd_data[0:gt_num] = np.squeeze(data['is_crowd'])
if pad_mask:
for j, poly in enumerate(data['gt_poly']):
for k, p_p in enumerate(poly):
pp_np = np.array(p_p).reshape(-1, 2)
gt_masks_data[j, k, :pp_np.shape[0], :] = pp_np
data['gt_poly'] = gt_masks_data
data['gt_bbox'] = gt_box_data
data['gt_class'] = gt_class_data
data['is_crowd_data'] = is_crowd_data
return samples return samples
...@@ -136,13 +92,12 @@ class RandomShape(BaseOperator): ...@@ -136,13 +92,12 @@ class RandomShape(BaseOperator):
select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR, select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR,
cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is
False, use cv2.INTER_NEAREST. False, use cv2.INTER_NEAREST.
Args: Args:
sizes (list): list of int, random choose a size from these sizes (list): list of int, random choose a size from these
random_inter (bool): whether to randomly interpolation, defalut true. random_inter (bool): whether to randomly interpolation, defalut true.
""" """
def __init__(self, sizes=[], random_inter=False): def __init__(self, sizes=[], random_inter=False, resize_box=False):
super(RandomShape, self).__init__() super(RandomShape, self).__init__()
self.sizes = sizes self.sizes = sizes
self.random_inter = random_inter self.random_inter = random_inter
...@@ -153,6 +108,7 @@ class RandomShape(BaseOperator): ...@@ -153,6 +108,7 @@ class RandomShape(BaseOperator):
cv2.INTER_CUBIC, cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4, cv2.INTER_LANCZOS4,
] if random_inter else [] ] if random_inter else []
self.resize_box = resize_box
def __call__(self, samples, context=None): def __call__(self, samples, context=None):
shape = np.random.choice(self.sizes) shape = np.random.choice(self.sizes)
...@@ -166,6 +122,12 @@ class RandomShape(BaseOperator): ...@@ -166,6 +122,12 @@ class RandomShape(BaseOperator):
im = cv2.resize( im = cv2.resize(
im, None, None, fx=scale_x, fy=scale_y, interpolation=method) im, None, None, fx=scale_x, fy=scale_y, interpolation=method)
samples[i]['image'] = im samples[i]['image'] = im
if self.resize_box and 'gt_bbox' in samples[i] and len(samples[0][
'gt_bbox']) > 0:
scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32)
samples[i]['gt_bbox'] = np.clip(samples[i]['gt_bbox'] *
scale_array, 0,
float(shape) - 1)
return samples return samples
...@@ -525,3 +487,99 @@ class Gt2FCOSTarget(BaseOperator): ...@@ -525,3 +487,99 @@ class Gt2FCOSTarget(BaseOperator):
sample['centerness{}'.format(lvl)] = np.reshape( sample['centerness{}'.format(lvl)] = np.reshape(
ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1]) ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
return samples return samples
@register_op
class Gt2TTFTarget(BaseOperator):
"""
Gt2TTFTarget
Generate TTFNet targets by ground truth data
Args:
num_classes(int): the number of classes.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
alpha(float): the alpha parameter to generate gaussian target.
0.54 by default.
"""
def __init__(self, num_classes, down_ratio=4, alpha=0.54):
super(Gt2TTFTarget, self).__init__()
self.down_ratio = down_ratio
self.num_classes = num_classes
self.alpha = alpha
def __call__(self, samples, context=None):
output_size = samples[0]['image'].shape[1]
feat_size = output_size // self.down_ratio
for sample in samples:
heatmap = np.zeros(
(self.num_classes, feat_size, feat_size), dtype='float32')
box_target = np.ones(
(4, feat_size, feat_size), dtype='float32') * -1
reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32')
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1
bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1
area = bbox_w * bbox_h
boxes_areas_log = np.log(area)
boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1]
boxes_area_topk_log = boxes_areas_log[boxes_ind]
gt_bbox = gt_bbox[boxes_ind]
gt_class = gt_class[boxes_ind]
feat_gt_bbox = gt_bbox / self.down_ratio
feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1)
feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1],
feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0])
ct_inds = np.stack(
[(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2,
(gt_bbox[:, 1] + gt_bbox[:, 3]) / 2],
axis=1) / self.down_ratio
h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32')
w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32')
for k in range(len(gt_bbox)):
cls_id = gt_class[k]
fake_heatmap = np.zeros((feat_size, feat_size), dtype='float32')
self.draw_truncate_gaussian(fake_heatmap, ct_inds[k],
h_radiuses_alpha[k],
w_radiuses_alpha[k])
heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap)
box_target_inds = fake_heatmap > 0
box_target[:, box_target_inds] = gt_bbox[k][:, None]
local_heatmap = fake_heatmap[box_target_inds]
ct_div = np.sum(local_heatmap)
local_heatmap *= boxes_area_topk_log[k]
reg_weight[0, box_target_inds] = local_heatmap / ct_div
sample['ttf_heatmap'] = heatmap
sample['ttf_box_target'] = box_target
sample['ttf_reg_weight'] = reg_weight
return samples
def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius):
h, w = 2 * h_radius + 1, 2 * w_radius + 1
sigma_x = w / 6
sigma_y = h / 6
gaussian = gaussian2D((h, w), sigma_x, sigma_y)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, w_radius), min(width - x, w_radius + 1)
top, bottom = min(y, h_radius), min(height - y, h_radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius -
left:w_radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
masked_heatmap, masked_gaussian)
return heatmap
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import numpy as np
from PIL import Image
class GridMask(object):
def __init__(self,
use_h=True,
use_w=True,
rotate=1,
offset=False,
ratio=0.5,
mode=1,
prob=0.7,
upper_iter=360000):
super(GridMask, self).__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.prob = prob
self.st_prob = prob
self.upper_iter = upper_iter
def __call__(self, x, curr_iter):
self.prob = self.st_prob * min(1, 1.0 * curr_iter / self.upper_iter)
if np.random.rand() > self.prob:
return x
_, h, w = x.shape
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(2, h)
self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + self.l, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + self.l, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask = Image.fromarray(np.uint8(mask))
mask = mask.rotate(r)
mask = np.asarray(mask)
mask = mask[(hh - h) // 2:(hh - h) // 2 + h, (ww - w) // 2:(ww - w) // 2
+ w].astype(np.float32)
if self.mode == 1:
mask = 1 - mask
mask = np.expand_dims(mask, axis=0)
if self.offset:
offset = (2 * (np.random.rand(h, w) - 0.5)).astype(np.float32)
x = (x * mask + offset * (1 - mask)).astype(x.dtype)
else:
x = (x * mask).astype(x.dtype)
return x
...@@ -61,10 +61,13 @@ def is_overlap(object_bbox, sample_bbox): ...@@ -61,10 +61,13 @@ def is_overlap(object_bbox, sample_bbox):
return True return True
def filter_and_process(sample_bbox, bboxes, labels, scores=None): def filter_and_process(sample_bbox, bboxes, labels, scores=None,
keypoints=None):
new_bboxes = [] new_bboxes = []
new_labels = [] new_labels = []
new_scores = [] new_scores = []
new_keypoints = []
new_kp_ignore = []
for i in range(len(bboxes)): for i in range(len(bboxes)):
new_bbox = [0, 0, 0, 0] new_bbox = [0, 0, 0, 0]
obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]] obj_bbox = [bboxes[i][0], bboxes[i][1], bboxes[i][2], bboxes[i][3]]
...@@ -84,9 +87,24 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None): ...@@ -84,9 +87,24 @@ def filter_and_process(sample_bbox, bboxes, labels, scores=None):
new_labels.append([labels[i][0]]) new_labels.append([labels[i][0]])
if scores is not None: if scores is not None:
new_scores.append([scores[i][0]]) new_scores.append([scores[i][0]])
if keypoints is not None:
sample_keypoint = keypoints[0][i]
for j in range(len(sample_keypoint)):
kp_len = sample_height if j % 2 else sample_width
sample_coord = sample_bbox[1] if j % 2 else sample_bbox[0]
sample_keypoint[j] = (
sample_keypoint[j] - sample_coord) / kp_len
sample_keypoint[j] = max(min(sample_keypoint[j], 1.0), 0.0)
new_keypoints.append(sample_keypoint)
new_kp_ignore.append(keypoints[1][i])
bboxes = np.array(new_bboxes) bboxes = np.array(new_bboxes)
labels = np.array(new_labels) labels = np.array(new_labels)
scores = np.array(new_scores) scores = np.array(new_scores)
if keypoints is not None:
keypoints = np.array(new_keypoints)
new_kp_ignore = np.array(new_kp_ignore)
return bboxes, labels, scores, (keypoints, new_kp_ignore)
return bboxes, labels, scores return bboxes, labels, scores
...@@ -420,7 +438,8 @@ def gaussian_radius(bbox_size, min_overlap): ...@@ -420,7 +438,8 @@ def gaussian_radius(bbox_size, min_overlap):
def draw_gaussian(heatmap, center, radius, k=1, delte=6): def draw_gaussian(heatmap, center, radius, k=1, delte=6):
diameter = 2 * radius + 1 diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / delte) sigma = diameter / delte
gaussian = gaussian2D((diameter, diameter), sigma_x=sigma, sigma_y=sigma)
x, y = center x, y = center
...@@ -435,10 +454,11 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6): ...@@ -435,10 +454,11 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6):
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
def gaussian2D(shape, sigma=1): def gaussian2D(shape, sigma_x=1, sigma_y=1):
m, n = [(ss - 1.) / 2. for ss in shape] m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1] y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
sigma_y)))
h[h < np.finfo(h.dtype).eps * h.max()] = 0 h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h return h
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册