diff --git a/dygraph/configs/ttfnet/README.md b/dygraph/configs/ttfnet/README.md index 67b848933496d52548befd5c5ca3bc79e048a5f7..d0bf4148141f9c5ade7dbe1bd3a88aa0443622e4 100644 --- a/dygraph/configs/ttfnet/README.md +++ b/dygraph/configs/ttfnet/README.md @@ -13,7 +13,7 @@ TTFNet是一种用于实时目标检测且对训练时间友好的网络,对Ce | 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | | :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | -| DarkNet53 | TTFNet | 12 | 1x | ---- | 33.6 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ttfnet_darknet53_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ttfnet/ttfnet_darknet53_1x_coco.yml) | +| DarkNet53 | TTFNet | 12 | 1x | ---- | 33.5 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ttfnet_darknet53_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/ttfnet/ttfnet_darknet53_1x_coco.yml) | ## Citations ``` diff --git a/dygraph/configs/ttfnet/_base_/ttfnet_darknet53.yml b/dygraph/configs/ttfnet/_base_/ttfnet_darknet53.yml index 59b173bc43562b8eeadeedd97a1ddf1938038860..d175cadb2d907ca70e27824f88a7093b6e193ac1 100644 --- a/dygraph/configs/ttfnet/_base_/ttfnet_darknet53.yml +++ b/dygraph/configs/ttfnet/_base_/ttfnet_darknet53.yml @@ -11,26 +11,14 @@ TTFNet: DarkNet: depth: 53 freeze_at: 0 - return_idx: [0, 1, 2, 3, 4] + return_idx: [1, 2, 3, 4] norm_type: bn norm_decay: 0.0004 -TTFFPN: - planes: [256, 128, 64] - shortcut_num: [1, 2, 3] - ch_in: [1024, 256, 128] +# use default config +# TTFFPN: TTFHead: - hm_head: - name: HMHead - ch_in: 64 - ch_out: 128 - conv_num: 2 - wh_head: - name: WHHead - ch_in: 64 - ch_out: 64 - conv_num: 2 hm_loss: name: CTFocalLoss loss_weight: 1. @@ -39,7 +27,6 @@ TTFHead: loss_weight: 5. reduction: sum - BBoxPostProcess: decode: name: TTFBox diff --git a/dygraph/ppdet/modeling/architectures/ttfnet.py b/dygraph/ppdet/modeling/architectures/ttfnet.py index 70b8761de5a034028e86c0e095cdb03ecea05c8f..181aa7f0d88f05332cd406da58b13d9d4fe0612a 100644 --- a/dygraph/ppdet/modeling/architectures/ttfnet.py +++ b/dygraph/ppdet/modeling/architectures/ttfnet.py @@ -17,7 +17,7 @@ from __future__ import division from __future__ import print_function import paddle -from ppdet.core.workspace import register +from ppdet.core.workspace import register, create from .meta_arch import BaseArch __all__ = ['TTFNet'] @@ -36,12 +36,7 @@ class TTFNet(BaseArch): """ __category__ = 'architecture' - __inject__ = [ - 'backbone', - 'neck', - 'ttf_head', - 'post_process', - ] + __inject__ = ['post_process'] def __init__(self, backbone='DarkNet', @@ -54,32 +49,55 @@ class TTFNet(BaseArch): self.ttf_head = ttf_head self.post_process = post_process - def model_arch(self, ): - # Backbone + @classmethod + def from_config(cls, cfg, *args, **kwargs): + backbone = create(cfg['backbone']) + + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + + kwargs = {'input_shape': neck.out_shape} + ttf_head = create(cfg['ttf_head'], **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + "ttf_head": ttf_head, + } + + def _forward(self): body_feats = self.backbone(self.inputs) - # neck body_feats = self.neck(body_feats) - # TTF Head - self.hm, self.wh = self.ttf_head(body_feats) + hm, wh = self.ttf_head(body_feats) + if self.training: + return hm, wh + else: + bbox, bbox_num = self.post_process(hm, wh, self.inputs['im_shape'], + self.inputs['scale_factor']) + return bbox, bbox_num def get_loss(self, ): loss = {} heatmap = self.inputs['ttf_heatmap'] box_target = self.inputs['ttf_box_target'] reg_weight = self.inputs['ttf_reg_weight'] - head_loss = self.ttf_head.get_loss(self.hm, self.wh, heatmap, - box_target, reg_weight) + hm, wh = self._forward() + head_loss = self.ttf_head.get_loss(hm, wh, heatmap, box_target, + reg_weight) loss.update(head_loss) total_loss = paddle.add_n(list(loss.values())) loss.update({'loss': total_loss}) return loss def get_pred(self): - bbox, bbox_num = self.post_process(self.hm, self.wh, - self.inputs['im_shape'], - self.inputs['scale_factor']) - outs = { + bbox_pred, bbox_num = self._forward() + label = bbox_pred[:, 0] + score = bbox_pred[:, 1] + bbox = bbox_pred[:, 2:] + output = { "bbox": bbox, + 'score': score, + 'label': label, "bbox_num": bbox_num, } - return outs + return output diff --git a/dygraph/ppdet/modeling/heads/ttf_head.py b/dygraph/ppdet/modeling/heads/ttf_head.py index a4c4b644e599c931109b72885c7e5f22e6aef3ac..632fd06306ec515deec3a8355dd422600c960ebc 100644 --- a/dygraph/ppdet/modeling/heads/ttf_head.py +++ b/dygraph/ppdet/modeling/heads/ttf_head.py @@ -104,34 +104,50 @@ class TTFHead(nn.Layer): """ TTFHead Args: - hm_head(object): Instance of 'HMHead', heatmap branch. - wh_head(object): Instance of 'WHHead', wh branch. + in_channels(int): the channel number of input to TTFHead. + num_classes(int): the number of classes, 80 by default. + hm_head_planes(int): the channel number in wh head, 128 by default. + wh_head_planes(int): the channel number in wh head, 64 by default. + hm_head_conv_num(int): the number of convolution in wh head, 2 by default. + wh_head_conv_num(int): the number of convolution in wh head, 2 by default. hm_loss(object): Instance of 'CTFocalLoss'. wh_loss(object): Instance of 'GIoULoss'. wh_offset_base(flaot): the base offset of width and height, 16. by default. - down_ratio(int): the actual down_ratio is calculated by base_down_ratio(default 16) a - nd the number of upsample layers. + down_ratio(int): the actual down_ratio is calculated by base_down_ratio(default 16) + and the number of upsample layers. """ - __shared__ = ['down_ratio'] - __inject__ = ['hm_head', 'wh_head', 'hm_loss', 'wh_loss'] + __shared__ = ['num_classes', 'down_ratio'] + __inject__ = ['hm_loss', 'wh_loss'] def __init__(self, - hm_head='HMHead', - wh_head='WHHead', + in_channels, + num_classes=80, + hm_head_planes=128, + wh_head_planes=64, + hm_head_conv_num=2, + wh_head_conv_num=2, hm_loss='CTFocalLoss', wh_loss='GIoULoss', wh_offset_base=16., down_ratio=4): super(TTFHead, self).__init__() - self.hm_head = hm_head - self.wh_head = wh_head + self.in_channels = in_channels + self.hm_head = HMHead(in_channels, hm_head_planes, num_classes, + hm_head_conv_num) + self.wh_head = WHHead(in_channels, wh_head_planes, wh_head_conv_num) self.hm_loss = hm_loss self.wh_loss = wh_loss self.wh_offset_base = wh_offset_base self.down_ratio = down_ratio + @classmethod + def from_config(cls, cfg, input_shape): + if isinstance(input_shape, (list, tuple)): + input_shape = input_shape[0] + return {'in_channels': input_shape.channels, } + def forward(self, feats): hm = self.hm_head(feats) wh = self.wh_head(feats) * self.wh_offset_base diff --git a/dygraph/ppdet/modeling/necks/ttf_fpn.py b/dygraph/ppdet/modeling/necks/ttf_fpn.py index 92ada3925db1f94723649d39ee9ae2b75b58578c..16f808240d2f05217e048ce89f3a3effcbaffe94 100644 --- a/dygraph/ppdet/modeling/necks/ttf_fpn.py +++ b/dygraph/ppdet/modeling/necks/ttf_fpn.py @@ -24,6 +24,9 @@ from paddle.regularizer import L2Decay from ppdet.modeling.layers import DeformableConvV2 import math from ppdet.modeling.ops import batch_norm +from ..shape_spec import ShapeSpec + +__all__ = ['TTFFPN'] __all__ = ['TTFFPN'] @@ -89,22 +92,33 @@ class ShortCut(nn.Layer): @register @serializable class TTFFPN(nn.Layer): + """ + Args: + in_channels (list): number of input feature channels from backbone. + [128,256,512,1024] by default, means the channels of DarkNet53 + backbone return_idx [1,2,3,4]. + shortcut_num (list): the number of convolution layers in each shortcut. + [3,2,1] by default, means DarkNet53 backbone return_idx_1 has 3 convs + in its shortcut, return_idx_2 has 2 convs and return_idx_3 has 1 conv. + """ + def __init__(self, - planes=[256, 128, 64], - shortcut_num=[1, 2, 3], - ch_in=[1024, 256, 128]): + in_channels=[128, 256, 512, 1024], + shortcut_num=[3, 2, 1]): super(TTFFPN, self).__init__() - self.planes = planes - self.shortcut_num = shortcut_num + self.planes = [c // 2 for c in in_channels[:-1]][::-1] + self.shortcut_num = shortcut_num[::-1] self.shortcut_len = len(shortcut_num) - self.ch_in = ch_in + self.ch_in = in_channels[::-1] + self.upsample_list = [] self.shortcut_list = [] for i, out_c in enumerate(self.planes): + in_c = self.ch_in[i] if i == 0 else self.ch_in[i] // 2 upsample = self.add_sublayer( 'upsample.' + str(i), Upsample( - self.ch_in[i], out_c, name='upsample.' + str(i))) + in_c, out_c, name='upsample.' + str(i))) self.upsample_list.append(upsample) if i < self.shortcut_len: shortcut = self.add_sublayer( @@ -121,3 +135,11 @@ class TTFFPN(nn.Layer): shortcut = self.shortcut_list[i](inputs[-i - 2]) feat = feat + shortcut return feat + + @classmethod + def from_config(cls, cfg, input_shape): + return {'in_channels': [i.channels for i in input_shape], } + + @property + def out_shape(self): + return [ShapeSpec(channels=self.planes[-1], )]