未验证 提交 996a4fa5 编写于 作者: W wangguanzhong 提交者: GitHub

Add CenterNet (#4098)

* add centernet

* add centernet

* add r50 for centernet

* support deploy for centernet

* add doc

* update pretrain

* clean code

* update r50 doc

* unify for_mot

* update code according to review

* update doc

* refine preprocess

* update doc
上级 d896574f
English | [简体中文](README_cn.md)
# CenterNet (CenterNet: Objects as Points)
## Table of Contents
- [Introduction](#Introduction)
- [Model Zoo](#Model_Zoo)
- [Citations](#Citations)
## Introduction
[CenterNet](http://arxiv.org/abs/1904.07850) is an Anchor Free detector, which model an object as a single point -- the center point of its bounding box. The detector uses keypoint estimation to find center points and regresses to all other object properties. The center point based approach, CenterNet, is end-to-end differentiable, simpler, faster, and more accurate than corresponding bounding box based detectors.
## Model Zoo
### CenterNet Results on COCO-val 2017
| backbone | input shape | mAP | FPS | download | config |
| :--------------| :------- | :----: | :------: | :----: |:-----: |
| DLA-34(paper) | 512x512 | 37.4 | - | - | - |
| DLA-34 | 512x512 | 37.6 | - | [model](https://bj.bcebos.com/v1/paddledet/models/centernet_dla34_140e_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/centernet/centernet_dla34_140e_coco.yml) |
| ResNet50 + DLAUp | 512x512 | 38.9 | - | [model](https://bj.bcebos.com/v1/paddledet/models/centernet_r50_140e_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/centernet/centernet_r50_140e_coco.yml) |
## Citations
```
@article{zhou2019objects,
title={Objects as points},
author={Zhou, Xingyi and Wang, Dequan and Kr{\"a}henb{\"u}hl, Philipp},
journal={arXiv preprint arXiv:1904.07850},
year={2019}
}
```
简体中文 | [English](README.md)
# CenterNet (CenterNet: Objects as Points)
## 内容
- [简介](#简介)
- [模型库](#模型库)
- [引用](#引用)
## 内容
[CenterNet](http://arxiv.org/abs/1904.07850)是Anchor Free检测器,将物体表示为一个目标框中心点。CenterNet使用关键点检测的方式定位中心点并回归物体的其他属性。CenterNet是以中心点为基础的检测方法,是端到端可训练的,并且相较于基于anchor的检测器更加检测高效。
## 模型库
### CenterNet在COCO-val 2017上结果
| 骨干网络 | 输入尺寸 | mAP | FPS | 下载链接 | 配置文件 |
| :--------------| :------- | :----: | :------: | :----: |:-----: |
| DLA-34(paper) | 512x512 | 37.4 | - | - | - |
| DLA-34 | 512x512 | 37.6 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/centernet_dla34_140e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/centernet/centernet_dla34_140e_coco.yml) |
| ResNet50 + DLAUp | 512x512 | 38.9 | - | [下载链接](https://bj.bcebos.com/v1/paddledet/models/centernet_r50_140e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/centernet/centernet_r50_140e_coco.yml) |
## 引用
```
@article{zhou2019objects,
title={Objects as points},
author={Zhou, Xingyi and Wang, Dequan and Kr{\"a}henb{\"u}hl, Philipp},
journal={arXiv preprint arXiv:1904.07850},
year={2019}
}
```
architecture: CenterNet
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/DLA34_pretrain.pdparams
CenterNet:
backbone: DLA
neck: CenterNetDLAFPN
head: CenterNetHead
post_process: CenterNetPostProcess
DLA:
depth: 34
CenterNetDLAFPN:
down_ratio: 4
CenterNetHead:
head_planes: 256
regress_ltrb: False
CenterNetPostProcess:
max_per_img: 100
regress_ltrb: False
architecture: CenterNet
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_pretrained.pdparams
norm_type: sync_bn
use_ema: true
ema_decay: 0.9998
CenterNet:
backbone: ResNet
neck: CenterNetDLAFPN
head: CenterNetHead
post_process: CenterNetPostProcess
ResNet:
depth: 50
variant: d
return_idx: [0, 1, 2, 3]
freeze_at: -1
norm_decay: 0.
variant: d
dcn_v2_stages: [3]
CenterNetDLAFPN:
first_level: 0
last_level: 4
down_ratio: 4
dcn_v2: False
CenterNetHead:
head_planes: 256
regress_ltrb: False
CenterNetPostProcess:
max_per_img: 100
regress_ltrb: False
worker_num: 4
TrainReader:
inputs_def:
image_shape: [3, 512, 512]
sample_transforms:
- Decode: {}
- FlipWarpAffine: {keep_res: False, input_h: 512, input_w: 512, use_random: True}
- CenterRandColor: {}
- Lighting: {eigval: [0.2141788, 0.01817699, 0.00341571], eigvec: [[-0.58752847, -0.69563484, 0.41340352], [-0.5832747, 0.00994535, -0.81221408], [-0.56089297, 0.71832671, 0.41158938]]}
- NormalizeImage: {mean: [0.40789655, 0.44719303, 0.47026116], std: [0.2886383 , 0.27408165, 0.27809834], is_scale: False}
- Permute: {}
- Gt2CenterNetTarget: {down_ratio: 4, max_objs: 128}
batch_size: 16
shuffle: True
drop_last: True
use_shared_memory: True
EvalReader:
sample_transforms:
- Decode: {}
- WarpAffine: {keep_res: True, input_h: 512, input_w: 512}
- NormalizeImage: {mean: [0.40789655, 0.44719303, 0.47026116], std: [0.2886383 , 0.27408165, 0.27809834]}
- Permute: {}
batch_size: 1
TestReader:
inputs_def:
image_shape: [3, 512, 512]
sample_transforms:
- Decode: {}
- WarpAffine: {keep_res: True, input_h: 512, input_w: 512}
- NormalizeImage: {mean: [0.40789655, 0.44719303, 0.47026116], std: [0.2886383 , 0.27408165, 0.27809834]}
- Permute: {}
batch_size: 1
epoch: 140
LearningRate:
base_lr: 0.0005
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [90, 120]
use_warmup: False
OptimizerBuilder:
optimizer:
type: Adam
regularizer: NULL
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_140e.yml',
'_base_/centernet_dla34.yml',
'_base_/centernet_reader.yml',
]
weights: output/centernet_dla34_140e_coco/model_final
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_140e.yml',
'_base_/centernet_r50.yml',
'_base_/centernet_reader.yml',
]
weights: output/centernet_r50_140e_coco/model_final
architecture: FairMOT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/fairmot_dla34_crowdhuman_pretrained.pdparams
for_mot: True
FairMOT:
detector: CenterNet
......@@ -12,10 +13,9 @@ CenterNet:
neck: CenterNetDLAFPN
head: CenterNetHead
post_process: CenterNetPostProcess
for_mot: True
CenterNetPostProcess:
for_mot: True
max_per_img: 500
JDETracker:
conf_thres: 0.4
......
architecture: FairMOT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/centernet_hardnet85_coco.pdparams
for_mot: True
FairMOT:
detector: CenterNet
......@@ -12,7 +13,6 @@ CenterNet:
neck: CenterNetHarDNetFPN
head: CenterNetHead
post_process: CenterNetPostProcess
for_mot: True
HarDNet:
depth_wise: False
......@@ -32,7 +32,8 @@ FairMOTEmbeddingHead:
ch_head: 512
CenterNetPostProcess:
for_mot: True
max_per_img: 500
regress_ltrb: True
JDETracker:
conf_thres: 0.4
......
architecture: FairMOT
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/HRNet_W18_C_pretrained.pdparams
for_mot: True
FairMOT:
detector: CenterNet
......@@ -12,7 +13,6 @@ CenterNet:
head: CenterNetHead
post_process: CenterNetPostProcess
neck: CenterNetDLAFPN
for_mot: True
HRNet:
width: 18
......@@ -28,7 +28,7 @@ CenterNetDLAFPN:
dcn_v2: False
CenterNetPostProcess:
for_mot: True
max_per_img: 500
JDETracker:
conf_thres: 0.4
......
......@@ -26,7 +26,7 @@ from paddle.inference import create_predictor
from benchmark_utils import PaddleInferBenchmark
from picodet_postprocess import PicoDetPostProcess
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine
from visualize import visualize_box_mask
from utils import argsparser, Timer, get_current_memory_mb
......@@ -45,6 +45,7 @@ SUPPORT_MODELS = {
'DeepSORT',
'GFL',
'PicoDet',
'CenterNet',
}
......
......@@ -245,6 +245,139 @@ class LetterBoxResize(object):
return im, im_info
class WarpAffine(object):
"""Warp affine the image
"""
def __init__(self,
keep_res=False,
pad=31,
input_h=512,
input_w=512,
scale=0.4,
shift=0.1):
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.scale = scale
self.shift = shift
def _get_3rd_point(self, a, b):
assert len(
a) == 2, 'input of _get_3rd_point should be point with length of 2'
assert len(
b) == 2, 'input of _get_3rd_point should be point with length of 2'
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
return third_pt
def rotate_point(self, pt, angle_rad):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert len(pt) == 2
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
new_x = pt[0] * cs - pt[1] * sn
new_y = pt[0] * sn + pt[1] * cs
rotated_pt = [new_x, new_y]
return rotated_pt
def get_affine_transform(self,
center,
input_size,
rot,
output_size,
shift=(0., 0.),
inv=False):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
input_size (np.ndarray[2, ]): Size of input feature (width, height).
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]
rot_rad = np.pi * rot / 180
src_dir = self.rotate_point([0., src_w * -0.5], rot_rad)
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = self._get_3rd_point(src[0, :], src[1, :])
dst = np.zeros((3, 2), dtype=np.float32)
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
dst[2, :] = self._get_3rd_point(dst[0, :], dst[1, :])
if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
return trans
def __call__(self, im, im_info):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
h, w = img.shape[:2]
if self.keep_res:
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
trans_input = self.get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
return inp, im_info
def preprocess(im, preprocess_ops):
# process image by preprocess_ops
im_info = {
......
......@@ -25,18 +25,26 @@ import cv2
import math
import numpy as np
from .operators import register_op, BaseOperator, Resize
from .op_helper import jaccard_overlap, gaussian2D
from .op_helper import jaccard_overlap, gaussian2D, gaussian_radius, draw_umich_gaussian
from .atss_assigner import ATSSAssigner
from scipy import ndimage
from ppdet.modeling import bbox_utils
from ppdet.utils.logger import setup_logger
from ppdet.modeling.keypoint_utils import get_affine_transform, affine_transform
logger = setup_logger(__name__)
__all__ = [
'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget',
'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseRCNNTarget', 'PadMaskBatch',
'Gt2GFLTarget'
'PadBatch',
'BatchRandomResize',
'Gt2YoloTarget',
'Gt2FCOSTarget',
'Gt2TTFTarget',
'Gt2Solov2Target',
'Gt2SparseRCNNTarget',
'PadMaskBatch',
'Gt2GFLTarget',
'Gt2CenterNetTarget',
]
......@@ -967,3 +975,84 @@ class PadMaskBatch(BaseOperator):
data['gt_rbox'] = rbox
return samples
@register_op
class Gt2CenterNetTarget(BaseOperator):
"""Gt2CenterNetTarget
Genterate CenterNet targets by ground-truth
Args:
down_ratio (int): The down sample ratio between output feature and
input image.
num_classes (int): The number of classes, 80 by default.
max_objs (int): The maximum objects detected, 128 by default.
"""
def __init__(self, down_ratio, num_classes=80, max_objs=128):
super(Gt2CenterNetTarget, self).__init__()
self.down_ratio = down_ratio
self.num_classes = num_classes
self.max_objs = max_objs
def __call__(self, sample, context=None):
input_h, input_w = sample['image'].shape[1:]
output_h = input_h // self.down_ratio
output_w = input_w // self.down_ratio
num_classes = self.num_classes
c = sample['center']
s = sample['scale']
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
wh = np.zeros((self.max_objs, 2), dtype=np.float32)
dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
reg = np.zeros((self.max_objs, 2), dtype=np.float32)
ind = np.zeros((self.max_objs), dtype=np.int64)
reg_mask = np.zeros((self.max_objs), dtype=np.int32)
cat_spec_wh = np.zeros(
(self.max_objs, num_classes * 2), dtype=np.float32)
cat_spec_mask = np.zeros(
(self.max_objs, num_classes * 2), dtype=np.int32)
trans_output = get_affine_transform(c, [s, s], 0, [output_w, output_h])
gt_det = []
for i, (bbox, cls) in enumerate(zip(gt_bbox, gt_class)):
cls = int(cls)
bbox[:2] = affine_transform(bbox[:2], trans_output)
bbox[2:] = affine_transform(bbox[2:], trans_output)
bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
if h > 0 and w > 0:
radius = gaussian_radius((math.ceil(h), math.ceil(w)), 0.7)
radius = max(0, int(radius))
ct = np.array(
[(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2],
dtype=np.float32)
ct_int = ct.astype(np.int32)
draw_umich_gaussian(hm[cls], ct_int, radius)
wh[i] = 1. * w, 1. * h
ind[i] = ct_int[1] * output_w + ct_int[0]
reg[i] = ct - ct_int
reg_mask[i] = 1
cat_spec_wh[i, cls * 2:cls * 2 + 2] = wh[i]
cat_spec_mask[i, cls * 2:cls * 2 + 2] = 1
gt_det.append([
ct[0] - w / 2, ct[1] - h / 2, ct[0] + w / 2, ct[1] + h / 2,
1, cls
])
sample.pop('gt_bbox', None)
sample.pop('gt_class', None)
sample.pop('center', None)
sample.pop('scale', None)
sample.pop('is_crowd', None)
sample.pop('difficult', None)
sample['heatmap'] = hm
sample['index_mask'] = reg_mask
sample['index'] = ind
sample['size'] = wh
sample['offset'] = reg
return sample
......@@ -32,6 +32,7 @@ from .operators import BaseOperator, register_op
from .batch_operators import Gt2TTFTarget
from ppdet.modeling.bbox_utils import bbox_iou_np_expand
from ppdet.utils.logger import setup_logger
from .op_helper import gaussian_radius
logger = setup_logger(__name__)
__all__ = [
......@@ -583,7 +584,7 @@ class Gt2FairMOTTarget(Gt2TTFTarget):
bbox_xy[3] = bbox_xy[1] + bbox_xy[3]
if h > 0 and w > 0:
radius = self.gaussian_radius((math.ceil(h), math.ceil(w)))
radius = gaussian_radius((math.ceil(h), math.ceil(w)), 0.7)
radius = max(0, int(radius))
ct = np.array([bbox[0], bbox[1]], dtype=np.float32)
ct_int = ct.astype(np.int32)
......@@ -612,25 +613,3 @@ class Gt2FairMOTTarget(Gt2TTFTarget):
sample.pop('gt_score', None)
sample.pop('gt_ide', None)
return samples
def gaussian_radius(self, det_size, min_overlap=0.7):
height, width = det_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
r1 = (b1 + sq1) / 2
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
r2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
r3 = (b3 + sq3) / 2
return min(r1, r2, r3)
......@@ -420,19 +420,19 @@ def gaussian_radius(bbox_size, min_overlap):
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
radius1 = (b1 - sq1) / (2 * a1)
radius1 = (b1 + sq1) / (2 * a1)
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
radius2 = (b2 - sq2) / (2 * a2)
radius2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
radius3 = (b3 + sq3) / (2 * a3)
radius3 = (b3 + sq3) / 2
return min(radius1, radius2, radius3)
......@@ -521,3 +521,33 @@ def filter_bbox(bbox, w, h, area_thr=0.25, wh_thr=2, ar_thr=20):
mask = (area_ratio > area_thr) & (
(wh > wh_thr).all(1)) & (ar_ratio < ar_thr)
return mask
def draw_umich_gaussian(heatmap, center, radius, k=1):
"""
draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126
"""
diameter = 2 * radius + 1
gaussian = gaussian2D(
(diameter, diameter), sigma_x=diameter / 6, sigma_y=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
return heatmap
def get_border(border, size):
i = 1
while size - border // i <= border // i:
i *= 2
return border // i
......@@ -48,9 +48,10 @@ from .op_helper import (satisfy_sample_constraint, filter_and_process,
generate_sample_bbox, clip_bbox, data_anchor_sampling,
satisfy_sample_constraint_coverage, crop_image_sampling,
generate_sample_bbox_square, bbox_area_sampling,
is_poly, transform_bbox)
is_poly, transform_bbox, get_border)
from ppdet.utils.logger import setup_logger
from ppdet.modeling.keypoint_utils import get_affine_transform, affine_transform
logger = setup_logger(__name__)
registered_ops = []
......@@ -2962,3 +2963,189 @@ class RandomSizeCrop(BaseOperator):
region = self.get_crop_params(sample['image'].shape[:2], [h, w])
return self.crop(sample, region)
@register_op
class WarpAffine(BaseOperator):
def __init__(self,
keep_res=False,
pad=31,
input_h=512,
input_w=512,
scale=0.4,
shift=0.1):
"""WarpAffine
Warp affine the image
"""
super(WarpAffine, self).__init__()
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.scale = scale
self.shift = shift
def apply(self, sample, context=None):
img = sample['image']
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
return sample
h, w = img.shape[:2]
if self.keep_res:
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
sample['image'] = inp
return sample
@register_op
class FlipWarpAffine(BaseOperator):
def __init__(self,
keep_res=False,
pad=31,
input_h=512,
input_w=512,
not_rand_crop=False,
scale=0.4,
shift=0.1,
flip=0.5,
is_scale=True,
use_random=True):
"""FlipWarpAffine
1. Random Crop
2. Flip the image horizontal
3. Warp affine the image
"""
super(FlipWarpAffine, self).__init__()
self.keep_res = keep_res
self.pad = pad
self.input_h = input_h
self.input_w = input_w
self.not_rand_crop = not_rand_crop
self.scale = scale
self.shift = shift
self.flip = flip
self.is_scale = is_scale
self.use_random = use_random
def apply(self, sample, context=None):
img = sample['image']
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0:
return sample
h, w = img.shape[:2]
if self.keep_res:
input_h = (h | self.pad) + 1
input_w = (w | self.pad) + 1
s = np.array([input_w, input_h], dtype=np.float32)
c = np.array([w // 2, h // 2], dtype=np.float32)
else:
s = max(h, w) * 1.0
input_h, input_w = self.input_h, self.input_w
c = np.array([w / 2., h / 2.], dtype=np.float32)
if self.use_random:
gt_bbox = sample['gt_bbox']
if not self.not_rand_crop:
s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
w_border = get_border(128, w)
h_border = get_border(128, h)
c[0] = np.random.randint(low=w_border, high=w - w_border)
c[1] = np.random.randint(low=h_border, high=h - h_border)
else:
sf = self.scale
cf = self.shift
c[0] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf)
c[1] += s * np.clip(np.random.randn() * cf, -2 * cf, 2 * cf)
s = s * np.clip(np.random.randn() * sf + 1, 1 - sf, 1 + sf)
if np.random.random() < self.flip:
img = img[:, ::-1, :]
c[0] = w - c[0] - 1
oldx1 = gt_bbox[:, 0].copy()
oldx2 = gt_bbox[:, 2].copy()
gt_bbox[:, 0] = w - oldx2 - 1
gt_bbox[:, 2] = w - oldx1 - 1
sample['gt_bbox'] = gt_bbox
trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
if not self.use_random:
img = cv2.resize(img, (w, h))
inp = cv2.warpAffine(
img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
if self.is_scale:
inp = (inp.astype(np.float32) / 255.)
sample['image'] = inp
sample['center'] = c
sample['scale'] = s
return sample
@register_op
class CenterRandColor(BaseOperator):
"""Random color for CenterNet series models.
Args:
saturation (float): saturation settings.
contrast (float): contrast settings.
brightness (float): brightness settings.
"""
def __init__(self, saturation=0.4, contrast=0.4, brightness=0.4):
super(CenterRandColor, self).__init__()
self.saturation = saturation
self.contrast = contrast
self.brightness = brightness
def apply_saturation(self, img, img_gray):
alpha = 1. + np.random.uniform(
low=-self.saturation, high=self.saturation)
self._blend(alpha, img, img_gray[:, :, None])
return img
def apply_contrast(self, img, img_gray):
alpha = 1. + np.random.uniform(low=-self.contrast, high=self.contrast)
img_mean = img_gray.mean()
self._blend(alpha, img, img_mean)
return img
def apply_brightness(self, img, img_gray):
alpha = 1 + np.random.uniform(
low=-self.brightness, high=self.brightness)
img *= alpha
return img
def _blend(self, alpha, img, img_mean):
img *= alpha
img_mean *= (1 - alpha)
img += img_mean
def __call__(self, sample, context=None):
img = sample['image']
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
functions = [
self.apply_brightness,
self.apply_contrast,
self.apply_saturation,
]
distortions = np.random.permutation(functions)
for func in distortions:
img = func(img, img_gray)
sample['image'] = img
return sample
......@@ -45,6 +45,7 @@ TRT_MIN_SUBGRAPH = {
'FairMOT': 5,
'GFL': 16,
'PicoDet': 3,
'CenterNet': 5,
}
KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
......
......@@ -37,6 +37,7 @@ class CenterNet(BaseArch):
"""
__category__ = 'architecture'
__inject__ = ['post_process']
__shared__ = ['for_mot']
def __init__(self,
backbone,
......@@ -71,6 +72,8 @@ class CenterNet(BaseArch):
head_out = self.head(neck_feat, self.inputs)
if self.for_mot:
head_out.update({'neck_feat': neck_feat})
elif self.training:
head_out['loss'] = head_out.pop('det_loss')
return head_out
def get_pred(self):
......
......@@ -28,7 +28,7 @@ class BaseArch(nn.Layer):
'mean']).reshape((1, 3, 1, 1))
self.std = paddle.to_tensor(item['NormalizeImage'][
'std']).reshape((1, 3, 1, 1))
if item['NormalizeImage']['is_scale']:
if item['NormalizeImage'].get('is_scale', True):
self.scale = 1. / 255.
break
if self.data_format == 'NHWC':
......
......@@ -16,7 +16,7 @@ import math
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingUniform
from paddle.nn.initializer import Constant, Uniform
from ppdet.core.workspace import register
from ppdet.modeling.losses import CTFocalLoss
......@@ -35,10 +35,9 @@ class ConvLayer(nn.Layer):
bias_attr = False
fan_in = ch_in * kernel_size**2
bound = 1 / math.sqrt(fan_in)
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
param_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
if bias:
bias_attr = paddle.ParamAttr(
initializer=nn.initializer.Uniform(-bound, bound))
bias_attr = paddle.ParamAttr(initializer=Constant(0.))
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
......
......@@ -48,8 +48,7 @@ def get_affine_transform(center,
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
input_size (np.ndarray[2, ]): Size of input feature (width, height).
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
......@@ -61,10 +60,11 @@ def get_affine_transform(center,
np.ndarray: The transform matrix.
"""
assert len(center) == 2
assert len(input_size) == 2
assert len(output_size) == 2
assert len(shift) == 2
if not isinstance(input_size, (np.ndarray, list)):
input_size = np.array([input_size, input_size], dtype=np.float32)
scale_tmp = input_size
shift = np.array(shift)
......@@ -77,6 +77,7 @@ def get_affine_transform(center,
dst_dir = np.array([0., dst_w * -0.5])
src = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
......@@ -108,8 +109,10 @@ def _get_3rd_point(a, b):
Returns:
np.ndarray: The 3rd point.
"""
assert len(a) == 2
assert len(b) == 2
assert len(
a) == 2, 'input of _get_3rd_point should be point with length of 2'
assert len(
b) == 2, 'input of _get_3rd_point should be point with length of 2'
direction = a - b
third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
......
......@@ -53,11 +53,12 @@ class CTFocalLoss(object):
bg_map = paddle.cast(target < 1, 'float32')
bg_map.stop_gradient = True
neg_weights = paddle.pow(1 - target, 4) * bg_map
neg_weights = paddle.pow(1 - target, 4)
pos_loss = 0 - paddle.log(pred) * paddle.pow(1 - pred,
self.gamma) * fg_map
neg_loss = 0 - paddle.log(1 - pred) * paddle.pow(
pred, self.gamma) * neg_weights
pred, self.gamma) * neg_weights * bg_map
pos_loss = paddle.sum(pos_loss)
neg_loss = paddle.sum(neg_loss)
......
......@@ -16,8 +16,9 @@ import numpy as np
import math
import paddle
import paddle.nn as nn
from paddle import ParamAttr
from paddle.nn.initializer import Uniform
import paddle.nn.functional as F
from paddle.nn.initializer import KaimingUniform
from ppdet.core.workspace import register, serializable
from ppdet.modeling.layers import ConvNormLayer
from ppdet.modeling.backbones.hardnet import ConvLayer, HarDBlock
......@@ -27,7 +28,7 @@ __all__ = ['CenterNetDLAFPN', 'CenterNetHarDNetFPN']
def fill_up_weights(up):
weight = up.weight
weight = up.weight.numpy()
f = math.ceil(weight.shape[2] / 2)
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(weight.shape[2]):
......@@ -36,6 +37,7 @@ def fill_up_weights(up):
(1 - math.fabs(i / f - c)) * (1 - math.fabs(j / f - c))
for c in range(1, weight.shape[0]):
weight[c, 0, :, :] = weight[0, 0, :, :]
up.weight.set_value(weight)
class IDAUp(nn.Layer):
......@@ -44,6 +46,8 @@ class IDAUp(nn.Layer):
for i in range(1, len(ch_ins)):
ch_in = ch_ins[i]
up_s = int(up_strides[i])
fan_in = ch_in * 3 * 3
stdv = 1. / math.sqrt(fan_in)
proj = nn.Sequential(
ConvNormLayer(
ch_in,
......@@ -54,7 +58,8 @@ class IDAUp(nn.Layer):
bias_on=dcn_v2,
norm_decay=None,
dcn_lr_scale=1.,
dcn_regularizer=None),
dcn_regularizer=None,
initializer=Uniform(-stdv, stdv)),
nn.ReLU())
node = nn.Sequential(
ConvNormLayer(
......@@ -66,21 +71,23 @@ class IDAUp(nn.Layer):
bias_on=dcn_v2,
norm_decay=None,
dcn_lr_scale=1.,
dcn_regularizer=None),
dcn_regularizer=None,
initializer=Uniform(-stdv, stdv)),
nn.ReLU())
param_attr = paddle.ParamAttr(initializer=KaimingUniform())
kernel_size = up_s * 2
fan_in = ch_out * kernel_size * kernel_size
stdv = 1. / math.sqrt(fan_in)
up = nn.Conv2DTranspose(
ch_out,
ch_out,
kernel_size=up_s * 2,
weight_attr=param_attr,
stride=up_s,
padding=up_s // 2,
groups=ch_out,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=False)
# TODO: uncomment fill_up_weights
#fill_up_weights(up)
fill_up_weights(up)
setattr(self, 'proj_' + str(i), proj)
setattr(self, 'up_' + str(i), up)
setattr(self, 'node_' + str(i), node)
......@@ -138,9 +145,10 @@ class CenterNetDLAFPN(nn.Layer):
last_level (int): the last level of input feature fed into the upsamplng block
out_channel (int): the channel of the output feature, 0 by default means
the channel of the input feature whose down ratio is `down_ratio`
first_level (int): the first level of input feature fed into the upsamplng
block, -1 by default and it will be calculated by down_ratio
dcn_v2 (bool): whether use the DCNv2, true by default
first_level (int|None): the first level of input feature fed into the upsamplng block.
if None, the first level stands for logs(down_ratio)
"""
def __init__(self,
......@@ -148,11 +156,13 @@ class CenterNetDLAFPN(nn.Layer):
down_ratio=4,
last_level=5,
out_channel=0,
first_level=-1,
dcn_v2=True):
dcn_v2=True,
first_level=None):
super(CenterNetDLAFPN, self).__init__()
self.first_level = int(np.log2(
down_ratio)) if first_level == -1 else first_level
down_ratio)) if first_level is None else first_level
assert self.first_level >= 0, "first level in CenterNetDLAFPN should be greater or equal to 0, but received {}".format(
self.first_level)
self.down_ratio = down_ratio
self.last_level = last_level
scales = [2**i for i in range(len(in_channels[self.first_level:]))]
......@@ -212,8 +222,9 @@ class CenterNetHarDNetFPN(nn.Layer):
[96, 214, 458, 784] by default, means the channels of HarDNet85
num_layers (int): HarDNet laters, 85 by default
down_ratio (int): the down ratio from images to heatmap, 4 by default
first_level (int): the first level of input feature fed into the
upsamplng block
first_level (int|None): the first level of input feature fed into the upsamplng block.
if None, the first level stands for logs(down_ratio) - 1
last_level (int): the last level of input feature fed into the upsamplng block
out_channel (int): the channel of the output feature, 0 by default means
the channel of the input feature whose down ratio is `down_ratio`
......@@ -223,17 +234,20 @@ class CenterNetHarDNetFPN(nn.Layer):
in_channels,
num_layers=85,
down_ratio=4,
first_level=-1,
first_level=None,
last_level=4,
out_channel=0):
super(CenterNetHarDNetFPN, self).__init__()
self.first_level = int(np.log2(
down_ratio)) - 1 if first_level == -1 else first_level
down_ratio)) - 1 if first_level is None else first_level
assert self.first_level >= 0, "first level in CenterNetDLAFPN should be greater or equal to 0, but received {}".format(
self.first_level)
self.down_ratio = down_ratio
self.last_level = last_level
self.last_pool = nn.AvgPool2D(kernel_size=2, stride=2)
assert num_layers in [68, 85], "HarDNet-{} not support.".format(num_layers)
assert num_layers in [68, 85], "HarDNet-{} not support.".format(
num_layers)
if num_layers == 85:
self.last_proj = ConvLayer(784, 256, kernel_size=1)
self.last_blk = HarDBlock(768, 80, 1.7, 8)
......
......@@ -418,7 +418,7 @@ class CenterNetPostProcess(TTFBox):
"""
__shared__ = ['down_ratio']
__shared__ = ['down_ratio', 'for_mot']
def __init__(self,
max_per_img=500,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册