diff --git a/dygraph/configs/hrnet/README.md b/dygraph/configs/hrnet/README.md index 5b8cb5cab1038775ee5e27ac3187c8c478f8fa60..beb17d1e0dd08bf3bea82a38421d8ca8e8362e9b 100644 --- a/dygraph/configs/hrnet/README.md +++ b/dygraph/configs/hrnet/README.md @@ -28,7 +28,7 @@ ## Model Zoo -| Backbone | Type | deformable Conv | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs | -| :---------------------- | :------------- | :---: | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: | -| HRNetV2p_W18 | Faster | False | 2 | 1x | - | 35.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml) | -| HRNetV2p_W18 | Faster | False | 2 | 2x | - | 37.7 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml) | +| Backbone | Type | Image/gpu | Lr schd | Inf time (fps) | Box AP | Mask AP | Download | Configs | +| :---------------------- | :------------- | :-------: | :-----: | :------------: | :----: | :-----: | :----------------------------------------------------------: | :-----: | +| HRNetV2p_W18 | Faster | 1 | 1x | - | 36.8 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml) | +| HRNetV2p_W18 | Faster | 1 | 2x | - | 39.0 | - | [model](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/faster_rcnn_hrnetv2p_w18_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_2x_coco.yml) | diff --git a/dygraph/configs/hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml b/dygraph/configs/hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml index 0f6fb8f11d7173d0725354f661d4596bfedf8064..ee2476800bf4e84579c6f9711ffcfb579d78dabb 100644 --- a/dygraph/configs/hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml +++ b/dygraph/configs/hrnet/_base_/faster_rcnn_hrnetv2p_w18.yml @@ -1,14 +1,8 @@ architecture: FasterRCNN pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/HRNet_W18_C_pretrained.tar -weights: output/faster_rcnn_hrnetv2p_w18_1x_coco/model_final load_static_weights: True -# Model Achitecture FasterRCNN: - # model anchor info flow - anchor: Anchor - proposal: Proposal - # model feat info flow backbone: HRNet neck: HRFPN rpn_head: RPNHead @@ -26,65 +20,51 @@ HRFPN: share_conv: false RPNHead: - rpn_feat: - name: RPNFeat - feat_in: 256 - feat_out: 256 - anchor_per_position: 3 - rpn_channel: 256 - -Anchor: anchor_generator: - name: AnchorGeneratorRPN aspect_ratios: [0.5, 1.0, 2.0] - anchor_start_size: 32 - stride: [4., 4.] - anchor_target_generator: - name: AnchorTargetGeneratorRPN + anchor_sizes: [[32], [64], [128], [256], [512]] + strides: [4, 8, 16, 32, 64] + rpn_target_assign: batch_size_per_im: 256 fg_fraction: 0.5 negative_overlap: 0.3 positive_overlap: 0.7 - straddle_thresh: 0.0 - -Proposal: - proposal_generator: - name: ProposalGenerator + use_random: True + train_proposal: + min_size: 0.0 + nms_thresh: 0.7 + pre_nms_top_n: 2000 + post_nms_top_n: 2000 + topk_after_collect: True + test_proposal: min_size: 0.0 nms_thresh: 0.7 - train_pre_nms_top_n: 2000 - train_post_nms_top_n: 2000 - infer_pre_nms_top_n: 1000 - infer_post_nms_top_n: 1000 - proposal_target_generator: - name: ProposalTargetGenerator - batch_size_per_im: 512 - bbox_reg_weights: [0.1, 0.1, 0.2, 0.2] - bg_thresh_hi: [0.5,] - bg_thresh_lo: [0.0,] - fg_thresh: [0.5,] - fg_fraction: 0.25 + pre_nms_top_n: 1000 + post_nms_top_n: 1000 BBoxHead: - bbox_feat: - name: BBoxFeat - roi_extractor: - name: RoIAlign - resolution: 7 - sampling_ratio: 2 - head_feat: - name: TwoFCHead - in_dim: 256 - mlp_dim: 1024 - in_feat: 1024 + head: TwoFCHead + roi_extractor: + resolution: 7 + sampling_ratio: 0 + aligned: True + bbox_assigner: BBoxAssigner + +BBoxAssigner: + batch_size_per_im: 512 + bg_thresh: 0.5 + fg_thresh: 0.5 + fg_fraction: 0.25 + use_random: True + +TwoFCHead: + mlp_dim: 1024 BBoxPostProcess: - decode: - name: RCNNBox - num_classes: 81 - batch_size: 1 + decode: RCNNBox nms: name: MultiClassNMS keep_top_k: 100 score_threshold: 0.05 nms_threshold: 0.5 + normalized: true diff --git a/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml b/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml index f68bac5feae6ecfce1d807290fd1b320cf1d83b9..6ff05964c41e05b2d7aaee9bf6ef330cee2337c0 100644 --- a/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml +++ b/dygraph/configs/hrnet/faster_rcnn_hrnetv2p_w18_1x_coco.yml @@ -6,6 +6,9 @@ _BASE_: [ '../runtime.yml', ] +weights: output/faster_rcnn_hrnetv2p_w18_1x_coco/model_final +epoch: 12 + LearningRate: base_lr: 0.02 schedulers: diff --git a/dygraph/ppdet/modeling/backbones/hrnet.py b/dygraph/ppdet/modeling/backbones/hrnet.py index 71715a4b5a36d067887097a7e202da27bc2ca400..4450bd9a597cf8f0edda241b2ac1201bd4906684 100644 --- a/dygraph/ppdet/modeling/backbones/hrnet.py +++ b/dygraph/ppdet/modeling/backbones/hrnet.py @@ -22,6 +22,7 @@ from numbers import Integral import math from ppdet.core.workspace import register, serializable +from ..shape_spec import ShapeSpec __all__ = ['HRNet'] @@ -577,6 +578,8 @@ class HRNet(nn.Layer): channels_2, channels_3, channels_4 = self.channels[width] num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3 + self._out_channels = channels_4 + self._out_strides = [4, 8, 16, 32] self.conv_layer1_1 = ConvNormLayer( ch_in=3, @@ -666,3 +669,11 @@ class HRNet(nn.Layer): res.append(layer) return res + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self._out_channels[i], stride=self._out_strides[i]) + for i in self.return_idx + ] diff --git a/dygraph/ppdet/modeling/necks/hrfpn.py b/dygraph/ppdet/modeling/necks/hrfpn.py index f06b3cacec7231ecd72914e5b1ae493c27cf50df..7afbbc0ea2cf25584a234ed731da25628b36c29b 100644 --- a/dygraph/ppdet/modeling/necks/hrfpn.py +++ b/dygraph/ppdet/modeling/necks/hrfpn.py @@ -18,6 +18,7 @@ from paddle import ParamAttr import paddle.nn as nn from paddle.regularizer import L2Decay from ppdet.core.workspace import register, serializable +from ..shape_spec import ShapeSpec __all__ = ['HRFPN'] @@ -26,23 +27,28 @@ __all__ = ['HRFPN'] class HRFPN(nn.Layer): """ Args: - in_channel (int): number of input feature channels from backbone + in_channels (list): number of input feature channels from backbone out_channel (int): number of output feature channels share_conv (bool): whether to share conv for different layers' reduction - spatial_scale (list): feature map scaling factor + spatial_scales (list): feature map scaling factor + extra_stage (int): add extra stage for returning HRFPN fpn_feats """ - def __init__( - self, - in_channel=270, - out_channel=256, - share_conv=False, - spatial_scale=[1. / 4, 1. / 8, 1. / 16, 1. / 32, 1. / 64], ): + def __init__(self, + in_channels=[18, 36, 72, 144], + out_channel=256, + share_conv=False, + extra_stage=1, + spatial_scales=[1. / 4, 1. / 8, 1. / 16, 1. / 32]): super(HRFPN, self).__init__() + in_channel = sum(in_channels) self.in_channel = in_channel self.out_channel = out_channel self.share_conv = share_conv - self.spatial_scale = spatial_scale + for i in range(extra_stage): + spatial_scales = spatial_scales + [spatial_scales[-1] / 2.] + self.spatial_scales = spatial_scales + self.num_out = len(self.spatial_scales) self.reduction = nn.Conv2D( in_channels=in_channel, @@ -50,7 +56,7 @@ class HRFPN(nn.Layer): kernel_size=1, weight_attr=ParamAttr(name='hrfpn_reduction_weights'), bias_attr=False) - self.num_out = len(self.spatial_scale) + if share_conv: self.fpn_conv = nn.Conv2D( in_channels=out_channel, @@ -106,5 +112,20 @@ class HRFPN(nn.Layer): conv = conv_func(outs[i]) outputs.append(conv) - fpn_feat = [outputs[k] for k in range(self.num_out)] - return fpn_feat, self.spatial_scale + fpn_feats = [outputs[k] for k in range(self.num_out)] + return fpn_feats + + @classmethod + def from_config(cls, cfg, input_shape): + return { + 'in_channels': [i.channels for i in input_shape], + 'spatial_scales': [1.0 / i.stride for i in input_shape], + } + + @property + def out_shape(self): + return [ + ShapeSpec( + channels=self.out_channel, stride=1. / s) + for s in self.spatial_scales + ]