未验证 提交 93e2a447 编写于 作者: F Feng Ni 提交者: GitHub

[Dygraph] add TTFNet (#2099)

* ttfnet infer

* fix ttfnet deploy

* clean code

* add comment, fix config

* add comment, update modelzoo

* fix norm resize, update modelzoo
上级 31a8da43
# TTFNet
## 简介
TTFNet是一种用于实时目标检测且对训练时间友好的网络,对CenterNet收敛速度慢的问题进行改进,提出了利用高斯核生成训练样本的新方法,有效的消除了anchor-free head中存在的模糊性。同时简单轻量化的网络结构也易于进行任务扩展。
**特点:**
结构简单,仅需要两个head检测目标位置和大小,并且去除了耗时的后处理操作
训练时间短,基于DarkNet53的骨干网路,V100 8卡仅需要训练2个小时即可达到较好的模型效果
## Model Zoo
| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
| DarkNet53 | TTFNet | 12 | 1x | ---- | 33.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ttfnet_darknet53_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ttfnet/ttfnet_darknet53_1x_coco.yml) |
## Citations
```
@article{liu2019training,
title = {Training-Time-Friendly Network for Real-Time Object Detection},
author = {Zili Liu, Tu Zheng, Guodong Xu, Zheng Yang, Haifeng Liu, Deng Cai},
journal = {arXiv preprint arXiv:1909.00700},
year = {2019}
}
```
epoch: 12
LearningRate:
base_lr: 0.015
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [8, 11]
- !LinearWarmup
start_factor: 0.2
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0004
type: L2
architecture: TTFNet
pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/DarkNet53_pretrained.tar
load_static_weights: True
TTFNet:
backbone: DarkNet
neck: TTFFPN
ttf_head: TTFHead
post_process: BBoxPostProcess
DarkNet:
depth: 53
freeze_at: 0
return_idx: [0, 1, 2, 3, 4]
norm_type: bn
norm_decay: 0.0004
TTFFPN:
planes: [256, 128, 64]
shortcut_num: [1, 2, 3]
ch_in: [1024, 256, 128]
TTFHead:
hm_head:
name: HMHead
ch_in: 64
ch_out: 128
conv_num: 2
wh_head:
name: WHHead
ch_in: 64
ch_out: 64
conv_num: 2
hm_loss:
name: CTFocalLoss
loss_weight: 1.
wh_loss:
name: GIoULoss
loss_weight: 5.
reduction: sum
BBoxPostProcess:
decode:
name: TTFBox
max_per_img: 100
score_thresh: 0.01
down_ratio: 4
worker_num: 2
TrainReader:
sample_transforms:
- DecodeOp: {}
- RandomFlipOp: {prob: 0.5}
- ResizeOp: {target_size: [512, 512], keep_ratio: False}
- NormalizeImageOp: {mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375], is_scale: false}
- PermuteOp: {}
batch_transforms:
- Gt2TTFTargetOp: {down_ratio: 4}
- PadBatchOp: {pad_to_stride: 32, pad_gt: true}
batch_size: 12
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [512, 512], keep_ratio: False}
- NormalizeImageOp: {is_scale: false, mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375]}
- PermuteOp: {}
batch_size: 1
drop_last: false
drop_empty: false
TestReader:
sample_transforms:
- DecodeOp: {}
- ResizeOp: {target_size: [512, 512], keep_ratio: False}
- NormalizeImageOp: {is_scale: false, mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375]}
- PermuteOp: {}
batch_size: 1
drop_last: false
drop_empty: false
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/ttfnet_darknet53.yml',
'_base_/ttfnet_reader.yml',
]
weights: output/ttfnet_darknet53_1x_coco/model_final
...@@ -35,6 +35,7 @@ SUPPORT_MODELS = { ...@@ -35,6 +35,7 @@ SUPPORT_MODELS = {
'SSD', 'SSD',
'FCOS', 'FCOS',
'SOLOv2', 'SOLOv2',
'TTFNet',
} }
......
...@@ -99,7 +99,7 @@ def _load_config_with_base(file_path): ...@@ -99,7 +99,7 @@ def _load_config_with_base(file_path):
return file_cfg return file_cfg
WITHOUT_BACKGROUND_ARCHS = ['YOLOv3', 'FCOS'] WITHOUT_BACKGROUND_ARCHS = ['YOLOv3', 'FCOS', 'TTFNet']
def _parse_with_background(): def _parse_with_background():
......
...@@ -514,6 +514,7 @@ class Gt2FCOSTargetOp(BaseOperator): ...@@ -514,6 +514,7 @@ class Gt2FCOSTargetOp(BaseOperator):
@register_op @register_op
class Gt2TTFTargetOp(BaseOperator): class Gt2TTFTargetOp(BaseOperator):
__shared__ = ['num_classes']
""" """
Gt2TTFTarget Gt2TTFTarget
Generate TTFNet targets by ground truth data Generate TTFNet targets by ground truth data
...@@ -525,7 +526,7 @@ class Gt2TTFTargetOp(BaseOperator): ...@@ -525,7 +526,7 @@ class Gt2TTFTargetOp(BaseOperator):
0.54 by default. 0.54 by default.
""" """
def __init__(self, num_classes, down_ratio=4, alpha=0.54): def __init__(self, num_classes=80, down_ratio=4, alpha=0.54):
super(Gt2TTFTargetOp, self).__init__() super(Gt2TTFTargetOp, self).__init__()
self.down_ratio = down_ratio self.down_ratio = down_ratio
self.num_classes = num_classes self.num_classes = num_classes
......
...@@ -13,6 +13,7 @@ from . import cascade_rcnn ...@@ -13,6 +13,7 @@ from . import cascade_rcnn
from . import ssd from . import ssd
from . import fcos from . import fcos
from . import solov2 from . import solov2
from . import ttfnet
from .meta_arch import * from .meta_arch import *
from .faster_rcnn import * from .faster_rcnn import *
...@@ -22,3 +23,4 @@ from .cascade_rcnn import * ...@@ -22,3 +23,4 @@ from .cascade_rcnn import *
from .ssd import * from .ssd import *
from .fcos import * from .fcos import *
from .solov2 import * from .solov2 import *
from .ttfnet import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from ppdet.core.workspace import register
from .meta_arch import BaseArch
__all__ = ['TTFNet']
@register
class TTFNet(BaseArch):
"""
TTFNet network, see https://arxiv.org/abs/1909.00700
Args:
backbone (object): backbone instance
neck (object): 'TTFFPN' instance
ttf_head (object): 'TTFHead' instance
post_process (object): 'BBoxPostProcess' instance
"""
__category__ = 'architecture'
__inject__ = [
'backbone',
'neck',
'ttf_head',
'post_process',
]
def __init__(self,
backbone='DarkNet',
neck='TTFFPN',
ttf_head='TTFHead',
post_process='BBoxPostProcess'):
super(TTFNet, self).__init__()
self.backbone = backbone
self.neck = neck
self.ttf_head = ttf_head
self.post_process = post_process
def model_arch(self, ):
# Backbone
body_feats = self.backbone(self.inputs)
# neck
body_feats = self.neck(body_feats)
# TTF Head
self.hm, self.wh = self.ttf_head(body_feats)
def get_loss(self, ):
loss = {}
heatmap = self.inputs['ttf_heatmap']
box_target = self.inputs['ttf_box_target']
reg_weight = self.inputs['ttf_reg_weight']
head_loss = self.ttf_head.get_loss(self.hm, self.wh, heatmap,
box_target, reg_weight)
loss.update(head_loss)
total_loss = paddle.add_n(list(loss.values()))
loss.update({'loss': total_loss})
return loss
def get_pred(self):
bbox, bbox_num = self.post_process(self.hm, self.wh,
self.inputs['im_shape'],
self.inputs['scale_factor'])
outs = {
"bbox": bbox,
"bbox_num": bbox_num,
}
return outs
...@@ -32,6 +32,7 @@ class ConvBNLayer(nn.Layer): ...@@ -32,6 +32,7 @@ class ConvBNLayer(nn.Layer):
groups=1, groups=1,
padding=0, padding=0,
norm_type='bn', norm_type='bn',
norm_decay=0.,
act="leaky", act="leaky",
name=None): name=None):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
...@@ -45,7 +46,8 @@ class ConvBNLayer(nn.Layer): ...@@ -45,7 +46,8 @@ class ConvBNLayer(nn.Layer):
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + '.conv.weights'), weight_attr=ParamAttr(name=name + '.conv.weights'),
bias_attr=False) bias_attr=False)
self.batch_norm = batch_norm(ch_out, norm_type=norm_type, name=name) self.batch_norm = batch_norm(
ch_out, norm_type=norm_type, norm_decay=norm_decay, name=name)
self.act = act self.act = act
def forward(self, inputs): def forward(self, inputs):
...@@ -64,6 +66,7 @@ class DownSample(nn.Layer): ...@@ -64,6 +66,7 @@ class DownSample(nn.Layer):
stride=2, stride=2,
padding=1, padding=1,
norm_type='bn', norm_type='bn',
norm_decay=0.,
name=None): name=None):
super(DownSample, self).__init__() super(DownSample, self).__init__()
...@@ -75,6 +78,7 @@ class DownSample(nn.Layer): ...@@ -75,6 +78,7 @@ class DownSample(nn.Layer):
stride=stride, stride=stride,
padding=padding, padding=padding,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay,
name=name) name=name)
self.ch_out = ch_out self.ch_out = ch_out
...@@ -84,7 +88,7 @@ class DownSample(nn.Layer): ...@@ -84,7 +88,7 @@ class DownSample(nn.Layer):
class BasicBlock(nn.Layer): class BasicBlock(nn.Layer):
def __init__(self, ch_in, ch_out, norm_type='bn', name=None): def __init__(self, ch_in, ch_out, norm_type='bn', norm_decay=0., name=None):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
...@@ -94,6 +98,7 @@ class BasicBlock(nn.Layer): ...@@ -94,6 +98,7 @@ class BasicBlock(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay,
name=name + '.0') name=name + '.0')
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
ch_in=ch_out, ch_in=ch_out,
...@@ -102,6 +107,7 @@ class BasicBlock(nn.Layer): ...@@ -102,6 +107,7 @@ class BasicBlock(nn.Layer):
stride=1, stride=1,
padding=1, padding=1,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay,
name=name + '.1') name=name + '.1')
def forward(self, inputs): def forward(self, inputs):
...@@ -112,18 +118,32 @@ class BasicBlock(nn.Layer): ...@@ -112,18 +118,32 @@ class BasicBlock(nn.Layer):
class Blocks(nn.Layer): class Blocks(nn.Layer):
def __init__(self, ch_in, ch_out, count, norm_type='bn', name=None): def __init__(self,
ch_in,
ch_out,
count,
norm_type='bn',
norm_decay=0.,
name=None):
super(Blocks, self).__init__() super(Blocks, self).__init__()
self.basicblock0 = BasicBlock( self.basicblock0 = BasicBlock(
ch_in, ch_out, norm_type=norm_type, name=name + '.0') ch_in,
ch_out,
norm_type=norm_type,
norm_decay=norm_decay,
name=name + '.0')
self.res_out_list = [] self.res_out_list = []
for i in range(1, count): for i in range(1, count):
block_name = '{}.{}'.format(name, i) block_name = '{}.{}'.format(name, i)
res_out = self.add_sublayer( res_out = self.add_sublayer(
block_name, block_name,
BasicBlock( BasicBlock(
ch_out * 2, ch_out, norm_type=norm_type, name=block_name)) ch_out * 2,
ch_out,
norm_type=norm_type,
norm_decay=norm_decay,
name=block_name))
self.res_out_list.append(res_out) self.res_out_list.append(res_out)
self.ch_out = ch_out self.ch_out = ch_out
...@@ -147,7 +167,8 @@ class DarkNet(nn.Layer): ...@@ -147,7 +167,8 @@ class DarkNet(nn.Layer):
freeze_at=-1, freeze_at=-1,
return_idx=[2, 3, 4], return_idx=[2, 3, 4],
num_stages=5, num_stages=5,
norm_type='bn'): norm_type='bn',
norm_decay=0.):
super(DarkNet, self).__init__() super(DarkNet, self).__init__()
self.depth = depth self.depth = depth
self.freeze_at = freeze_at self.freeze_at = freeze_at
...@@ -162,12 +183,14 @@ class DarkNet(nn.Layer): ...@@ -162,12 +183,14 @@ class DarkNet(nn.Layer):
stride=1, stride=1,
padding=1, padding=1,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay,
name='yolo_input') name='yolo_input')
self.downsample0 = DownSample( self.downsample0 = DownSample(
ch_in=32, ch_in=32,
ch_out=32 * 2, ch_out=32 * 2,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay,
name='yolo_input.downsample') name='yolo_input.downsample')
self.darknet_conv_block_list = [] self.darknet_conv_block_list = []
...@@ -182,6 +205,7 @@ class DarkNet(nn.Layer): ...@@ -182,6 +205,7 @@ class DarkNet(nn.Layer):
32 * (2**i), 32 * (2**i),
stage, stage,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay,
name=name)) name=name))
self.darknet_conv_block_list.append(conv_block) self.darknet_conv_block_list.append(conv_block)
for i in range(num_stages - 1): for i in range(num_stages - 1):
...@@ -192,6 +216,7 @@ class DarkNet(nn.Layer): ...@@ -192,6 +216,7 @@ class DarkNet(nn.Layer):
ch_in=32 * (2**(i + 1)), ch_in=32 * (2**(i + 1)),
ch_out=32 * (2**(i + 2)), ch_out=32 * (2**(i + 2)),
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay,
name=down_name)) name=down_name))
self.downsample_list.append(downsample) self.downsample_list.append(downsample)
......
...@@ -20,6 +20,7 @@ from . import roi_extractor ...@@ -20,6 +20,7 @@ from . import roi_extractor
from . import ssd_head from . import ssd_head
from . import fcos_head from . import fcos_head
from . import solov2_head from . import solov2_head
from . import ttf_head
from .rpn_head import * from .rpn_head import *
from .bbox_head import * from .bbox_head import *
...@@ -29,3 +30,4 @@ from .roi_extractor import * ...@@ -29,3 +30,4 @@ from .roi_extractor import *
from .ssd_head import * from .ssd_head import *
from .fcos_head import * from .fcos_head import *
from .solov2_head import * from .solov2_head import *
from .ttf_head import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Constant, Uniform, Normal
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
import numpy as np
@register
class HMHead(nn.Layer):
__shared__ = ['num_classes']
def __init__(self, ch_in, ch_out=128, num_classes=80, conv_num=2):
super(HMHead, self).__init__()
head_conv = nn.Sequential()
for i in range(conv_num):
name = 'conv.{}'.format(i)
head_conv.add_sublayer(
name,
nn.Conv2D(
in_channels=ch_in if i == 0 else ch_out,
out_channels=ch_out,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
bias_attr=ParamAttr(
learning_rate=2., regularizer=L2Decay(0.))))
head_conv.add_sublayer(name + '.act', nn.ReLU())
self.feat = self.add_sublayer('hm_feat', head_conv)
bias_init = float(-np.log((1 - 0.01) / 0.01))
self.head = self.add_sublayer(
'hm_head',
nn.Conv2D(
in_channels=ch_out,
out_channels=num_classes,
kernel_size=1,
weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
bias_attr=ParamAttr(
learning_rate=2.,
regularizer=L2Decay(0.),
initializer=Constant(bias_init))))
def forward(self, feat):
out = self.feat(feat)
out = self.head(out)
return out
@register
class WHHead(nn.Layer):
def __init__(self, ch_in, ch_out=64, conv_num=2):
super(WHHead, self).__init__()
head_conv = nn.Sequential()
for i in range(conv_num):
name = 'conv.{}'.format(i)
head_conv.add_sublayer(
name,
nn.Conv2D(
in_channels=ch_in if i == 0 else ch_out,
out_channels=ch_out,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0, 0.001)),
bias_attr=ParamAttr(
learning_rate=2., regularizer=L2Decay(0.))))
head_conv.add_sublayer(name + '.act', nn.ReLU())
self.feat = self.add_sublayer('wh_feat', head_conv)
self.head = self.add_sublayer(
'wh_head',
nn.Conv2D(
in_channels=ch_out,
out_channels=4,
kernel_size=1,
weight_attr=ParamAttr(initializer=Normal(0, 0.001)),
bias_attr=ParamAttr(
learning_rate=2., regularizer=L2Decay(0.))))
def forward(self, feat):
out = self.feat(feat)
out = self.head(out)
out = F.relu(out)
return out
@register
class TTFHead(nn.Layer):
"""
TTFHead
Args:
hm_head(object): Instance of 'HMHead', heatmap branch.
wh_head(object): Instance of 'WHHead', wh branch.
hm_loss(object): Instance of 'CTFocalLoss'.
wh_loss(object): Instance of 'GIoULoss'.
wh_offset_base(flaot): the base offset of width and height, 16. by default.
down_ratio(int): the actual down_ratio is calculated by base_down_ratio(default 16) a
nd the number of upsample layers.
"""
__shared__ = ['down_ratio']
__inject__ = ['hm_head', 'wh_head', 'hm_loss', 'wh_loss']
def __init__(self,
hm_head='HMHead',
wh_head='WHHead',
hm_loss='CTFocalLoss',
wh_loss='GIoULoss',
wh_offset_base=16.,
down_ratio=4):
super(TTFHead, self).__init__()
self.hm_head = hm_head
self.wh_head = wh_head
self.hm_loss = hm_loss
self.wh_loss = wh_loss
self.wh_offset_base = wh_offset_base
self.down_ratio = down_ratio
def forward(self, feats):
hm = self.hm_head(feats)
wh = self.wh_head(feats) * self.wh_offset_base
return hm, wh
def filter_box_by_weight(self, pred, target, weight):
index = paddle.nonzero(weight > 0)
index.stop_gradient = True
weight = paddle.gather_nd(weight, index)
pred = paddle.gather_nd(pred, index)
target = paddle.gather_nd(target, index)
return pred, target, weight
def get_loss(self, pred_hm, pred_wh, target_hm, box_target, target_weight):
pred_hm = paddle.clip(F.sigmoid(pred_hm), 1e-4, 1 - 1e-4)
hm_loss = self.hm_loss(pred_hm, target_hm)
H, W = target_hm.shape[2:]
mask = paddle.reshape(target_weight, [-1, H, W])
avg_factor = paddle.sum(mask) + 1e-4
base_step = self.down_ratio
shifts_x = paddle.arange(0, W * base_step, base_step, dtype='int32')
shifts_y = paddle.arange(0, H * base_step, base_step, dtype='int32')
shift_y, shift_x = paddle.tensor.meshgrid([shifts_y, shifts_x])
base_loc = paddle.stack([shift_x, shift_y], axis=0)
base_loc.stop_gradient = True
pred_boxes = paddle.concat(
[0 - pred_wh[:, 0:2, :, :] + base_loc, pred_wh[:, 2:4] + base_loc],
axis=1)
pred_boxes = paddle.transpose(pred_boxes, [0, 2, 3, 1])
boxes = paddle.transpose(box_target, [0, 2, 3, 1])
boxes.stop_gradient = True
pred_boxes, boxes, mask = self.filter_box_by_weight(pred_boxes, boxes,
mask)
mask.stop_gradient = True
wh_loss = self.wh_loss(pred_boxes, boxes, iou_weight=mask.unsqueeze(1))
wh_loss = wh_loss / avg_factor
ttf_loss = {'hm_loss': hm_loss, 'wh_loss': wh_loss}
return ttf_loss
...@@ -913,6 +913,87 @@ class FCOSBox(object): ...@@ -913,6 +913,87 @@ class FCOSBox(object):
@register @register
class TTFBox(object):
__shared__ = ['down_ratio']
def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
super(TTFBox, self).__init__()
self.max_per_img = max_per_img
self.score_thresh = score_thresh
self.down_ratio = down_ratio
def _simple_nms(self, heat, kernel=3):
pad = (kernel - 1) // 2
hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
keep = paddle.cast(hmax == heat, 'float32')
return heat * keep
def _topk(self, scores):
k = self.max_per_img
shape_fm = paddle.shape(scores)
shape_fm.stop_gradient = True
cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
# batch size is 1
scores_r = paddle.reshape(scores, [cat, -1])
topk_scores, topk_inds = paddle.topk(scores_r, k)
topk_scores, topk_inds = paddle.topk(scores_r, k)
topk_ys = topk_inds // width
topk_xs = topk_inds % width
topk_score_r = paddle.reshape(topk_scores, [-1])
topk_score, topk_ind = paddle.topk(topk_score_r, k)
k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
topk_inds = paddle.reshape(topk_inds, [-1])
topk_ys = paddle.reshape(topk_ys, [-1, 1])
topk_xs = paddle.reshape(topk_xs, [-1, 1])
topk_inds = paddle.gather(topk_inds, topk_ind)
topk_ys = paddle.gather(topk_ys, topk_ind)
topk_xs = paddle.gather(topk_xs, topk_ind)
return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
def __call__(self, hm, wh, im_shape, scale_factor):
heatmap = F.sigmoid(hm)
heat = self._simple_nms(heatmap)
scores, inds, clses, ys, xs = self._topk(heat)
ys = paddle.cast(ys, 'float32') * self.down_ratio
xs = paddle.cast(xs, 'float32') * self.down_ratio
scores = paddle.tensor.unsqueeze(scores, [1])
clses = paddle.tensor.unsqueeze(clses, [1])
wh_t = paddle.transpose(wh, [0, 2, 3, 1])
wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
wh = paddle.gather(wh, inds)
x1 = xs - wh[:, 0:1]
y1 = ys - wh[:, 1:2]
x2 = xs + wh[:, 2:3]
y2 = ys + wh[:, 3:4]
bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
scale_y = scale_factor[:, 0:1]
scale_x = scale_factor[:, 1:2]
scale_expand = paddle.concat(
[scale_x, scale_y, scale_x, scale_y], axis=1)
boxes_shape = paddle.shape(bboxes)
boxes_shape.stop_gradient = True
scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
bboxes = paddle.divide(bboxes, scale_expand)
results = paddle.concat([clses, scores, bboxes], axis=1)
# hack: append result with cls=-1 and score=1. to avoid all scores
# are less than score_thresh which may cause error in gather.
fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
fill_r = paddle.cast(fill_r, results.dtype)
results = paddle.concat([results, fill_r])
scores = results[:, 1]
valid_ind = paddle.nonzero(scores > self.score_thresh)
results = paddle.gather(results, valid_ind)
return results, paddle.shape(results)[0:1]
@serializable @serializable
class MaskMatrixNMS(object): class MaskMatrixNMS(object):
""" """
......
...@@ -18,6 +18,7 @@ from . import iou_loss ...@@ -18,6 +18,7 @@ from . import iou_loss
from . import ssd_loss from . import ssd_loss
from . import fcos_loss from . import fcos_loss
from . import solov2_loss from . import solov2_loss
from . import ctfocal_loss
from .yolo_loss import * from .yolo_loss import *
from .iou_aware_loss import * from .iou_aware_loss import *
...@@ -25,3 +26,4 @@ from .iou_loss import * ...@@ -25,3 +26,4 @@ from .iou_loss import *
from .ssd_loss import * from .ssd_loss import *
from .fcos_loss import * from .fcos_loss import *
from .solov2_loss import * from .solov2_loss import *
from .ctfocal_loss import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
__all__ = ['CTFocalLoss']
@register
@serializable
class CTFocalLoss(object):
"""
CTFocalLoss
Args:
loss_weight (float): loss weight
gamma (float): gamma parameter for Focal Loss
"""
def __init__(self, loss_weight=1., gamma=2.0):
self.loss_weight = loss_weight
self.gamma = gamma
def __call__(self, pred, target):
"""
Calculate the loss
Args:
pred(Tensor): heatmap prediction
target(Tensor): target for positive samples
Return:
ct_focal_loss (Tensor): Focal Loss used in CornerNet & CenterNet.
Note that the values in target are in [0, 1] since gaussian is
used to reduce the punishment and we treat [0, 1) as neg example.
"""
fg_map = paddle.cast(target == 1, 'float32')
fg_map.stop_gradient = True
bg_map = paddle.cast(target < 1, 'float32')
bg_map.stop_gradient = True
neg_weights = paddle.pow(1 - target, 4) * bg_map
pos_loss = 0 - paddle.log(pred) * paddle.pow(1 - pred,
self.gamma) * fg_map
neg_loss = 0 - paddle.log(1 - pred) * paddle.pow(
pred, self.gamma) * neg_weights
pos_loss = paddle.sum(pos_loss)
neg_loss = paddle.sum(neg_loss)
fg_num = paddle.sum(fg_map)
ct_focal_loss = (pos_loss + neg_loss) / (
fg_num + paddle.cast(fg_num == 0, 'float32'))
return ct_focal_loss * self.loss_weight
...@@ -21,7 +21,7 @@ import paddle.nn.functional as F ...@@ -21,7 +21,7 @@ import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ..utils import xywh2xyxy, bbox_iou, decode_yolo from ..utils import xywh2xyxy, bbox_iou, decode_yolo
__all__ = ['IouLoss'] __all__ = ['IouLoss', 'GIoULoss']
@register @register
...@@ -60,3 +60,72 @@ class IouLoss(object): ...@@ -60,3 +60,72 @@ class IouLoss(object):
loss_iou = loss_iou * self.loss_weight loss_iou = loss_iou * self.loss_weight
return loss_iou return loss_iou
@register
@serializable
class GIoULoss(object):
"""
Generalized Intersection over Union, see https://arxiv.org/abs/1902.09630
Args:
loss_weight (float): giou loss weight, default as 1
eps (float): epsilon to avoid divide by zero, default as 1e-10
reduction (string): Options are "none", "mean" and "sum". default as none
"""
def __init__(self, loss_weight=1., eps=1e-10, reduction='none'):
self.loss_weight = loss_weight
self.eps = eps
assert reduction in ('none', 'mean', 'sum')
self.reduction = reduction
def bbox_overlap(self, box1, box2, eps=1e-10):
"""calculate the iou of box1 and box2
Args:
box1 (Tensor): box1 with the shape (..., 4)
box2 (Tensor): box1 with the shape (..., 4)
eps (float): epsilon to avoid divide by zero
Return:
iou (Tensor): iou of box1 and box2
overlap (Tensor): overlap of box1 and box2
union (Tensor): union of box1 and box2
"""
x1, y1, x2, y2 = box1
x1g, y1g, x2g, y2g = box2
xkis1 = paddle.maximum(x1, x1g)
ykis1 = paddle.maximum(y1, y1g)
xkis2 = paddle.minimum(x2, x2g)
ykis2 = paddle.minimum(y2, y2g)
w_inter = (xkis2 - xkis1).clip(0)
h_inter = (ykis2 - ykis1).clip(0)
overlap = w_inter * h_inter
area1 = (x2 - x1) * (y2 - y1)
area2 = (x2g - x1g) * (y2g - y1g)
union = area1 + area2 - overlap + eps
iou = overlap / union
return iou, overlap, union
def __call__(self, pbox, gbox, iou_weight=1.):
x1, y1, x2, y2 = paddle.split(pbox, num_or_sections=4, axis=-1)
x1g, y1g, x2g, y2g = paddle.split(gbox, num_or_sections=4, axis=-1)
box1 = [x1, y1, x2, y2]
box2 = [x1g, y1g, x2g, y2g]
iou, overlap, union = self.bbox_overlap(box1, box2, self.eps)
xc1 = paddle.minimum(x1, x1g)
yc1 = paddle.minimum(y1, y1g)
xc2 = paddle.maximum(x2, x2g)
yc2 = paddle.maximum(y2, y2g)
area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps
miou = iou - ((area_c - union) / area_c)
giou = 1 - miou
if self.reduction == 'none':
loss = giou
elif self.reduction == 'sum':
loss = paddle.sum(giou * iou_weight)
else:
loss = paddle.mean(giou * iou_weight)
return loss * self.loss_weight
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
from . import fpn from . import fpn
from . import yolo_fpn from . import yolo_fpn
from . import hrfpn from . import hrfpn
from . import ttf_fpn
from .fpn import * from .fpn import *
from .yolo_fpn import * from .yolo_fpn import *
from .hrfpn import * from .hrfpn import *
from .ttf_fpn import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.nn.initializer import Constant, Uniform, Normal
from paddle.nn import Conv2D, ReLU, Sequential
from paddle import ParamAttr
from ppdet.core.workspace import register, serializable
from paddle.regularizer import L2Decay
from ppdet.modeling.layers import DeformableConvV2
import math
from ppdet.modeling.ops import batch_norm
class Upsample(nn.Layer):
def __init__(self, ch_in, ch_out, name=None):
super(Upsample, self).__init__()
fan_in = ch_in * 3 * 3
stdv = 1. / math.sqrt(fan_in)
self.dcn = DeformableConvV2(
ch_in,
ch_out,
kernel_size=3,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=ParamAttr(
initializer=Constant(0),
regularizer=L2Decay(0.),
learning_rate=2.),
lr_scale=2.,
regularizer=L2Decay(0.),
name=name)
self.bn = batch_norm(
ch_out, norm_type='bn', initializer=Constant(1.), name=name)
def forward(self, feat):
dcn = self.dcn(feat)
bn = self.bn(dcn)
relu = F.relu(bn)
out = F.interpolate(relu, scale_factor=2., mode='bilinear')
return out
class ShortCut(nn.Layer):
def __init__(self, layer_num, ch_out, name=None):
super(ShortCut, self).__init__()
shortcut_conv = Sequential()
ch_in = ch_out * 2
for i in range(layer_num):
fan_out = 3 * 3 * ch_out
std = math.sqrt(2. / fan_out)
in_channels = ch_in if i == 0 else ch_out
shortcut_name = name + '.conv.{}'.format(i)
shortcut_conv.add_sublayer(
shortcut_name,
Conv2D(
in_channels=in_channels,
out_channels=ch_out,
kernel_size=3,
padding=1,
weight_attr=ParamAttr(initializer=Normal(0, std)),
bias_attr=ParamAttr(
learning_rate=2., regularizer=L2Decay(0.))))
if i < layer_num - 1:
shortcut_conv.add_sublayer(shortcut_name + '.act', ReLU())
self.shortcut = self.add_sublayer('short', shortcut_conv)
def forward(self, feat):
out = self.shortcut(feat)
return out
@register
@serializable
class TTFFPN(nn.Layer):
def __init__(self,
planes=[256, 128, 64],
shortcut_num=[1, 2, 3],
ch_in=[1024, 256, 128]):
super(TTFFPN, self).__init__()
self.planes = planes
self.shortcut_num = shortcut_num
self.shortcut_len = len(shortcut_num)
self.ch_in = ch_in
self.upsample_list = []
self.shortcut_list = []
for i, out_c in enumerate(self.planes):
upsample = self.add_sublayer(
'upsample.' + str(i),
Upsample(
self.ch_in[i], out_c, name='upsample.' + str(i)))
self.upsample_list.append(upsample)
if i < self.shortcut_len:
shortcut = self.add_sublayer(
'shortcut.' + str(i),
ShortCut(
self.shortcut_num[i], out_c, name='shortcut.' + str(i)))
self.shortcut_list.append(shortcut)
def forward(self, inputs):
feat = inputs[-1]
for i, out_c in enumerate(self.planes):
feat = self.upsample_list[i](feat)
if i < self.shortcut_len:
shortcut = self.shortcut_list[i](inputs[-i - 2])
feat = feat + shortcut
return feat
...@@ -45,7 +45,7 @@ __all__ = [ ...@@ -45,7 +45,7 @@ __all__ = [
] ]
def batch_norm(ch, norm_type='bn', name=None): def batch_norm(ch, norm_type='bn', norm_decay=0., initializer=None, name=None):
bn_name = name + '.bn' bn_name = name + '.bn'
if norm_type == 'sync_bn': if norm_type == 'sync_bn':
batch_norm = nn.SyncBatchNorm batch_norm = nn.SyncBatchNorm
...@@ -55,9 +55,11 @@ def batch_norm(ch, norm_type='bn', name=None): ...@@ -55,9 +55,11 @@ def batch_norm(ch, norm_type='bn', name=None):
return batch_norm( return batch_norm(
ch, ch,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=bn_name + '.scale', regularizer=L2Decay(0.)), name=bn_name + '.scale',
initializer=initializer,
regularizer=L2Decay(norm_decay)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
name=bn_name + '.offset', regularizer=L2Decay(0.))) name=bn_name + '.offset', regularizer=L2Decay(norm_decay)))
@paddle.jit.not_to_static @paddle.jit.not_to_static
......
...@@ -17,8 +17,12 @@ class BBoxPostProcess(object): ...@@ -17,8 +17,12 @@ class BBoxPostProcess(object):
self.nms = nms self.nms = nms
def __call__(self, head_out, rois, im_shape, scale_factor=None): def __call__(self, head_out, rois, im_shape, scale_factor=None):
if self.nms is not None:
bboxes, score = self.decode(head_out, rois, im_shape, scale_factor) bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, score) bbox_pred, bbox_num, _ = self.nms(bboxes, score)
else:
bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
scale_factor)
return bbox_pred, bbox_num return bbox_pred, bbox_num
......
...@@ -147,6 +147,8 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map): ...@@ -147,6 +147,8 @@ def get_det_res(bboxes, bbox_nums, image_id, num_id_to_cat_id_map):
dt = bboxes[k] dt = bboxes[k]
k = k + 1 k = k + 1
num_id, score, xmin, ymin, xmax, ymax = dt.tolist() num_id, score, xmin, ymin, xmax, ymax = dt.tolist()
if num_id < 0:
continue
category_id = num_id_to_cat_id_map[num_id] category_id = num_id_to_cat_id_map[num_id]
w = xmax - xmin + 1 w = xmax - xmin + 1
h = ymax - ymin + 1 h = ymax - ymin + 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册