From d43e6d9a95d6c879b55de1d8a7eb3a582cdd9650 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Mon, 13 Jul 2020 19:07:22 +0800 Subject: [PATCH] add ttfnet (#1054) --- configs/anchor_free/README.md | 20 +- configs/anchor_free/ttfnet_darknet.yml | 141 +++++++++ deploy/python/infer.py | 19 +- ppdet/data/transform/batch_operators.py | 117 ++++++- ppdet/data/transform/op_helper.py | 8 +- ppdet/data/transform/operators.py | 65 ++-- ppdet/modeling/anchor_heads/__init__.py | 2 + ppdet/modeling/anchor_heads/ttf_head.py | 383 +++++++++++++++++++++++ ppdet/modeling/architectures/__init__.py | 2 + ppdet/modeling/architectures/ttfnet.py | 132 ++++++++ ppdet/modeling/backbones/darknet.py | 6 +- ppdet/modeling/losses/giou_loss.py | 42 ++- ppdet/modeling/ops.py | 63 ++-- ppdet/utils/coco_eval.py | 1 + ppdet/utils/eval_utils.py | 2 + tools/export_model.py | 3 + 16 files changed, 932 insertions(+), 74 deletions(-) create mode 100644 configs/anchor_free/ttfnet_darknet.yml create mode 100644 ppdet/modeling/anchor_heads/ttf_head.py create mode 100644 ppdet/modeling/architectures/ttfnet.py diff --git a/configs/anchor_free/README.md b/configs/anchor_free/README.md index 4d2b39f80..a1f6994a1 100644 --- a/configs/anchor_free/README.md +++ b/configs/anchor_free/README.md @@ -12,10 +12,12 @@ ## 模型库与基线 下表中展示了PaddleDetection当前支持的网络结构,具体细节请参考[算法细节](#算法细节)。 -| | ResNet50 | ResNet50-vd | Hourglass104 | -|:------------------------:|:--------:|:--------------------------:|:------------------------:| -| [CornerNet-Squeeze](#CornerNet-Squeeze) | x | ✓ | ✓ | -| [FCOS](#FCOS) | ✓ | x | x | +| | ResNet50 | ResNet50-vd | Hourglass104 | DarkNet53 +|:------------------------:|:--------:|:-------------:|:-------------:|:-------------:| +| [CornerNet-Squeeze](#CornerNet-Squeeze) | x | ✓ | ✓ |x | +| [FCOS](#FCOS) | ✓ | x | x | x | +| [TTFNet](#TTFNet) | x | x | x | ✓ | + ### 模型库 @@ -31,6 +33,7 @@ | FCOS | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 39.8 | 18.85 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_1x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_r50_fpn_1x.yml) | | FCOS+multiscale_train | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 42.0 | 19.05 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_multiscale_2x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_r50_fpn_multiscale_2x.yml) | | FCOS+DCN | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 44.4 | 13.66 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_dcn_r50_fpn_1x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_dcn_r50_fpn_1x.yml) | +| TTFNet | DarkNet53 | 12 | [DarkNet53_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar) | 32.9 | 85.92 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ttfnet_darknet.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/ttfnet_darknet.yml) | **注意:** @@ -64,5 +67,14 @@ - 通过center-ness单层分支预测当前点是否是目标中心,消除低质量误检 +## TTFNet + +**简介:** [TTFNet](https://arxiv.org/abs/1909.00700)是一种用于实时目标检测且对训练时间友好的网络,对CenterNet收敛速度慢的问题进行改进,提出了利用高斯核生成训练样本的新方法,有效的消除了anchor-free head中存在的模糊性。同时简单轻量化的网络结构也易于进行任务扩展。 + +**特点:** + +- 结构简单,仅需要两个head检测目标位置和大小,并且去除了耗时的后处理操作 +- 训练时间短,基于DarkNet53的骨干网路,V100 8卡仅需要训练2个小时即可达到较好的模型效果 + ## 如何贡献代码 我们非常欢迎您可以为PaddleDetection中的Anchor Free检测模型提供代码,您可以提交PR供我们review;也十分感谢您的反馈,可以提交相应issue,我们会及时解答。 diff --git a/configs/anchor_free/ttfnet_darknet.yml b/configs/anchor_free/ttfnet_darknet.yml new file mode 100644 index 000000000..b3ec9aa03 --- /dev/null +++ b/configs/anchor_free/ttfnet_darknet.yml @@ -0,0 +1,141 @@ +architecture: TTFNet +use_gpu: true +max_iters: 15000 +log_smooth_window: 20 +save_dir: output +snapshot_iter: 1000 +metric: COCO +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar +weights: output/ttfnet_darknet/model_final +num_classes: 80 + +TTFNet: + backbone: DarkNet + ttf_head: TTFHead + +DarkNet: + norm_type: bn + norm_decay: 0.0004 + depth: 53 + freeze_at: 1 + +TTFHead: + head_conv: 128 + wh_conv: 64 + hm_head_conv_num: 2 + wh_head_conv_num: 2 + wh_offset_base: 16 + wh_loss: GiouLoss + +GiouLoss: + loss_weight: 5. + do_average: false + use_class_weight: false + +LearningRate: + base_lr: 0.015 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 11250 + - 13750 + - !LinearWarmup + start_factor: 0.2 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0004 + type: L2 + +TrainReader: + inputs_def: + fields: ['image', 'ttf_heatmap', 'ttf_box_target', 'ttf_reg_weight'] + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: true + - !Resize + target_dim: 512 + - !RandomFlipImage + prob: 0.5 + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + - !Permute + to_bgr: false + channel_first: true + batch_transforms: + - !Gt2TTFTarget + num_classes: 80 + down_ratio: 4 + - !PadBatch + pad_to_stride: 32 + batch_size: 12 + shuffle: true + worker_num: 8 + bufsize: 2 + use_process: true + +EvalReader: + inputs_def: + image_shape: [3, 512, 512] + fields: ['image', 'im_id', 'scale_factor'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !Resize + target_dim: 512 + - !NormalizeImage + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + is_scale: false + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 + drop_empty: false + worker_num: 8 + bufsize: 16 + +TestReader: + inputs_def: + image_shape: [3, 512, 512] + fields: ['image', 'im_id', 'scale_factor'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !Resize + interp: 1 + target_dim: 512 + - !NormalizeImage + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + is_scale: false + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 091fb724d..18876d1c5 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -115,8 +115,7 @@ class Resize(object): padding_im[:im_h, :im_w, :] = im im = padding_im - if self.arch in self.scale_set: - im_info['scale'] = im_scale_x + im_info['scale'] = [im_scale_x, im_scale_y] im_info['resize_shape'] = im.shape[:2] return im, im_info @@ -252,18 +251,23 @@ def create_inputs(im, im_info, model_arch='YOLO'): inputs['image'] = im origin_shape = list(im_info['origin_shape']) resize_shape = list(im_info['resize_shape']) - scale = im_info['scale'] + scale_x, scale_y = im_info['scale'] if 'YOLO' in model_arch: im_size = np.array([origin_shape]).astype('int32') inputs['im_size'] = im_size elif 'RetinaNet' in model_arch: + scale = scale_x im_info = np.array([resize_shape + [scale]]).astype('float32') inputs['im_info'] = im_info elif 'RCNN' in model_arch: + scale = scale_x im_info = np.array([resize_shape + [scale]]).astype('float32') im_shape = np.array([origin_shape + [1.]]).astype('float32') inputs['im_info'] = im_info inputs['im_shape'] = im_shape + elif 'TTF' in model_arch: + scale_factor = np.array([scale_x, scale_y] * 2).astype('float32') + inputs['scale_factor'] = scale_factor return inputs @@ -272,7 +276,7 @@ class Config(): Args: model_dir (str): root path of model.yml """ - support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face'] + support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face', 'TTF'] def __init__(self, model_dir): # parsing Yaml config for Preprocess @@ -298,9 +302,8 @@ class Config(): for support_model in self.support_models: if support_model in yml_conf['arch']: return True - raise ValueError( - "Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face". - format(yml_conf['arch'])) + raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[ + 'arch'], self.support_models)) def print_config(self): print('----------- Model Configuration -----------') @@ -450,7 +453,7 @@ class Detector(): np_boxes[:, 3] *= w np_boxes[:, 4] *= h np_boxes[:, 5] *= w - expect_boxes = np_boxes[:, 1] > threshold + expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1) np_boxes = np_boxes[expect_boxes, :] for box in np_boxes: print('class_id:{:d}, confidence:{:.2f},' diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 8da0b1e35..1bed5edaf 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -26,13 +26,17 @@ import cv2 import numpy as np from .operators import register_op, BaseOperator -from .op_helper import jaccard_overlap +from .op_helper import jaccard_overlap, gaussian2D logger = logging.getLogger(__name__) __all__ = [ - 'PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget', - 'Gt2FCOSTarget' + 'PadBatch', + 'RandomShape', + 'PadMultiScaleTest', + 'Gt2YoloTarget', + 'Gt2FCOSTarget', + 'Gt2TTFTarget', ] @@ -41,7 +45,6 @@ class PadBatch(BaseOperator): """ Pad a batch of samples so they can be divisible by a stride. The layout of each image should be 'CHW'. - Args: pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure height and width is divisible by `pad_to_stride`. @@ -89,13 +92,12 @@ class RandomShape(BaseOperator): select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is False, use cv2.INTER_NEAREST. - Args: sizes (list): list of int, random choose a size from these random_inter (bool): whether to randomly interpolation, defalut true. """ - def __init__(self, sizes=[], random_inter=False): + def __init__(self, sizes=[], random_inter=False, resize_box=False): super(RandomShape, self).__init__() self.sizes = sizes self.random_inter = random_inter @@ -106,6 +108,7 @@ class RandomShape(BaseOperator): cv2.INTER_CUBIC, cv2.INTER_LANCZOS4, ] if random_inter else [] + self.resize_box = resize_box def __call__(self, samples, context=None): shape = np.random.choice(self.sizes) @@ -119,6 +122,12 @@ class RandomShape(BaseOperator): im = cv2.resize( im, None, None, fx=scale_x, fy=scale_y, interpolation=method) samples[i]['image'] = im + if self.resize_box and 'gt_bbox' in samples[i] and len(samples[0][ + 'gt_bbox']) > 0: + scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32) + samples[i]['gt_bbox'] = np.clip(samples[i]['gt_bbox'] * + scale_array, 0, + float(shape) - 1) return samples @@ -478,3 +487,99 @@ class Gt2FCOSTarget(BaseOperator): sample['centerness{}'.format(lvl)] = np.reshape( ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1]) return samples + + +@register_op +class Gt2TTFTarget(BaseOperator): + """ + Gt2TTFTarget + Generate TTFNet targets by ground truth data + + Args: + num_classes(int): the number of classes. + down_ratio(int): the down ratio from images to heatmap, 4 by default. + alpha(float): the alpha parameter to generate gaussian target. + 0.54 by default. + """ + + def __init__(self, num_classes, down_ratio=4, alpha=0.54): + super(Gt2TTFTarget, self).__init__() + self.down_ratio = down_ratio + self.num_classes = num_classes + self.alpha = alpha + + def __call__(self, samples, context=None): + output_size = samples[0]['image'].shape[1] + feat_size = output_size // self.down_ratio + for sample in samples: + heatmap = np.zeros( + (self.num_classes, feat_size, feat_size), dtype='float32') + box_target = np.ones( + (4, feat_size, feat_size), dtype='float32') * -1 + reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32') + + gt_bbox = sample['gt_bbox'] + gt_class = sample['gt_class'] + + bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1 + bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1 + area = bbox_w * bbox_h + boxes_areas_log = np.log(area) + boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1] + boxes_area_topk_log = boxes_areas_log[boxes_ind] + gt_bbox = gt_bbox[boxes_ind] + gt_class = gt_class[boxes_ind] + + feat_gt_bbox = gt_bbox / self.down_ratio + feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1) + feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1], + feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0]) + + ct_inds = np.stack( + [(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2, + (gt_bbox[:, 1] + gt_bbox[:, 3]) / 2], + axis=1) / self.down_ratio + + h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32') + w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32') + + for k in range(len(gt_bbox)): + cls_id = gt_class[k] + fake_heatmap = np.zeros((feat_size, feat_size), dtype='float32') + self.draw_truncate_gaussian(fake_heatmap, ct_inds[k], + h_radiuses_alpha[k], + w_radiuses_alpha[k]) + + heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap) + box_target_inds = fake_heatmap > 0 + box_target[:, box_target_inds] = gt_bbox[k][:, None] + + local_heatmap = fake_heatmap[box_target_inds] + ct_div = np.sum(local_heatmap) + local_heatmap *= boxes_area_topk_log[k] + reg_weight[0, box_target_inds] = local_heatmap / ct_div + sample['ttf_heatmap'] = heatmap + sample['ttf_box_target'] = box_target + sample['ttf_reg_weight'] = reg_weight + return samples + + def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius): + h, w = 2 * h_radius + 1, 2 * w_radius + 1 + sigma_x = w / 6 + sigma_y = h / 6 + gaussian = gaussian2D((h, w), sigma_x, sigma_y) + + x, y = int(center[0]), int(center[1]) + + height, width = heatmap.shape[0:2] + + left, right = min(x, w_radius), min(width - x, w_radius + 1) + top, bottom = min(y, h_radius), min(height - y, h_radius + 1) + + masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right] + masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius - + left:w_radius + right] + if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: + heatmap[y - top:y + bottom, x - left:x + right] = np.maximum( + masked_heatmap, masked_gaussian) + return heatmap diff --git a/ppdet/data/transform/op_helper.py b/ppdet/data/transform/op_helper.py index a9d19b96d..02d219546 100644 --- a/ppdet/data/transform/op_helper.py +++ b/ppdet/data/transform/op_helper.py @@ -438,7 +438,8 @@ def gaussian_radius(bbox_size, min_overlap): def draw_gaussian(heatmap, center, radius, k=1, delte=6): diameter = 2 * radius + 1 - gaussian = gaussian2D((diameter, diameter), sigma=diameter / delte) + sigma = diameter / delte + gaussian = gaussian2D((diameter, diameter), sigma_x=sigma, sigma_y=sigma) x, y = center @@ -453,10 +454,11 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6): np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap) -def gaussian2D(shape, sigma=1): +def gaussian2D(shape, sigma_x=1, sigma_y=1): m, n = [(ss - 1.) / 2. for ss in shape] y, x = np.ogrid[-m:m + 1, -n:n + 1] - h = np.exp(-(x * x + y * y) / (2 * sigma * sigma)) + h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y * + sigma_y))) h[h < np.finfo(h.dtype).eps * h.max()] = 0 return h diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 0875b2696..db73e4174 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -92,7 +92,6 @@ class BaseOperator(object): class DecodeImage(BaseOperator): def __init__(self, to_rgb=True, with_mixup=False, with_cutmix=False): """ Transform the image data to numpy format. - Args: to_rgb (bool): whether to convert BGR to RGB with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score @@ -165,7 +164,6 @@ class MultiscaleTestResize(BaseOperator): use_flip=True): """ Rescale image to the each size in target size, and capped at max_size. - Args: origin_target_size(int): original target size of image's short side. origin_max_size(int): original max size of image. @@ -274,7 +272,6 @@ class ResizeImage(BaseOperator): if max_size != 0. If target_size is list, selected a scale randomly as the specified target size. - Args: target_size (int|list): the target size of image's short side, multi-scale training is adopted when type is list. @@ -1177,7 +1174,6 @@ class Permute(BaseOperator): Args: to_bgr (bool): confirm whether to convert RGB to BGR channel_first (bool): confirm whether to change channel - """ super(Permute, self).__init__() self.to_bgr = to_bgr @@ -1386,7 +1382,6 @@ class RandomInterpImage(BaseOperator): @register_op class Resize(BaseOperator): """Resize image and bbox. - Args: target_dim (int or list): target size, can be a single number or a list (for random shape). @@ -1419,6 +1414,7 @@ class Resize(BaseOperator): scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32) sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0, dim - 1) + sample['scale_factor'] = [scale_x, scale_y] * 2 sample['h'] = resize_h sample['w'] = resize_w @@ -1430,7 +1426,6 @@ class Resize(BaseOperator): @register_op class ColorDistort(BaseOperator): """Random color distortion. - Args: hue (list): hue settings. in [lower, upper, probability] format. @@ -1442,6 +1437,8 @@ class ColorDistort(BaseOperator): in [lower, upper, probability] format. random_apply (bool): whether to apply in random (yolo) or fixed (SSD) order. + hsv_format (bool): whether to convert color from BGR to HSV + random_channel (bool): whether to swap channels randomly """ def __init__(self, @@ -1449,13 +1446,17 @@ class ColorDistort(BaseOperator): saturation=[0.5, 1.5, 0.5], contrast=[0.5, 1.5, 0.5], brightness=[0.5, 1.5, 0.5], - random_apply=True): + random_apply=True, + hsv_format=False, + random_channel=False): super(ColorDistort, self).__init__() self.hue = hue self.saturation = saturation self.contrast = contrast self.brightness = brightness self.random_apply = random_apply + self.hsv_format = hsv_format + self.random_channel = random_channel def apply_hue(self, img): low, high, prob = self.hue @@ -1463,6 +1464,11 @@ class ColorDistort(BaseOperator): return img img = img.astype(np.float32) + if self.hsv_format: + img[..., 0] += random.uniform(low, high) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + return img # XXX works, but result differ from HSV version delta = np.random.uniform(low, high) @@ -1482,8 +1488,10 @@ class ColorDistort(BaseOperator): if np.random.uniform(0., 1.) < prob: return img delta = np.random.uniform(low, high) - img = img.astype(np.float32) + if self.hsv_format: + img[..., 1] *= delta + return img gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32) gray = gray.sum(axis=2, keepdims=True) gray *= (1.0 - delta) @@ -1530,12 +1538,24 @@ class ColorDistort(BaseOperator): if np.random.randint(0, 2): img = self.apply_contrast(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) img = self.apply_saturation(img) img = self.apply_hue(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) else: + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) img = self.apply_saturation(img) img = self.apply_hue(img) + if self.hsv_format: + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) img = self.apply_contrast(img) + + if self.random_channel: + if np.random.randint(0, 2): + img = img[..., np.random.permutation(3)] sample['image'] = img return sample @@ -1603,7 +1623,6 @@ class CornerRandColor(ColorDistort): @register_op class NormalizePermute(BaseOperator): """Normalize and permute channel order. - Args: mean (list): mean values in RGB order. std (list): std values in RGB order. @@ -1633,7 +1652,6 @@ class NormalizePermute(BaseOperator): @register_op class RandomExpand(BaseOperator): """Random expand the canvas. - Args: ratio (float): maximum expansion ratio. prob (float): probability to expand. @@ -1725,7 +1743,6 @@ class RandomExpand(BaseOperator): @register_op class RandomCrop(BaseOperator): """Random crop image and bboxes. - Args: aspect_ratio (list): aspect ratio of cropped region. in [min, max] format. @@ -1852,11 +1869,23 @@ class RandomCrop(BaseOperator): found = False for i in range(self.num_attempts): scale = np.random.uniform(*self.scaling) - min_ar, max_ar = self.aspect_ratio - aspect_ratio = np.random.uniform( - max(min_ar, scale**2), min(max_ar, scale**-2)) - crop_h = int(h * scale / np.sqrt(aspect_ratio)) - crop_w = int(w * scale * np.sqrt(aspect_ratio)) + if self.aspect_ratio is not None: + min_ar, max_ar = self.aspect_ratio + aspect_ratio = np.random.uniform( + max(min_ar, scale**2), min(max_ar, scale**-2)) + h_scale = scale / np.sqrt(aspect_ratio) + w_scale = scale * np.sqrt(aspect_ratio) + else: + h_scale = np.random.uniform(*self.scaling) + w_scale = np.random.uniform(*self.scaling) + crop_h = h * h_scale + crop_w = w * w_scale + if self.aspect_ratio is None: + if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0: + continue + + crop_h = int(crop_h) + crop_w = int(crop_w) crop_y = np.random.randint(0, h - crop_h) crop_x = np.random.randint(0, w - crop_w) crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h] @@ -2008,7 +2037,6 @@ class BboxXYXY2XYWH(BaseOperator): return sample -@register_op class Lighting(BaseOperator): """ Lighting the imagen by eigenvalues and eigenvectors @@ -2248,7 +2276,6 @@ class CornerRatio(BaseOperator): class RandomScaledCrop(BaseOperator): """Resize image and bbox based on long side (with optional random scaling), then crop or pad image to target size. - Args: target_dim (int): target size. scale_range (list): random scale range. @@ -2303,7 +2330,6 @@ class RandomScaledCrop(BaseOperator): @register_op class ResizeAndPad(BaseOperator): """Resize image and bbox, then pad image to target size. - Args: target_dim (int): target size interp (int): interpolation method, default to `cv2.INTER_LINEAR`. @@ -2342,7 +2368,6 @@ class ResizeAndPad(BaseOperator): @register_op class TargetAssign(BaseOperator): """Assign regression target and labels. - Args: image_size (int or list): input image size, a single integer or list of [h, w]. Default: 512 diff --git a/ppdet/modeling/anchor_heads/__init__.py b/ppdet/modeling/anchor_heads/__init__.py index f38665c13..80324aa84 100644 --- a/ppdet/modeling/anchor_heads/__init__.py +++ b/ppdet/modeling/anchor_heads/__init__.py @@ -20,6 +20,7 @@ from . import retina_head from . import fcos_head from . import corner_head from . import efficient_head +from . import ttf_head from .rpn_head import * from .yolo_head import * @@ -27,3 +28,4 @@ from .retina_head import * from .fcos_head import * from .corner_head import * from .efficient_head import * +from .ttf_head import * diff --git a/ppdet/modeling/anchor_heads/ttf_head.py b/ppdet/modeling/anchor_heads/ttf_head.py new file mode 100644 index 000000000..df50ca2d1 --- /dev/null +++ b/ppdet/modeling/anchor_heads/ttf_head.py @@ -0,0 +1,383 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.initializer import Normal, Constant, Uniform, Xavier +from paddle.fluid.regularizer import L2Decay +from ppdet.core.workspace import register +from ppdet.modeling.ops import DeformConv, DropBlock +from ppdet.modeling.losses import GiouLoss + +__all__ = ['TTFHead'] + + +@register +class TTFHead(object): + """ + TTFHead + Args: + head_conv(int): the default channel number of convolution in head. + 128 by default. + num_classes(int): the number of classes, 80 by default. + hm_weight(float): the weight of heatmap branch. 1. by default. + wh_weight(float): the weight of wh branch. 5. by default. + wh_offset_base(flaot): the base offset of width and height. + 16. by default. + planes(tuple): the channel number of convolution in each upsample. + (256, 128, 64) by default. + shortcut_num(tuple): the number of convolution layers in each shortcut. + (1, 2, 3) by default. + wh_head_conv_num(int): the number of convolution layers in wh head. + 2 by default. + hm_head_conv_num(int): the number of convolution layers in wh head. + 2 by default. + wh_conv(int): the channel number of convolution in wh head. + 64 by default. + wh_planes(int): the output channel in wh head. 4 by default. + score_thresh(float): the score threshold to get prediction. + 0.01 by default. + max_per_img(int): the maximum detection per image. 100 by default. + base_down_ratio(int): the base down_ratio, the actual down_ratio is + calculated by base_down_ratio and the number of upsample layers. + 16 by default. + wh_loss(object): `GiouLoss` instance. + dcn_upsample(bool): whether upsample by dcn. True by default. + dcn_head(bool): whether use dcn in head. False by default. + drop_block(bool): whether use dropblock. False by default. + block_size(int): block_size parameter for drop_block. 3 by default. + keep_prob(float): keep_prob parameter for drop_block. 0.9 by default. + """ + + __inject__ = ['wh_loss'] + __shared__ = ['num_classes'] + + def __init__(self, + head_conv=128, + num_classes=80, + hm_weight=1., + wh_weight=5., + wh_offset_base=16., + planes=(256, 128, 64), + shortcut_num=(1, 2, 3), + wh_head_conv_num=2, + hm_head_conv_num=2, + wh_conv=64, + wh_planes=4, + score_thresh=0.01, + max_per_img=100, + base_down_ratio=32, + wh_loss='GiouLoss', + dcn_upsample=True, + dcn_head=False, + drop_block=False, + block_size=3, + keep_prob=0.9): + super(TTFHead, self).__init__() + self.head_conv = head_conv + self.num_classes = num_classes + self.hm_weight = hm_weight + self.wh_weight = wh_weight + self.wh_offset_base = wh_offset_base + self.planes = planes + self.shortcut_num = shortcut_num + self.shortcut_len = len(shortcut_num) + self.wh_head_conv_num = wh_head_conv_num + self.hm_head_conv_num = hm_head_conv_num + self.wh_conv = wh_conv + self.wh_planes = wh_planes + self.score_thresh = score_thresh + self.max_per_img = max_per_img + self.down_ratio = base_down_ratio // 2**len(planes) + self.hm_weight = hm_weight + self.wh_weight = wh_weight + self.wh_loss = wh_loss + self.dcn_upsample = dcn_upsample + self.dcn_head = dcn_head + self.drop_block = drop_block + self.block_size = block_size + self.keep_prob = keep_prob + + def shortcut(self, x, out_c, layer_num, kernel_size=3, padding=1, + name=None): + assert layer_num > 0 + for i in range(layer_num): + act = 'relu' if i < layer_num - 1 else None + fan_out = kernel_size * kernel_size * out_c + std = math.sqrt(2. / fan_out) + param_name = name + '.layers.' + str(i * 2) + x = fluid.layers.conv2d( + x, + out_c, + kernel_size, + padding=padding, + act=act, + param_attr=ParamAttr( + initializer=Normal(0, std), name=param_name + '.weight'), + bias_attr=ParamAttr( + learning_rate=2., + regularizer=L2Decay(0.), + name=param_name + '.bias')) + return x + + def upsample(self, x, out_c, name=None): + fan_in = x.shape[1] * 3 * 3 + stdv = 1. / math.sqrt(fan_in) + if self.dcn_upsample: + conv = DeformConv( + x, + out_c, + 3, + initializer=Uniform(-stdv, stdv), + bias_attr=True, + name=name + '.0') + else: + conv = fluid.layers.conv2d( + x, + out_c, + 3, + padding=1, + param_attr=ParamAttr(initializer=Uniform(-stdv, stdv)), + bias_attr=ParamAttr( + learning_rate=2., regularizer=L2Decay(0.))) + + norm_name = name + '.1' + pattr = ParamAttr(name=norm_name + '.weight', initializer=Constant(1.)) + battr = ParamAttr(name=norm_name + '.bias', initializer=Constant(0.)) + bn = fluid.layers.batch_norm( + input=conv, + act='relu', + param_attr=pattr, + bias_attr=battr, + name=norm_name + '.output.1', + moving_mean_name=norm_name + '.running_mean', + moving_variance_name=norm_name + '.running_var') + up = fluid.layers.resize_bilinear( + bn, scale=2, name=name + '.2.upsample') + return up + + def _head(self, + x, + out_c, + conv_num=1, + head_out_c=None, + name=None, + is_test=False): + head_out_c = self.head_conv if not head_out_c else head_out_c + conv_w_std = 0.01 if '.hm' in name else 0.001 + conv_w_init = Normal(0, conv_w_std) + for i in range(conv_num): + conv_name = '{}.{}.conv'.format(name, i) + if self.dcn_head: + x = DeformConv( + x, + head_out_c, + 3, + initializer=conv_w_init, + name=conv_name + '.dcn') + x = fluid.layers.relu(x) + else: + x = fluid.layers.conv2d( + x, + head_out_c, + 3, + padding=1, + param_attr=ParamAttr( + initializer=conv_w_init, name=conv_name + '.weight'), + bias_attr=ParamAttr( + learning_rate=2., + regularizer=L2Decay(0.), + name=conv_name + '.bias'), + act='relu') + if self.drop_block and '.hm' in name: + x = DropBlock( + x, + block_size=self.block_size, + keep_prob=self.keep_prob, + is_test=is_test) + bias_init = float(-np.log((1 - 0.01) / 0.01)) if '.hm' in name else 0. + conv_b_init = Constant(bias_init) + x = fluid.layers.conv2d( + x, + out_c, + 1, + param_attr=ParamAttr( + initializer=conv_w_init, + name='{}.{}.weight'.format(name, conv_num)), + bias_attr=ParamAttr( + learning_rate=2., + regularizer=L2Decay(0.), + name='{}.{}.bias'.format(name, conv_num), + initializer=conv_b_init)) + return x + + def hm_head(self, x, name=None, is_test=False): + hm = self._head( + x, + self.num_classes, + self.hm_head_conv_num, + name=name, + is_test=is_test) + return hm + + def wh_head(self, x, name=None): + planes = self.wh_planes + wh = self._head( + x, planes, self.wh_head_conv_num, self.wh_conv, name=name) + return fluid.layers.relu(wh) + + def get_output(self, input, name=None, is_test=False): + feat = input[-1] + for i, out_c in enumerate(self.planes): + feat = self.upsample( + feat, out_c, name=name + '.deconv_layers.' + str(i)) + if i < self.shortcut_len: + shortcut = self.shortcut( + input[-i - 2], + out_c, + self.shortcut_num[i], + name=name + '.shortcut_layers.' + str(i)) + feat = fluid.layers.elementwise_add(feat, shortcut) + + hm = self.hm_head(feat, name=name + '.hm', is_test=is_test) + wh = self.wh_head(feat, name=name + '.wh') * self.wh_offset_base + + return hm, wh + + def _simple_nms(self, heat, kernel=3): + pad = (kernel - 1) // 2 + hmax = fluid.layers.pool2d(heat, kernel, 'max', pool_padding=pad) + keep = fluid.layers.cast(hmax == heat, 'float32') + return heat * keep + + def _topk(self, scores, k): + cat, height, width = scores.shape[1:] + # batch size is 1 + scores_r = fluid.layers.reshape(scores, [cat, -1]) + topk_scores, topk_inds = fluid.layers.topk(scores_r, k) + topk_ys = topk_inds / width + topk_xs = topk_inds % width + + topk_score_r = fluid.layers.reshape(topk_scores, [-1]) + topk_score, topk_ind = fluid.layers.topk(topk_score_r, k) + topk_clses = fluid.layers.cast(topk_ind / k, 'float32') + + topk_inds = fluid.layers.reshape(topk_inds, [-1]) + topk_ys = fluid.layers.reshape(topk_ys, [-1, 1]) + topk_xs = fluid.layers.reshape(topk_xs, [-1, 1]) + topk_inds = fluid.layers.gather(topk_inds, topk_ind) + topk_ys = fluid.layers.gather(topk_ys, topk_ind) + topk_xs = fluid.layers.gather(topk_xs, topk_ind) + + return topk_score, topk_inds, topk_clses, topk_ys, topk_xs + + def get_bboxes(self, heatmap, wh, scale_factor): + heatmap = fluid.layers.sigmoid(heatmap) + heat = self._simple_nms(heatmap) + scores, inds, clses, ys, xs = self._topk(heat, self.max_per_img) + ys = fluid.layers.cast(ys, 'float32') * self.down_ratio + xs = fluid.layers.cast(xs, 'float32') * self.down_ratio + scores = fluid.layers.unsqueeze(scores, [1]) + clses = fluid.layers.unsqueeze(clses, [1]) + + wh_t = fluid.layers.transpose(wh, [0, 2, 3, 1]) + wh = fluid.layers.reshape(wh_t, [-1, wh_t.shape[-1]]) + wh = fluid.layers.gather(wh, inds) + + x1 = xs - wh[:, 0:1] + y1 = ys - wh[:, 1:2] + x2 = xs + wh[:, 2:3] + y2 = ys + wh[:, 3:4] + bboxes = fluid.layers.concat([x1, y1, x2, y2], axis=1) + bboxes = fluid.layers.elementwise_div(bboxes, scale_factor, axis=-1) + results = fluid.layers.concat([clses, scores, bboxes], axis=1) + # hack: append result with cls=-1 and score=1. to avoid all scores + # are less than score_thresh which may cause error in gather. + fill_r = fluid.layers.assign( + np.array( + [[-1, 1., 0, 0, 0, 0]], dtype='float32')) + results = fluid.layers.concat([results, fill_r]) + scores = results[:, 1] + valid_ind = fluid.layers.where(scores > self.score_thresh) + results = fluid.layers.gather(results, valid_ind) + return {'bbox': results} + + def ct_focal_loss(self, pred_hm, target_hm, gamma=2.0): + fg_map = fluid.layers.cast(target_hm == 1, 'float32') + fg_map.stop_gradient = True + bg_map = fluid.layers.cast(target_hm < 1, 'float32') + bg_map.stop_gradient = True + + neg_weights = fluid.layers.pow(1 - target_hm, 4) * bg_map + pos_loss = 0 - fluid.layers.log(pred_hm) * fluid.layers.pow( + 1 - pred_hm, gamma) * fg_map + neg_loss = 0 - fluid.layers.log(1 - pred_hm) * fluid.layers.pow( + pred_hm, gamma) * neg_weights + pos_loss = fluid.layers.reduce_sum(pos_loss) + neg_loss = fluid.layers.reduce_sum(neg_loss) + + fg_num = fluid.layers.reduce_sum(fg_map) + focal_loss = (pos_loss + neg_loss) / ( + fg_num + fluid.layers.cast(fg_num == 0, 'float32')) + return focal_loss + + def filter_box_by_weight(self, pred, target, weight): + index = fluid.layers.where(weight > 0) + index.stop_gradient = True + weight = fluid.layers.gather_nd(weight, index) + pred = fluid.layers.gather_nd(pred, index) + target = fluid.layers.gather_nd(target, index) + return pred, target, weight + + def get_loss(self, pred_hm, pred_wh, target_hm, box_target, target_weight): + pred_hm = paddle.tensor.clamp( + fluid.layers.sigmoid(pred_hm), 1e-4, 1 - 1e-4) + hm_loss = self.ct_focal_loss(pred_hm, target_hm) * self.hm_weight + shape = fluid.layers.shape(target_hm) + shape.stop_gradient = True + H, W = shape[2], shape[3] + + mask = fluid.layers.reshape(target_weight, [-1, H, W]) + avg_factor = fluid.layers.reduce_sum(mask) + 1e-4 + base_step = self.down_ratio + zero = fluid.layers.fill_constant(shape=[1], value=0, dtype='int32') + shifts_x = paddle.arange(zero, W * base_step, base_step, dtype='int32') + shifts_y = paddle.arange(zero, H * base_step, base_step, dtype='int32') + shift_y, shift_x = paddle.tensor.meshgrid([shifts_y, shifts_x]) + base_loc = fluid.layers.stack([shift_x, shift_y], axis=0) + base_loc.stop_gradient = True + + pred_boxes = fluid.layers.concat( + [0 - pred_wh[:, 0:2, :, :] + base_loc, pred_wh[:, 2:4] + base_loc], + axis=1) + pred_boxes = fluid.layers.transpose(pred_boxes, [0, 2, 3, 1]) + boxes = fluid.layers.transpose(box_target, [0, 2, 3, 1]) + boxes.stop_gradient = True + + pred_boxes, boxes, mask = self.filter_box_by_weight(pred_boxes, boxes, + mask) + mask.stop_gradient = True + wh_loss = self.wh_loss( + pred_boxes, boxes, outside_weight=mask, use_transform=False) + wh_loss = wh_loss / avg_factor + + ttf_loss = {'hm_loss': hm_loss, 'wh_loss': wh_loss} + return ttf_loss diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index d558e06db..f29be6cbd 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -27,6 +27,7 @@ from . import blazeface from . import faceboxes from . import fcos from . import cornernet_squeeze +from . import ttfnet from .faster_rcnn import * from .mask_rcnn import * @@ -41,3 +42,4 @@ from .blazeface import * from .faceboxes import * from .fcos import * from .cornernet_squeeze import * +from .ttfnet import * diff --git a/ppdet/modeling/architectures/ttfnet.py b/ppdet/modeling/architectures/ttfnet.py new file mode 100644 index 000000000..75ea43ddc --- /dev/null +++ b/ppdet/modeling/architectures/ttfnet.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +from paddle import fluid + +from ppdet.experimental import mixed_precision_global_state +from ppdet.core.workspace import register + +__all__ = ['TTFNet'] + + +@register +class TTFNet(object): + """ + TTFNet network, see https://arxiv.org/abs/1909.00700 + + Args: + backbone (object): backbone instance + ttf_head (object): `TTFHead` instance + num_classes (int): the number of classes, 80 by default. + """ + + __category__ = 'architecture' + __inject__ = ['backbone', 'ttf_head'] + __shared__ = ['num_classes'] + + def __init__(self, backbone, ttf_head='TTFHead', num_classes=80): + super(TTFNet, self).__init__() + self.backbone = backbone + self.ttf_head = ttf_head + self.num_classes = num_classes + + def build(self, feed_vars, mode='train'): + im = feed_vars['image'] + + mixed_precision_enabled = mixed_precision_global_state() is not None + + # cast inputs to FP16 + if mixed_precision_enabled: + im = fluid.layers.cast(im, 'float16') + + body_feats = self.backbone(im) + + if isinstance(body_feats, OrderedDict): + body_feat_names = list(body_feats.keys()) + body_feats = [body_feats[name] for name in body_feat_names] + + # cast features back to FP32 + if mixed_precision_enabled: + body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats] + + predict_hm, predict_wh = self.ttf_head.get_output( + body_feats, 'ttf_head', is_test=mode == 'test') + if mode == 'train': + heatmap = feed_vars['ttf_heatmap'] + box_target = feed_vars['ttf_box_target'] + reg_weight = feed_vars['ttf_reg_weight'] + loss = self.ttf_head.get_loss(predict_hm, predict_wh, heatmap, + box_target, reg_weight) + total_loss = fluid.layers.sum(list(loss.values())) + loss.update({'loss': total_loss}) + return loss + else: + results = self.ttf_head.get_bboxes(predict_hm, predict_wh, + feed_vars['scale_factor']) + return results + + def _inputs_def(self, image_shape, downsample): + im_shape = [None] + image_shape + H, W = im_shape[2:] + target_h = None if H is None else H // downsample + target_w = None if W is None else W // downsample + # yapf: disable + inputs_def = { + 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, + 'scale_factor': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 0}, + 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, + 'ttf_heatmap': {'shape': [None, self.num_classes, target_h, target_w], 'dtype': 'float32', 'lod_level': 0}, + 'ttf_box_target': {'shape': [None, 4, target_h, target_w], 'dtype': 'float32', 'lod_level': 0}, + 'ttf_reg_weight': {'shape': [None, 1, target_h, target_w], 'dtype': 'float32', 'lod_level': 0}, + } + # yapf: enable + + return inputs_def + + def build_inputs( + self, + image_shape=[3, None, None], + fields=[ + 'image', 'ttf_heatmap', 'ttf_box_target', 'ttf_reg_weight' + ], # for train + use_dataloader=True, + iterable=False, + downsample=4): + inputs_def = self._inputs_def(image_shape, downsample) + feed_vars = OrderedDict([(key, fluid.data( + name=key, + shape=inputs_def[key]['shape'], + dtype=inputs_def[key]['dtype'], + lod_level=inputs_def[key]['lod_level'])) for key in fields]) + loader = fluid.io.DataLoader.from_generator( + feed_list=list(feed_vars.values()), + capacity=16, + use_double_buffer=True, + iterable=iterable) if use_dataloader else None + return feed_vars, loader + + def train(self, feed_vars): + return self.build(feed_vars, mode='train') + + def eval(self, feed_vars): + return self.build(feed_vars, mode='test') + + def test(self, feed_vars): + return self.build(feed_vars, mode='test') diff --git a/ppdet/modeling/backbones/darknet.py b/ppdet/modeling/backbones/darknet.py index 37583ab29..0f5a7a542 100644 --- a/ppdet/modeling/backbones/darknet.py +++ b/ppdet/modeling/backbones/darknet.py @@ -42,13 +42,15 @@ class DarkNet(object): depth=53, norm_type='bn', norm_decay=0., - weight_prefix_name=''): + weight_prefix_name='', + freeze_at=-1): assert depth in [53], "unsupported depth value" self.depth = depth self.norm_type = norm_type self.norm_decay = norm_decay self.depth_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)} self.prefix_name = weight_prefix_name + self.freeze_at = freeze_at def _conv_norm(self, input, @@ -161,6 +163,8 @@ class DarkNet(object): ch_out=32 * 2**i, count=stage, name=self.prefix_name + "stage.{}".format(i)) + if i < self.freeze_at: + block.stop_gradient = True blocks.append(block) if i < len(stages) - 1: # do not downsaple in the last stage downsample_ = self._downsample( diff --git a/ppdet/modeling/losses/giou_loss.py b/ppdet/modeling/losses/giou_loss.py index 5362e37ae..82c07c828 100644 --- a/ppdet/modeling/losses/giou_loss.py +++ b/ppdet/modeling/losses/giou_loss.py @@ -33,14 +33,24 @@ class GiouLoss(object): loss_weight (float): diou loss weight, default as 10 in faster-rcnn is_cls_agnostic (bool): flag of class-agnostic num_classes (int): class num + do_average (bool): whether to average the loss + use_class_weight(bool): whether to use class weight ''' __shared__ = ['num_classes'] - def __init__(self, loss_weight=10., is_cls_agnostic=False, num_classes=81): + def __init__(self, + loss_weight=10., + is_cls_agnostic=False, + num_classes=81, + do_average=True, + use_class_weight=True): super(GiouLoss, self).__init__() self.loss_weight = loss_weight self.is_cls_agnostic = is_cls_agnostic self.num_classes = num_classes + self.do_average = do_average + self.class_weight = 2 if is_cls_agnostic else num_classes + self.use_class_weight = use_class_weight # deltas: NxMx4 def bbox_transform(self, deltas, weights): @@ -78,10 +88,15 @@ class GiouLoss(object): y, inside_weight=None, outside_weight=None, - bbox_reg_weight=[0.1, 0.1, 0.2, 0.2]): + bbox_reg_weight=[0.1, 0.1, 0.2, 0.2], + use_transform=True): eps = 1.e-10 - x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight) - x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight) + if use_transform: + x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight) + x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight) + else: + x1, y1, x2, y2 = fluid.layers.split(x, num_or_sections=4, dim=1) + x1g, y1g, x2g, y2g = fluid.layers.split(y, num_or_sections=4, dim=1) x2 = fluid.layers.elementwise_max(x1, x2) y2 = fluid.layers.elementwise_max(y1, y2) @@ -99,9 +114,9 @@ class GiouLoss(object): intsctk = (xkis2 - xkis1) * (ykis2 - ykis1) intsctk = intsctk * fluid.layers.greater_than( xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1) - unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g ) - intsctk + eps + iouk = intsctk / unionk area_c = (xc2 - xc1) * (yc2 - yc1) + eps @@ -116,10 +131,17 @@ class GiouLoss(object): outside_weight = fluid.layers.reduce_mean(outside_weight, dim=1) iou_weights = inside_weight * outside_weight - - class_weight = 2 if self.is_cls_agnostic else self.num_classes - iouk = fluid.layers.reduce_mean((1 - iouk) * iou_weights) * class_weight - miouk = fluid.layers.reduce_mean( - (1 - miouk) * iou_weights) * class_weight + elif outside_weight is not None: + iou_weights = outside_weight + + if self.do_average: + miouk = fluid.layers.reduce_mean((1 - miouk) * iou_weights) + else: + iou_distance = fluid.layers.elementwise_mul( + 1 - miouk, iou_weights, axis=0) + miouk = fluid.layers.reduce_sum(iou_distance) + + if self.use_class_weight: + miouk = miouk * self.class_weight return miouk * self.loss_weight diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index d456b4c09..30459176f 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -30,7 +30,8 @@ __all__ = [ 'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner', 'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead', 'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm', - 'DeformConvNorm', 'MultiClassSoftNMS', 'MatrixNMS', 'LibraBBoxAssigner' + 'DeformConvNorm', 'MultiClassSoftNMS', 'MatrixNMS', 'LibraBBoxAssigner', + 'DeformConv' ] @@ -43,36 +44,32 @@ def _conv_offset(input, filter_size, stride, padding, act=None, name=None): stride=stride, padding=padding, param_attr=ParamAttr( - initializer=fluid.initializer.Constant(value=0), - name=name + ".w_0"), + initializer=fluid.initializer.Constant(0), name=name + ".w_0"), bias_attr=ParamAttr( - initializer=fluid.initializer.Constant(value=0), + initializer=fluid.initializer.Constant(0), + learning_rate=2., + regularizer=L2Decay(0.), name=name + ".b_0"), act=act, name=name) return out -def DeformConvNorm(input, - num_filters, - filter_size, - stride=1, - groups=1, - norm_decay=0., - norm_type='affine_channel', - norm_groups=32, - dilation=1, - lr_scale=1, - freeze_norm=False, - act=None, - norm_name=None, - initializer=None, - bias_attr=False, - name=None): +def DeformConv(input, + num_filters, + filter_size, + stride=1, + groups=1, + dilation=1, + lr_scale=1, + initializer=None, + bias_attr=False, + name=None): if bias_attr: bias_para = ParamAttr( name=name + "_bias", - initializer=fluid.initializer.Constant(value=0), + initializer=fluid.initializer.Constant(0), + regularizer=L2Decay(0.), learning_rate=lr_scale * 2) else: bias_para = False @@ -109,6 +106,29 @@ def DeformConvNorm(input, bias_attr=bias_para, name=name + ".conv2d.output.1") + return conv + + +def DeformConvNorm(input, + num_filters, + filter_size, + stride=1, + groups=1, + norm_decay=0., + norm_type='affine_channel', + norm_groups=32, + dilation=1, + lr_scale=1, + freeze_norm=False, + act=None, + norm_name=None, + initializer=None, + bias_attr=False, + name=None): + assert norm_type in ['bn', 'sync_bn', 'affine_channel'] + conv = DeformConv(input, num_filters, filter_size, stride, groups, dilation, + lr_scale, initializer, bias_attr, name) + norm_lr = 0. if freeze_norm else 1. pattr = ParamAttr( name=norm_name + '_scale', @@ -330,7 +350,6 @@ class AnchorGenerator(object): @serializable class AnchorGrid(object): """Generate anchor grid - Args: image_size (int or list): input image size, may be a single integer or list of [h, w]. Default: 512 diff --git a/ppdet/utils/coco_eval.py b/ppdet/utils/coco_eval.py index b065a276a..b54be135c 100644 --- a/ppdet/utils/coco_eval.py +++ b/ppdet/utils/coco_eval.py @@ -261,6 +261,7 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False): for j in range(num): dt = bboxes[k] clsid, score, xmin, ymin, xmax, ymax = dt.tolist() + if clsid < 0: continue catid = (clsid2catid[int(clsid)]) if is_bbox_normalized: diff --git a/ppdet/utils/eval_utils.py b/ppdet/utils/eval_utils.py index 8ba53838d..ef3d11da7 100644 --- a/ppdet/utils/eval_utils.py +++ b/ppdet/utils/eval_utils.py @@ -161,6 +161,8 @@ def eval_run(exe, if 'Corner' in cfg.architecture and post_config is not None: from ppdet.utils.post_process import corner_post_process corner_post_process(res, post_config, cfg.num_classes) + if 'TTFNet' in cfg.architecture: + res['bbox'][1].append([len(res['bbox'][0])]) results.append(res) if iter_id % 100 == 0: logger.info('Test iter {}'.format(iter_id)) diff --git a/tools/export_model.py b/tools/export_model.py index 786fee49b..5024167ca 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -76,6 +76,8 @@ def parse_reader(reader_cfg, metric, arch): params['max_size'] = max(image_shape[ 1:]) if arch in scale_set else 0 params['image_shape'] = image_shape[1:] + if 'target_dim' in params: + params.pop('target_dim') p.update(params) preprocess_list.append(p) batch_transforms = reader_cfg.get('batch_transforms', None) @@ -109,6 +111,7 @@ def dump_infer_config(FLAGS, config): 'RCNN': 40, 'RetinaNet': 40, 'Face': 3, + 'TTFNet': 3, } infer_arch = config['architecture'] -- GitLab