未验证 提交 d43e6d9a 编写于 作者: W wangguanzhong 提交者: GitHub

add ttfnet (#1054)

上级 315fd738
......@@ -12,10 +12,12 @@
## 模型库与基线
下表中展示了PaddleDetection当前支持的网络结构,具体细节请参考[算法细节](#算法细节)
| | ResNet50 | ResNet50-vd | Hourglass104 |
|:------------------------:|:--------:|:--------------------------:|:------------------------:|
| [CornerNet-Squeeze](#CornerNet-Squeeze) | x | ✓ | ✓ |
| [FCOS](#FCOS) | ✓ | x | x |
| | ResNet50 | ResNet50-vd | Hourglass104 | DarkNet53
|:------------------------:|:--------:|:-------------:|:-------------:|:-------------:|
| [CornerNet-Squeeze](#CornerNet-Squeeze) | x | ✓ | ✓ |x |
| [FCOS](#FCOS) | ✓ | x | x | x |
| [TTFNet](#TTFNet) | x | x | x | ✓ |
### 模型库
......@@ -31,6 +33,7 @@
| FCOS | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 39.8 | 18.85 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_1x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_r50_fpn_1x.yml) |
| FCOS+multiscale_train | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 42.0 | 19.05 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_r50_fpn_multiscale_2x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_r50_fpn_multiscale_2x.yml) |
| FCOS+DCN | ResNet50 | 2 | [ResNet50\_cos\_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_cos_pretrained.tar) | 44.4 | 13.66 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/fcos_dcn_r50_fpn_1x.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/fcos_dcn_r50_fpn_1x.yml) |
| TTFNet | DarkNet53 | 12 | [DarkNet53_pretrained](https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar) | 32.9 | 85.92 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/ttfnet_darknet.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/configs/anchor_free/ttfnet_darknet.yml) |
**注意:**
......@@ -64,5 +67,14 @@
- 通过center-ness单层分支预测当前点是否是目标中心,消除低质量误检
## TTFNet
**简介:** [TTFNet](https://arxiv.org/abs/1909.00700)是一种用于实时目标检测且对训练时间友好的网络,对CenterNet收敛速度慢的问题进行改进,提出了利用高斯核生成训练样本的新方法,有效的消除了anchor-free head中存在的模糊性。同时简单轻量化的网络结构也易于进行任务扩展。
**特点:**
- 结构简单,仅需要两个head检测目标位置和大小,并且去除了耗时的后处理操作
- 训练时间短,基于DarkNet53的骨干网路,V100 8卡仅需要训练2个小时即可达到较好的模型效果
## 如何贡献代码
我们非常欢迎您可以为PaddleDetection中的Anchor Free检测模型提供代码,您可以提交PR供我们review;也十分感谢您的反馈,可以提交相应issue,我们会及时解答。
architecture: TTFNet
use_gpu: true
max_iters: 15000
log_smooth_window: 20
save_dir: output
snapshot_iter: 1000
metric: COCO
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
weights: output/ttfnet_darknet/model_final
num_classes: 80
TTFNet:
backbone: DarkNet
ttf_head: TTFHead
DarkNet:
norm_type: bn
norm_decay: 0.0004
depth: 53
freeze_at: 1
TTFHead:
head_conv: 128
wh_conv: 64
hm_head_conv_num: 2
wh_head_conv_num: 2
wh_offset_base: 16
wh_loss: GiouLoss
GiouLoss:
loss_weight: 5.
do_average: false
use_class_weight: false
LearningRate:
base_lr: 0.015
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 11250
- 13750
- !LinearWarmup
start_factor: 0.2
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0004
type: L2
TrainReader:
inputs_def:
fields: ['image', 'ttf_heatmap', 'ttf_box_target', 'ttf_reg_weight']
dataset:
!COCODataSet
image_dir: train2017
anno_path: annotations/instances_train2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: true
- !Resize
target_dim: 512
- !RandomFlipImage
prob: 0.5
- !NormalizeImage
is_channel_first: false
is_scale: false
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
- !Permute
to_bgr: false
channel_first: true
batch_transforms:
- !Gt2TTFTarget
num_classes: 80
down_ratio: 4
- !PadBatch
pad_to_stride: 32
batch_size: 12
shuffle: true
worker_num: 8
bufsize: 2
use_process: true
EvalReader:
inputs_def:
image_shape: [3, 512, 512]
fields: ['image', 'im_id', 'scale_factor']
dataset:
!COCODataSet
image_dir: val2017
anno_path: annotations/instances_val2017.json
dataset_dir: dataset/coco
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !Resize
target_dim: 512
- !NormalizeImage
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
is_scale: false
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
batch_size: 1
drop_empty: false
worker_num: 8
bufsize: 16
TestReader:
inputs_def:
image_shape: [3, 512, 512]
fields: ['image', 'im_id', 'scale_factor']
dataset:
!ImageFolder
anno_path: annotations/instances_val2017.json
with_background: false
sample_transforms:
- !DecodeImage
to_rgb: True
- !Resize
interp: 1
target_dim: 512
- !NormalizeImage
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
is_scale: false
is_channel_first: false
- !Permute
to_bgr: false
channel_first: True
batch_size: 1
......@@ -115,8 +115,7 @@ class Resize(object):
padding_im[:im_h, :im_w, :] = im
im = padding_im
if self.arch in self.scale_set:
im_info['scale'] = im_scale_x
im_info['scale'] = [im_scale_x, im_scale_y]
im_info['resize_shape'] = im.shape[:2]
return im, im_info
......@@ -252,18 +251,23 @@ def create_inputs(im, im_info, model_arch='YOLO'):
inputs['image'] = im
origin_shape = list(im_info['origin_shape'])
resize_shape = list(im_info['resize_shape'])
scale = im_info['scale']
scale_x, scale_y = im_info['scale']
if 'YOLO' in model_arch:
im_size = np.array([origin_shape]).astype('int32')
inputs['im_size'] = im_size
elif 'RetinaNet' in model_arch:
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
inputs['im_info'] = im_info
elif 'RCNN' in model_arch:
scale = scale_x
im_info = np.array([resize_shape + [scale]]).astype('float32')
im_shape = np.array([origin_shape + [1.]]).astype('float32')
inputs['im_info'] = im_info
inputs['im_shape'] = im_shape
elif 'TTF' in model_arch:
scale_factor = np.array([scale_x, scale_y] * 2).astype('float32')
inputs['scale_factor'] = scale_factor
return inputs
......@@ -272,7 +276,7 @@ class Config():
Args:
model_dir (str): root path of model.yml
"""
support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face']
support_models = ['YOLO', 'SSD', 'RetinaNet', 'RCNN', 'Face', 'TTF']
def __init__(self, model_dir):
# parsing Yaml config for Preprocess
......@@ -298,9 +302,8 @@ class Config():
for support_model in self.support_models:
if support_model in yml_conf['arch']:
return True
raise ValueError(
"Unsupported arch: {}, expect SSD, YOLO, RetinaNet, RCNN and Face".
format(yml_conf['arch']))
raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
'arch'], self.support_models))
def print_config(self):
print('----------- Model Configuration -----------')
......@@ -450,7 +453,7 @@ class Detector():
np_boxes[:, 3] *= w
np_boxes[:, 4] *= h
np_boxes[:, 5] *= w
expect_boxes = np_boxes[:, 1] > threshold
expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
np_boxes = np_boxes[expect_boxes, :]
for box in np_boxes:
print('class_id:{:d}, confidence:{:.2f},'
......
......@@ -26,13 +26,17 @@ import cv2
import numpy as np
from .operators import register_op, BaseOperator
from .op_helper import jaccard_overlap
from .op_helper import jaccard_overlap, gaussian2D
logger = logging.getLogger(__name__)
__all__ = [
'PadBatch', 'RandomShape', 'PadMultiScaleTest', 'Gt2YoloTarget',
'Gt2FCOSTarget'
'PadBatch',
'RandomShape',
'PadMultiScaleTest',
'Gt2YoloTarget',
'Gt2FCOSTarget',
'Gt2TTFTarget',
]
......@@ -41,7 +45,6 @@ class PadBatch(BaseOperator):
"""
Pad a batch of samples so they can be divisible by a stride.
The layout of each image should be 'CHW'.
Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
height and width is divisible by `pad_to_stride`.
......@@ -89,13 +92,12 @@ class RandomShape(BaseOperator):
select one an interpolation algorithm [cv2.INTER_NEAREST, cv2.INTER_LINEAR,
cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]. If random_inter is
False, use cv2.INTER_NEAREST.
Args:
sizes (list): list of int, random choose a size from these
random_inter (bool): whether to randomly interpolation, defalut true.
"""
def __init__(self, sizes=[], random_inter=False):
def __init__(self, sizes=[], random_inter=False, resize_box=False):
super(RandomShape, self).__init__()
self.sizes = sizes
self.random_inter = random_inter
......@@ -106,6 +108,7 @@ class RandomShape(BaseOperator):
cv2.INTER_CUBIC,
cv2.INTER_LANCZOS4,
] if random_inter else []
self.resize_box = resize_box
def __call__(self, samples, context=None):
shape = np.random.choice(self.sizes)
......@@ -119,6 +122,12 @@ class RandomShape(BaseOperator):
im = cv2.resize(
im, None, None, fx=scale_x, fy=scale_y, interpolation=method)
samples[i]['image'] = im
if self.resize_box and 'gt_bbox' in samples[i] and len(samples[0][
'gt_bbox']) > 0:
scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32)
samples[i]['gt_bbox'] = np.clip(samples[i]['gt_bbox'] *
scale_array, 0,
float(shape) - 1)
return samples
......@@ -478,3 +487,99 @@ class Gt2FCOSTarget(BaseOperator):
sample['centerness{}'.format(lvl)] = np.reshape(
ctn_targets_by_level[lvl], newshape=[grid_h, grid_w, 1])
return samples
@register_op
class Gt2TTFTarget(BaseOperator):
"""
Gt2TTFTarget
Generate TTFNet targets by ground truth data
Args:
num_classes(int): the number of classes.
down_ratio(int): the down ratio from images to heatmap, 4 by default.
alpha(float): the alpha parameter to generate gaussian target.
0.54 by default.
"""
def __init__(self, num_classes, down_ratio=4, alpha=0.54):
super(Gt2TTFTarget, self).__init__()
self.down_ratio = down_ratio
self.num_classes = num_classes
self.alpha = alpha
def __call__(self, samples, context=None):
output_size = samples[0]['image'].shape[1]
feat_size = output_size // self.down_ratio
for sample in samples:
heatmap = np.zeros(
(self.num_classes, feat_size, feat_size), dtype='float32')
box_target = np.ones(
(4, feat_size, feat_size), dtype='float32') * -1
reg_weight = np.zeros((1, feat_size, feat_size), dtype='float32')
gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class']
bbox_w = gt_bbox[:, 2] - gt_bbox[:, 0] + 1
bbox_h = gt_bbox[:, 3] - gt_bbox[:, 1] + 1
area = bbox_w * bbox_h
boxes_areas_log = np.log(area)
boxes_ind = np.argsort(boxes_areas_log, axis=0)[::-1]
boxes_area_topk_log = boxes_areas_log[boxes_ind]
gt_bbox = gt_bbox[boxes_ind]
gt_class = gt_class[boxes_ind]
feat_gt_bbox = gt_bbox / self.down_ratio
feat_gt_bbox = np.clip(feat_gt_bbox, 0, feat_size - 1)
feat_hs, feat_ws = (feat_gt_bbox[:, 3] - feat_gt_bbox[:, 1],
feat_gt_bbox[:, 2] - feat_gt_bbox[:, 0])
ct_inds = np.stack(
[(gt_bbox[:, 0] + gt_bbox[:, 2]) / 2,
(gt_bbox[:, 1] + gt_bbox[:, 3]) / 2],
axis=1) / self.down_ratio
h_radiuses_alpha = (feat_hs / 2. * self.alpha).astype('int32')
w_radiuses_alpha = (feat_ws / 2. * self.alpha).astype('int32')
for k in range(len(gt_bbox)):
cls_id = gt_class[k]
fake_heatmap = np.zeros((feat_size, feat_size), dtype='float32')
self.draw_truncate_gaussian(fake_heatmap, ct_inds[k],
h_radiuses_alpha[k],
w_radiuses_alpha[k])
heatmap[cls_id] = np.maximum(heatmap[cls_id], fake_heatmap)
box_target_inds = fake_heatmap > 0
box_target[:, box_target_inds] = gt_bbox[k][:, None]
local_heatmap = fake_heatmap[box_target_inds]
ct_div = np.sum(local_heatmap)
local_heatmap *= boxes_area_topk_log[k]
reg_weight[0, box_target_inds] = local_heatmap / ct_div
sample['ttf_heatmap'] = heatmap
sample['ttf_box_target'] = box_target
sample['ttf_reg_weight'] = reg_weight
return samples
def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius):
h, w = 2 * h_radius + 1, 2 * w_radius + 1
sigma_x = w / 6
sigma_y = h / 6
gaussian = gaussian2D((h, w), sigma_x, sigma_y)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, w_radius), min(width - x, w_radius + 1)
top, bottom = min(y, h_radius), min(height - y, h_radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[h_radius - top:h_radius + bottom, w_radius -
left:w_radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
heatmap[y - top:y + bottom, x - left:x + right] = np.maximum(
masked_heatmap, masked_gaussian)
return heatmap
......@@ -438,7 +438,8 @@ def gaussian_radius(bbox_size, min_overlap):
def draw_gaussian(heatmap, center, radius, k=1, delte=6):
diameter = 2 * radius + 1
gaussian = gaussian2D((diameter, diameter), sigma=diameter / delte)
sigma = diameter / delte
gaussian = gaussian2D((diameter, diameter), sigma_x=sigma, sigma_y=sigma)
x, y = center
......@@ -453,10 +454,11 @@ def draw_gaussian(heatmap, center, radius, k=1, delte=6):
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
def gaussian2D(shape, sigma=1):
def gaussian2D(shape, sigma_x=1, sigma_y=1):
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
sigma_y)))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
......@@ -92,7 +92,6 @@ class BaseOperator(object):
class DecodeImage(BaseOperator):
def __init__(self, to_rgb=True, with_mixup=False, with_cutmix=False):
""" Transform the image data to numpy format.
Args:
to_rgb (bool): whether to convert BGR to RGB
with_mixup (bool): whether or not to mixup image and gt_bbbox/gt_score
......@@ -165,7 +164,6 @@ class MultiscaleTestResize(BaseOperator):
use_flip=True):
"""
Rescale image to the each size in target size, and capped at max_size.
Args:
origin_target_size(int): original target size of image's short side.
origin_max_size(int): original max size of image.
......@@ -274,7 +272,6 @@ class ResizeImage(BaseOperator):
if max_size != 0.
If target_size is list, selected a scale randomly as the specified
target size.
Args:
target_size (int|list): the target size of image's short side,
multi-scale training is adopted when type is list.
......@@ -1177,7 +1174,6 @@ class Permute(BaseOperator):
Args:
to_bgr (bool): confirm whether to convert RGB to BGR
channel_first (bool): confirm whether to change channel
"""
super(Permute, self).__init__()
self.to_bgr = to_bgr
......@@ -1386,7 +1382,6 @@ class RandomInterpImage(BaseOperator):
@register_op
class Resize(BaseOperator):
"""Resize image and bbox.
Args:
target_dim (int or list): target size, can be a single number or a list
(for random shape).
......@@ -1419,6 +1414,7 @@ class Resize(BaseOperator):
scale_array = np.array([scale_x, scale_y] * 2, dtype=np.float32)
sample['gt_bbox'] = np.clip(sample['gt_bbox'] * scale_array, 0,
dim - 1)
sample['scale_factor'] = [scale_x, scale_y] * 2
sample['h'] = resize_h
sample['w'] = resize_w
......@@ -1430,7 +1426,6 @@ class Resize(BaseOperator):
@register_op
class ColorDistort(BaseOperator):
"""Random color distortion.
Args:
hue (list): hue settings.
in [lower, upper, probability] format.
......@@ -1442,6 +1437,8 @@ class ColorDistort(BaseOperator):
in [lower, upper, probability] format.
random_apply (bool): whether to apply in random (yolo) or fixed (SSD)
order.
hsv_format (bool): whether to convert color from BGR to HSV
random_channel (bool): whether to swap channels randomly
"""
def __init__(self,
......@@ -1449,13 +1446,17 @@ class ColorDistort(BaseOperator):
saturation=[0.5, 1.5, 0.5],
contrast=[0.5, 1.5, 0.5],
brightness=[0.5, 1.5, 0.5],
random_apply=True):
random_apply=True,
hsv_format=False,
random_channel=False):
super(ColorDistort, self).__init__()
self.hue = hue
self.saturation = saturation
self.contrast = contrast
self.brightness = brightness
self.random_apply = random_apply
self.hsv_format = hsv_format
self.random_channel = random_channel
def apply_hue(self, img):
low, high, prob = self.hue
......@@ -1463,6 +1464,11 @@ class ColorDistort(BaseOperator):
return img
img = img.astype(np.float32)
if self.hsv_format:
img[..., 0] += random.uniform(low, high)
img[..., 0][img[..., 0] > 360] -= 360
img[..., 0][img[..., 0] < 0] += 360
return img
# XXX works, but result differ from HSV version
delta = np.random.uniform(low, high)
......@@ -1482,8 +1488,10 @@ class ColorDistort(BaseOperator):
if np.random.uniform(0., 1.) < prob:
return img
delta = np.random.uniform(low, high)
img = img.astype(np.float32)
if self.hsv_format:
img[..., 1] *= delta
return img
gray = img * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
gray = gray.sum(axis=2, keepdims=True)
gray *= (1.0 - delta)
......@@ -1530,12 +1538,24 @@ class ColorDistort(BaseOperator):
if np.random.randint(0, 2):
img = self.apply_contrast(img)
if self.hsv_format:
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
img = self.apply_saturation(img)
img = self.apply_hue(img)
if self.hsv_format:
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
else:
if self.hsv_format:
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
img = self.apply_saturation(img)
img = self.apply_hue(img)
if self.hsv_format:
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
img = self.apply_contrast(img)
if self.random_channel:
if np.random.randint(0, 2):
img = img[..., np.random.permutation(3)]
sample['image'] = img
return sample
......@@ -1603,7 +1623,6 @@ class CornerRandColor(ColorDistort):
@register_op
class NormalizePermute(BaseOperator):
"""Normalize and permute channel order.
Args:
mean (list): mean values in RGB order.
std (list): std values in RGB order.
......@@ -1633,7 +1652,6 @@ class NormalizePermute(BaseOperator):
@register_op
class RandomExpand(BaseOperator):
"""Random expand the canvas.
Args:
ratio (float): maximum expansion ratio.
prob (float): probability to expand.
......@@ -1725,7 +1743,6 @@ class RandomExpand(BaseOperator):
@register_op
class RandomCrop(BaseOperator):
"""Random crop image and bboxes.
Args:
aspect_ratio (list): aspect ratio of cropped region.
in [min, max] format.
......@@ -1852,11 +1869,23 @@ class RandomCrop(BaseOperator):
found = False
for i in range(self.num_attempts):
scale = np.random.uniform(*self.scaling)
min_ar, max_ar = self.aspect_ratio
aspect_ratio = np.random.uniform(
max(min_ar, scale**2), min(max_ar, scale**-2))
crop_h = int(h * scale / np.sqrt(aspect_ratio))
crop_w = int(w * scale * np.sqrt(aspect_ratio))
if self.aspect_ratio is not None:
min_ar, max_ar = self.aspect_ratio
aspect_ratio = np.random.uniform(
max(min_ar, scale**2), min(max_ar, scale**-2))
h_scale = scale / np.sqrt(aspect_ratio)
w_scale = scale * np.sqrt(aspect_ratio)
else:
h_scale = np.random.uniform(*self.scaling)
w_scale = np.random.uniform(*self.scaling)
crop_h = h * h_scale
crop_w = w * w_scale
if self.aspect_ratio is None:
if crop_h / crop_w < 0.5 or crop_h / crop_w > 2.0:
continue
crop_h = int(crop_h)
crop_w = int(crop_w)
crop_y = np.random.randint(0, h - crop_h)
crop_x = np.random.randint(0, w - crop_w)
crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
......@@ -2008,7 +2037,6 @@ class BboxXYXY2XYWH(BaseOperator):
return sample
@register_op
class Lighting(BaseOperator):
"""
Lighting the imagen by eigenvalues and eigenvectors
......@@ -2248,7 +2276,6 @@ class CornerRatio(BaseOperator):
class RandomScaledCrop(BaseOperator):
"""Resize image and bbox based on long side (with optional random scaling),
then crop or pad image to target size.
Args:
target_dim (int): target size.
scale_range (list): random scale range.
......@@ -2303,7 +2330,6 @@ class RandomScaledCrop(BaseOperator):
@register_op
class ResizeAndPad(BaseOperator):
"""Resize image and bbox, then pad image to target size.
Args:
target_dim (int): target size
interp (int): interpolation method, default to `cv2.INTER_LINEAR`.
......@@ -2342,7 +2368,6 @@ class ResizeAndPad(BaseOperator):
@register_op
class TargetAssign(BaseOperator):
"""Assign regression target and labels.
Args:
image_size (int or list): input image size, a single integer or list of
[h, w]. Default: 512
......
......@@ -20,6 +20,7 @@ from . import retina_head
from . import fcos_head
from . import corner_head
from . import efficient_head
from . import ttf_head
from .rpn_head import *
from .yolo_head import *
......@@ -27,3 +28,4 @@ from .retina_head import *
from .fcos_head import *
from .corner_head import *
from .efficient_head import *
from .ttf_head import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.initializer import Normal, Constant, Uniform, Xavier
from paddle.fluid.regularizer import L2Decay
from ppdet.core.workspace import register
from ppdet.modeling.ops import DeformConv, DropBlock
from ppdet.modeling.losses import GiouLoss
__all__ = ['TTFHead']
@register
class TTFHead(object):
"""
TTFHead
Args:
head_conv(int): the default channel number of convolution in head.
128 by default.
num_classes(int): the number of classes, 80 by default.
hm_weight(float): the weight of heatmap branch. 1. by default.
wh_weight(float): the weight of wh branch. 5. by default.
wh_offset_base(flaot): the base offset of width and height.
16. by default.
planes(tuple): the channel number of convolution in each upsample.
(256, 128, 64) by default.
shortcut_num(tuple): the number of convolution layers in each shortcut.
(1, 2, 3) by default.
wh_head_conv_num(int): the number of convolution layers in wh head.
2 by default.
hm_head_conv_num(int): the number of convolution layers in wh head.
2 by default.
wh_conv(int): the channel number of convolution in wh head.
64 by default.
wh_planes(int): the output channel in wh head. 4 by default.
score_thresh(float): the score threshold to get prediction.
0.01 by default.
max_per_img(int): the maximum detection per image. 100 by default.
base_down_ratio(int): the base down_ratio, the actual down_ratio is
calculated by base_down_ratio and the number of upsample layers.
16 by default.
wh_loss(object): `GiouLoss` instance.
dcn_upsample(bool): whether upsample by dcn. True by default.
dcn_head(bool): whether use dcn in head. False by default.
drop_block(bool): whether use dropblock. False by default.
block_size(int): block_size parameter for drop_block. 3 by default.
keep_prob(float): keep_prob parameter for drop_block. 0.9 by default.
"""
__inject__ = ['wh_loss']
__shared__ = ['num_classes']
def __init__(self,
head_conv=128,
num_classes=80,
hm_weight=1.,
wh_weight=5.,
wh_offset_base=16.,
planes=(256, 128, 64),
shortcut_num=(1, 2, 3),
wh_head_conv_num=2,
hm_head_conv_num=2,
wh_conv=64,
wh_planes=4,
score_thresh=0.01,
max_per_img=100,
base_down_ratio=32,
wh_loss='GiouLoss',
dcn_upsample=True,
dcn_head=False,
drop_block=False,
block_size=3,
keep_prob=0.9):
super(TTFHead, self).__init__()
self.head_conv = head_conv
self.num_classes = num_classes
self.hm_weight = hm_weight
self.wh_weight = wh_weight
self.wh_offset_base = wh_offset_base
self.planes = planes
self.shortcut_num = shortcut_num
self.shortcut_len = len(shortcut_num)
self.wh_head_conv_num = wh_head_conv_num
self.hm_head_conv_num = hm_head_conv_num
self.wh_conv = wh_conv
self.wh_planes = wh_planes
self.score_thresh = score_thresh
self.max_per_img = max_per_img
self.down_ratio = base_down_ratio // 2**len(planes)
self.hm_weight = hm_weight
self.wh_weight = wh_weight
self.wh_loss = wh_loss
self.dcn_upsample = dcn_upsample
self.dcn_head = dcn_head
self.drop_block = drop_block
self.block_size = block_size
self.keep_prob = keep_prob
def shortcut(self, x, out_c, layer_num, kernel_size=3, padding=1,
name=None):
assert layer_num > 0
for i in range(layer_num):
act = 'relu' if i < layer_num - 1 else None
fan_out = kernel_size * kernel_size * out_c
std = math.sqrt(2. / fan_out)
param_name = name + '.layers.' + str(i * 2)
x = fluid.layers.conv2d(
x,
out_c,
kernel_size,
padding=padding,
act=act,
param_attr=ParamAttr(
initializer=Normal(0, std), name=param_name + '.weight'),
bias_attr=ParamAttr(
learning_rate=2.,
regularizer=L2Decay(0.),
name=param_name + '.bias'))
return x
def upsample(self, x, out_c, name=None):
fan_in = x.shape[1] * 3 * 3
stdv = 1. / math.sqrt(fan_in)
if self.dcn_upsample:
conv = DeformConv(
x,
out_c,
3,
initializer=Uniform(-stdv, stdv),
bias_attr=True,
name=name + '.0')
else:
conv = fluid.layers.conv2d(
x,
out_c,
3,
padding=1,
param_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr(
learning_rate=2., regularizer=L2Decay(0.)))
norm_name = name + '.1'
pattr = ParamAttr(name=norm_name + '.weight', initializer=Constant(1.))
battr = ParamAttr(name=norm_name + '.bias', initializer=Constant(0.))
bn = fluid.layers.batch_norm(
input=conv,
act='relu',
param_attr=pattr,
bias_attr=battr,
name=norm_name + '.output.1',
moving_mean_name=norm_name + '.running_mean',
moving_variance_name=norm_name + '.running_var')
up = fluid.layers.resize_bilinear(
bn, scale=2, name=name + '.2.upsample')
return up
def _head(self,
x,
out_c,
conv_num=1,
head_out_c=None,
name=None,
is_test=False):
head_out_c = self.head_conv if not head_out_c else head_out_c
conv_w_std = 0.01 if '.hm' in name else 0.001
conv_w_init = Normal(0, conv_w_std)
for i in range(conv_num):
conv_name = '{}.{}.conv'.format(name, i)
if self.dcn_head:
x = DeformConv(
x,
head_out_c,
3,
initializer=conv_w_init,
name=conv_name + '.dcn')
x = fluid.layers.relu(x)
else:
x = fluid.layers.conv2d(
x,
head_out_c,
3,
padding=1,
param_attr=ParamAttr(
initializer=conv_w_init, name=conv_name + '.weight'),
bias_attr=ParamAttr(
learning_rate=2.,
regularizer=L2Decay(0.),
name=conv_name + '.bias'),
act='relu')
if self.drop_block and '.hm' in name:
x = DropBlock(
x,
block_size=self.block_size,
keep_prob=self.keep_prob,
is_test=is_test)
bias_init = float(-np.log((1 - 0.01) / 0.01)) if '.hm' in name else 0.
conv_b_init = Constant(bias_init)
x = fluid.layers.conv2d(
x,
out_c,
1,
param_attr=ParamAttr(
initializer=conv_w_init,
name='{}.{}.weight'.format(name, conv_num)),
bias_attr=ParamAttr(
learning_rate=2.,
regularizer=L2Decay(0.),
name='{}.{}.bias'.format(name, conv_num),
initializer=conv_b_init))
return x
def hm_head(self, x, name=None, is_test=False):
hm = self._head(
x,
self.num_classes,
self.hm_head_conv_num,
name=name,
is_test=is_test)
return hm
def wh_head(self, x, name=None):
planes = self.wh_planes
wh = self._head(
x, planes, self.wh_head_conv_num, self.wh_conv, name=name)
return fluid.layers.relu(wh)
def get_output(self, input, name=None, is_test=False):
feat = input[-1]
for i, out_c in enumerate(self.planes):
feat = self.upsample(
feat, out_c, name=name + '.deconv_layers.' + str(i))
if i < self.shortcut_len:
shortcut = self.shortcut(
input[-i - 2],
out_c,
self.shortcut_num[i],
name=name + '.shortcut_layers.' + str(i))
feat = fluid.layers.elementwise_add(feat, shortcut)
hm = self.hm_head(feat, name=name + '.hm', is_test=is_test)
wh = self.wh_head(feat, name=name + '.wh') * self.wh_offset_base
return hm, wh
def _simple_nms(self, heat, kernel=3):
pad = (kernel - 1) // 2
hmax = fluid.layers.pool2d(heat, kernel, 'max', pool_padding=pad)
keep = fluid.layers.cast(hmax == heat, 'float32')
return heat * keep
def _topk(self, scores, k):
cat, height, width = scores.shape[1:]
# batch size is 1
scores_r = fluid.layers.reshape(scores, [cat, -1])
topk_scores, topk_inds = fluid.layers.topk(scores_r, k)
topk_ys = topk_inds / width
topk_xs = topk_inds % width
topk_score_r = fluid.layers.reshape(topk_scores, [-1])
topk_score, topk_ind = fluid.layers.topk(topk_score_r, k)
topk_clses = fluid.layers.cast(topk_ind / k, 'float32')
topk_inds = fluid.layers.reshape(topk_inds, [-1])
topk_ys = fluid.layers.reshape(topk_ys, [-1, 1])
topk_xs = fluid.layers.reshape(topk_xs, [-1, 1])
topk_inds = fluid.layers.gather(topk_inds, topk_ind)
topk_ys = fluid.layers.gather(topk_ys, topk_ind)
topk_xs = fluid.layers.gather(topk_xs, topk_ind)
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
def get_bboxes(self, heatmap, wh, scale_factor):
heatmap = fluid.layers.sigmoid(heatmap)
heat = self._simple_nms(heatmap)
scores, inds, clses, ys, xs = self._topk(heat, self.max_per_img)
ys = fluid.layers.cast(ys, 'float32') * self.down_ratio
xs = fluid.layers.cast(xs, 'float32') * self.down_ratio
scores = fluid.layers.unsqueeze(scores, [1])
clses = fluid.layers.unsqueeze(clses, [1])
wh_t = fluid.layers.transpose(wh, [0, 2, 3, 1])
wh = fluid.layers.reshape(wh_t, [-1, wh_t.shape[-1]])
wh = fluid.layers.gather(wh, inds)
x1 = xs - wh[:, 0:1]
y1 = ys - wh[:, 1:2]
x2 = xs + wh[:, 2:3]
y2 = ys + wh[:, 3:4]
bboxes = fluid.layers.concat([x1, y1, x2, y2], axis=1)
bboxes = fluid.layers.elementwise_div(bboxes, scale_factor, axis=-1)
results = fluid.layers.concat([clses, scores, bboxes], axis=1)
# hack: append result with cls=-1 and score=1. to avoid all scores
# are less than score_thresh which may cause error in gather.
fill_r = fluid.layers.assign(
np.array(
[[-1, 1., 0, 0, 0, 0]], dtype='float32'))
results = fluid.layers.concat([results, fill_r])
scores = results[:, 1]
valid_ind = fluid.layers.where(scores > self.score_thresh)
results = fluid.layers.gather(results, valid_ind)
return {'bbox': results}
def ct_focal_loss(self, pred_hm, target_hm, gamma=2.0):
fg_map = fluid.layers.cast(target_hm == 1, 'float32')
fg_map.stop_gradient = True
bg_map = fluid.layers.cast(target_hm < 1, 'float32')
bg_map.stop_gradient = True
neg_weights = fluid.layers.pow(1 - target_hm, 4) * bg_map
pos_loss = 0 - fluid.layers.log(pred_hm) * fluid.layers.pow(
1 - pred_hm, gamma) * fg_map
neg_loss = 0 - fluid.layers.log(1 - pred_hm) * fluid.layers.pow(
pred_hm, gamma) * neg_weights
pos_loss = fluid.layers.reduce_sum(pos_loss)
neg_loss = fluid.layers.reduce_sum(neg_loss)
fg_num = fluid.layers.reduce_sum(fg_map)
focal_loss = (pos_loss + neg_loss) / (
fg_num + fluid.layers.cast(fg_num == 0, 'float32'))
return focal_loss
def filter_box_by_weight(self, pred, target, weight):
index = fluid.layers.where(weight > 0)
index.stop_gradient = True
weight = fluid.layers.gather_nd(weight, index)
pred = fluid.layers.gather_nd(pred, index)
target = fluid.layers.gather_nd(target, index)
return pred, target, weight
def get_loss(self, pred_hm, pred_wh, target_hm, box_target, target_weight):
pred_hm = paddle.tensor.clamp(
fluid.layers.sigmoid(pred_hm), 1e-4, 1 - 1e-4)
hm_loss = self.ct_focal_loss(pred_hm, target_hm) * self.hm_weight
shape = fluid.layers.shape(target_hm)
shape.stop_gradient = True
H, W = shape[2], shape[3]
mask = fluid.layers.reshape(target_weight, [-1, H, W])
avg_factor = fluid.layers.reduce_sum(mask) + 1e-4
base_step = self.down_ratio
zero = fluid.layers.fill_constant(shape=[1], value=0, dtype='int32')
shifts_x = paddle.arange(zero, W * base_step, base_step, dtype='int32')
shifts_y = paddle.arange(zero, H * base_step, base_step, dtype='int32')
shift_y, shift_x = paddle.tensor.meshgrid([shifts_y, shifts_x])
base_loc = fluid.layers.stack([shift_x, shift_y], axis=0)
base_loc.stop_gradient = True
pred_boxes = fluid.layers.concat(
[0 - pred_wh[:, 0:2, :, :] + base_loc, pred_wh[:, 2:4] + base_loc],
axis=1)
pred_boxes = fluid.layers.transpose(pred_boxes, [0, 2, 3, 1])
boxes = fluid.layers.transpose(box_target, [0, 2, 3, 1])
boxes.stop_gradient = True
pred_boxes, boxes, mask = self.filter_box_by_weight(pred_boxes, boxes,
mask)
mask.stop_gradient = True
wh_loss = self.wh_loss(
pred_boxes, boxes, outside_weight=mask, use_transform=False)
wh_loss = wh_loss / avg_factor
ttf_loss = {'hm_loss': hm_loss, 'wh_loss': wh_loss}
return ttf_loss
......@@ -27,6 +27,7 @@ from . import blazeface
from . import faceboxes
from . import fcos
from . import cornernet_squeeze
from . import ttfnet
from .faster_rcnn import *
from .mask_rcnn import *
......@@ -41,3 +42,4 @@ from .blazeface import *
from .faceboxes import *
from .fcos import *
from .cornernet_squeeze import *
from .ttfnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
from paddle import fluid
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
__all__ = ['TTFNet']
@register
class TTFNet(object):
"""
TTFNet network, see https://arxiv.org/abs/1909.00700
Args:
backbone (object): backbone instance
ttf_head (object): `TTFHead` instance
num_classes (int): the number of classes, 80 by default.
"""
__category__ = 'architecture'
__inject__ = ['backbone', 'ttf_head']
__shared__ = ['num_classes']
def __init__(self, backbone, ttf_head='TTFHead', num_classes=80):
super(TTFNet, self).__init__()
self.backbone = backbone
self.ttf_head = ttf_head
self.num_classes = num_classes
def build(self, feed_vars, mode='train'):
im = feed_vars['image']
mixed_precision_enabled = mixed_precision_global_state() is not None
# cast inputs to FP16
if mixed_precision_enabled:
im = fluid.layers.cast(im, 'float16')
body_feats = self.backbone(im)
if isinstance(body_feats, OrderedDict):
body_feat_names = list(body_feats.keys())
body_feats = [body_feats[name] for name in body_feat_names]
# cast features back to FP32
if mixed_precision_enabled:
body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats]
predict_hm, predict_wh = self.ttf_head.get_output(
body_feats, 'ttf_head', is_test=mode == 'test')
if mode == 'train':
heatmap = feed_vars['ttf_heatmap']
box_target = feed_vars['ttf_box_target']
reg_weight = feed_vars['ttf_reg_weight']
loss = self.ttf_head.get_loss(predict_hm, predict_wh, heatmap,
box_target, reg_weight)
total_loss = fluid.layers.sum(list(loss.values()))
loss.update({'loss': total_loss})
return loss
else:
results = self.ttf_head.get_bboxes(predict_hm, predict_wh,
feed_vars['scale_factor'])
return results
def _inputs_def(self, image_shape, downsample):
im_shape = [None] + image_shape
H, W = im_shape[2:]
target_h = None if H is None else H // downsample
target_w = None if W is None else W // downsample
# yapf: disable
inputs_def = {
'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0},
'scale_factor': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 0},
'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0},
'ttf_heatmap': {'shape': [None, self.num_classes, target_h, target_w], 'dtype': 'float32', 'lod_level': 0},
'ttf_box_target': {'shape': [None, 4, target_h, target_w], 'dtype': 'float32', 'lod_level': 0},
'ttf_reg_weight': {'shape': [None, 1, target_h, target_w], 'dtype': 'float32', 'lod_level': 0},
}
# yapf: enable
return inputs_def
def build_inputs(
self,
image_shape=[3, None, None],
fields=[
'image', 'ttf_heatmap', 'ttf_box_target', 'ttf_reg_weight'
], # for train
use_dataloader=True,
iterable=False,
downsample=4):
inputs_def = self._inputs_def(image_shape, downsample)
feed_vars = OrderedDict([(key, fluid.data(
name=key,
shape=inputs_def[key]['shape'],
dtype=inputs_def[key]['dtype'],
lod_level=inputs_def[key]['lod_level'])) for key in fields])
loader = fluid.io.DataLoader.from_generator(
feed_list=list(feed_vars.values()),
capacity=16,
use_double_buffer=True,
iterable=iterable) if use_dataloader else None
return feed_vars, loader
def train(self, feed_vars):
return self.build(feed_vars, mode='train')
def eval(self, feed_vars):
return self.build(feed_vars, mode='test')
def test(self, feed_vars):
return self.build(feed_vars, mode='test')
......@@ -42,13 +42,15 @@ class DarkNet(object):
depth=53,
norm_type='bn',
norm_decay=0.,
weight_prefix_name=''):
weight_prefix_name='',
freeze_at=-1):
assert depth in [53], "unsupported depth value"
self.depth = depth
self.norm_type = norm_type
self.norm_decay = norm_decay
self.depth_cfg = {53: ([1, 2, 8, 8, 4], self.basicblock)}
self.prefix_name = weight_prefix_name
self.freeze_at = freeze_at
def _conv_norm(self,
input,
......@@ -161,6 +163,8 @@ class DarkNet(object):
ch_out=32 * 2**i,
count=stage,
name=self.prefix_name + "stage.{}".format(i))
if i < self.freeze_at:
block.stop_gradient = True
blocks.append(block)
if i < len(stages) - 1: # do not downsaple in the last stage
downsample_ = self._downsample(
......
......@@ -33,14 +33,24 @@ class GiouLoss(object):
loss_weight (float): diou loss weight, default as 10 in faster-rcnn
is_cls_agnostic (bool): flag of class-agnostic
num_classes (int): class num
do_average (bool): whether to average the loss
use_class_weight(bool): whether to use class weight
'''
__shared__ = ['num_classes']
def __init__(self, loss_weight=10., is_cls_agnostic=False, num_classes=81):
def __init__(self,
loss_weight=10.,
is_cls_agnostic=False,
num_classes=81,
do_average=True,
use_class_weight=True):
super(GiouLoss, self).__init__()
self.loss_weight = loss_weight
self.is_cls_agnostic = is_cls_agnostic
self.num_classes = num_classes
self.do_average = do_average
self.class_weight = 2 if is_cls_agnostic else num_classes
self.use_class_weight = use_class_weight
# deltas: NxMx4
def bbox_transform(self, deltas, weights):
......@@ -78,10 +88,15 @@ class GiouLoss(object):
y,
inside_weight=None,
outside_weight=None,
bbox_reg_weight=[0.1, 0.1, 0.2, 0.2]):
bbox_reg_weight=[0.1, 0.1, 0.2, 0.2],
use_transform=True):
eps = 1.e-10
x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight)
x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight)
if use_transform:
x1, y1, x2, y2 = self.bbox_transform(x, bbox_reg_weight)
x1g, y1g, x2g, y2g = self.bbox_transform(y, bbox_reg_weight)
else:
x1, y1, x2, y2 = fluid.layers.split(x, num_or_sections=4, dim=1)
x1g, y1g, x2g, y2g = fluid.layers.split(y, num_or_sections=4, dim=1)
x2 = fluid.layers.elementwise_max(x1, x2)
y2 = fluid.layers.elementwise_max(y1, y2)
......@@ -99,9 +114,9 @@ class GiouLoss(object):
intsctk = (xkis2 - xkis1) * (ykis2 - ykis1)
intsctk = intsctk * fluid.layers.greater_than(
xkis2, xkis1) * fluid.layers.greater_than(ykis2, ykis1)
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g
) - intsctk + eps
iouk = intsctk / unionk
area_c = (xc2 - xc1) * (yc2 - yc1) + eps
......@@ -116,10 +131,17 @@ class GiouLoss(object):
outside_weight = fluid.layers.reduce_mean(outside_weight, dim=1)
iou_weights = inside_weight * outside_weight
class_weight = 2 if self.is_cls_agnostic else self.num_classes
iouk = fluid.layers.reduce_mean((1 - iouk) * iou_weights) * class_weight
miouk = fluid.layers.reduce_mean(
(1 - miouk) * iou_weights) * class_weight
elif outside_weight is not None:
iou_weights = outside_weight
if self.do_average:
miouk = fluid.layers.reduce_mean((1 - miouk) * iou_weights)
else:
iou_distance = fluid.layers.elementwise_mul(
1 - miouk, iou_weights, axis=0)
miouk = fluid.layers.reduce_sum(iou_distance)
if self.use_class_weight:
miouk = miouk * self.class_weight
return miouk * self.loss_weight
......@@ -30,7 +30,8 @@ __all__ = [
'GenerateProposals', 'MultiClassNMS', 'BBoxAssigner', 'MaskAssigner',
'RoIAlign', 'RoIPool', 'MultiBoxHead', 'SSDLiteMultiBoxHead',
'SSDOutputDecoder', 'RetinaTargetAssign', 'RetinaOutputDecoder', 'ConvNorm',
'DeformConvNorm', 'MultiClassSoftNMS', 'MatrixNMS', 'LibraBBoxAssigner'
'DeformConvNorm', 'MultiClassSoftNMS', 'MatrixNMS', 'LibraBBoxAssigner',
'DeformConv'
]
......@@ -43,36 +44,32 @@ def _conv_offset(input, filter_size, stride, padding, act=None, name=None):
stride=stride,
padding=padding,
param_attr=ParamAttr(
initializer=fluid.initializer.Constant(value=0),
name=name + ".w_0"),
initializer=fluid.initializer.Constant(0), name=name + ".w_0"),
bias_attr=ParamAttr(
initializer=fluid.initializer.Constant(value=0),
initializer=fluid.initializer.Constant(0),
learning_rate=2.,
regularizer=L2Decay(0.),
name=name + ".b_0"),
act=act,
name=name)
return out
def DeformConvNorm(input,
num_filters,
filter_size,
stride=1,
groups=1,
norm_decay=0.,
norm_type='affine_channel',
norm_groups=32,
dilation=1,
lr_scale=1,
freeze_norm=False,
act=None,
norm_name=None,
initializer=None,
bias_attr=False,
name=None):
def DeformConv(input,
num_filters,
filter_size,
stride=1,
groups=1,
dilation=1,
lr_scale=1,
initializer=None,
bias_attr=False,
name=None):
if bias_attr:
bias_para = ParamAttr(
name=name + "_bias",
initializer=fluid.initializer.Constant(value=0),
initializer=fluid.initializer.Constant(0),
regularizer=L2Decay(0.),
learning_rate=lr_scale * 2)
else:
bias_para = False
......@@ -109,6 +106,29 @@ def DeformConvNorm(input,
bias_attr=bias_para,
name=name + ".conv2d.output.1")
return conv
def DeformConvNorm(input,
num_filters,
filter_size,
stride=1,
groups=1,
norm_decay=0.,
norm_type='affine_channel',
norm_groups=32,
dilation=1,
lr_scale=1,
freeze_norm=False,
act=None,
norm_name=None,
initializer=None,
bias_attr=False,
name=None):
assert norm_type in ['bn', 'sync_bn', 'affine_channel']
conv = DeformConv(input, num_filters, filter_size, stride, groups, dilation,
lr_scale, initializer, bias_attr, name)
norm_lr = 0. if freeze_norm else 1.
pattr = ParamAttr(
name=norm_name + '_scale',
......@@ -330,7 +350,6 @@ class AnchorGenerator(object):
@serializable
class AnchorGrid(object):
"""Generate anchor grid
Args:
image_size (int or list): input image size, may be a single integer or
list of [h, w]. Default: 512
......
......@@ -261,6 +261,7 @@ def bbox2out(results, clsid2catid, is_bbox_normalized=False):
for j in range(num):
dt = bboxes[k]
clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
if clsid < 0: continue
catid = (clsid2catid[int(clsid)])
if is_bbox_normalized:
......
......@@ -161,6 +161,8 @@ def eval_run(exe,
if 'Corner' in cfg.architecture and post_config is not None:
from ppdet.utils.post_process import corner_post_process
corner_post_process(res, post_config, cfg.num_classes)
if 'TTFNet' in cfg.architecture:
res['bbox'][1].append([len(res['bbox'][0])])
results.append(res)
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
......
......@@ -76,6 +76,8 @@ def parse_reader(reader_cfg, metric, arch):
params['max_size'] = max(image_shape[
1:]) if arch in scale_set else 0
params['image_shape'] = image_shape[1:]
if 'target_dim' in params:
params.pop('target_dim')
p.update(params)
preprocess_list.append(p)
batch_transforms = reader_cfg.get('batch_transforms', None)
......@@ -109,6 +111,7 @@ def dump_infer_config(FLAGS, config):
'RCNN': 40,
'RetinaNet': 40,
'Face': 3,
'TTFNet': 3,
}
infer_arch = config['architecture']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册