From 866332ed600479980bbde31a0bd909435be5280a Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 15 Apr 2021 14:54:25 +0800 Subject: [PATCH] add ttfnet enhance (#2609) * add ttfnet enhance * add doc * fix pafnet training --- configs/ttfnet/README.md | 39 +++- configs/ttfnet/_base_/optimizer_10x.yml | 19 ++ configs/ttfnet/_base_/optimizer_20x.yml | 20 ++ configs/ttfnet/_base_/pafnet.yml | 41 ++++ configs/ttfnet/_base_/pafnet_lite.yml | 44 ++++ configs/ttfnet/_base_/pafnet_lite_reader.yml | 40 ++++ configs/ttfnet/_base_/pafnet_reader.yml | 40 ++++ configs/ttfnet/_base_/ttfnet_darknet53.yml | 5 +- configs/ttfnet/pafnet_10x_coco.yml | 8 + .../pafnet_lite_mobilenet_v3_20x_coco.yml | 8 + ppdet/data/source/dataset.py | 3 + ppdet/data/transform/batch_operators.py | 2 + ppdet/data/transform/gridmask_utils.py | 4 +- ppdet/data/transform/operators.py | 43 ++-- ppdet/modeling/backbones/mobilenet_v3.py | 12 +- ppdet/modeling/heads/ttf_head.py | 133 +++++++++--- ppdet/modeling/layers.py | 72 ++++++- ppdet/modeling/necks/ttf_fpn.py | 167 ++++++++++++--- .../configs/anchor_free/pafnet_10x_coco.yml | 170 +++++++++++++++ .../pafnet_lite_mobilenet_v3_20x_coco.yml | 171 ++++++++++++++++ .../ppdet/modeling/anchor_heads/ttf_head.py | 193 +++++++++++++++++- 21 files changed, 1142 insertions(+), 92 deletions(-) create mode 100644 configs/ttfnet/_base_/optimizer_10x.yml create mode 100644 configs/ttfnet/_base_/optimizer_20x.yml create mode 100644 configs/ttfnet/_base_/pafnet.yml create mode 100644 configs/ttfnet/_base_/pafnet_lite.yml create mode 100644 configs/ttfnet/_base_/pafnet_lite_reader.yml create mode 100644 configs/ttfnet/_base_/pafnet_reader.yml create mode 100644 configs/ttfnet/pafnet_10x_coco.yml create mode 100644 configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml create mode 100644 static/configs/anchor_free/pafnet_10x_coco.yml create mode 100644 static/configs/anchor_free/pafnet_lite_mobilenet_v3_20x_coco.yml diff --git a/configs/ttfnet/README.md b/configs/ttfnet/README.md index b5a6cfd41..7c89ca2b0 100644 --- a/configs/ttfnet/README.md +++ b/configs/ttfnet/README.md @@ -1,4 +1,4 @@ -# TTFNet +# 1. TTFNet ## 简介 @@ -15,6 +15,43 @@ TTFNet是一种用于实时目标检测且对训练时间友好的网络,对Ce | :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | | DarkNet53 | TTFNet | 12 | 1x | ---- | 33.5 | [下载链接](https://paddledet.bj.bcebos.com/models/ttfnet_darknet53_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ttfnet/ttfnet_darknet53_1x_coco.yml) | + + + + +# 2. PAFNet + +## 简介 + +PAFNet(Paddle Anchor Free)是PaddleDetection基于TTFNet的优化模型,精度达到anchor free领域SOTA水平,同时产出移动端轻量级模型PAFNet-Lite + +PAFNet系列模型从如下方面优化TTFNet模型: + +- [CutMix](https://arxiv.org/abs/1905.04899) +- 更优的骨干网络: ResNet50vd-DCN +- 更大的训练batch size: 8 GPUs,每GPU batch_size=18 +- Synchronized Batch Normalization +- [Deformable Convolution](https://arxiv.org/abs/1703.06211) +- [Exponential Moving Average](https://www.investopedia.com/terms/e/ema.asp) +- 更优的预训练模型 + + +## 模型库 + +| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | +| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| ResNet50vd | PAFNet | 18 | 10x | ---- | 42.2 | [下载链接](https://paddledet.bj.bcebos.com/models/pafnet_10x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ttfnet/pafnet_10x_coco.yml) | + + + +### PAFNet-Lite + +| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 | Box AP | 麒麟990延时(ms) | 体积(M) | 下载 | 配置文件 | +| :-------------- | :------------- | :-----: | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| MobileNetv3 | PAFNet-Lite | 12 | 20x | 23.9 | 26.00 | 14 | [下载链接](https://paddledet.bj.bcebos.com/models/pafnet_lite_mobilenet_v3_20x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml) | + + + ## Citations ``` @article{liu2019training, diff --git a/configs/ttfnet/_base_/optimizer_10x.yml b/configs/ttfnet/_base_/optimizer_10x.yml new file mode 100644 index 000000000..dd2c29d96 --- /dev/null +++ b/configs/ttfnet/_base_/optimizer_10x.yml @@ -0,0 +1,19 @@ +epoch: 120 + +LearningRate: + base_lr: 0.015 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [80, 110] + - !LinearWarmup + start_factor: 0.2 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0004 + type: L2 diff --git a/configs/ttfnet/_base_/optimizer_20x.yml b/configs/ttfnet/_base_/optimizer_20x.yml new file mode 100644 index 000000000..4dd349220 --- /dev/null +++ b/configs/ttfnet/_base_/optimizer_20x.yml @@ -0,0 +1,20 @@ +epoch: 240 + +LearningRate: + base_lr: 0.015 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [160, 220] + - !LinearWarmup + start_factor: 0.2 + steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 35 + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0004 + type: L2 diff --git a/configs/ttfnet/_base_/pafnet.yml b/configs/ttfnet/_base_/pafnet.yml new file mode 100644 index 000000000..5319fe6c8 --- /dev/null +++ b/configs/ttfnet/_base_/pafnet.yml @@ -0,0 +1,41 @@ +architecture: TTFNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_pretrained.pdparams +norm_type: sync_bn +use_ema: true +ema_decay: 0.9998 + +TTFNet: + backbone: ResNet + neck: TTFFPN + ttf_head: TTFHead + post_process: BBoxPostProcess + +ResNet: + depth: 50 + variant: d + return_idx: [0, 1, 2, 3] + freeze_at: -1 + norm_decay: 0. + variant: d + dcn_v2_stages: [1, 2, 3] + +TTFFPN: + planes: [256, 128, 64] + shortcut_num: [3, 2, 1] + +TTFHead: + dcn_head: true + hm_loss: + name: CTFocalLoss + loss_weight: 1. + wh_loss: + name: GIoULoss + loss_weight: 5. + reduction: sum + +BBoxPostProcess: + decode: + name: TTFBox + max_per_img: 100 + score_thresh: 0.01 + down_ratio: 4 diff --git a/configs/ttfnet/_base_/pafnet_lite.yml b/configs/ttfnet/_base_/pafnet_lite.yml new file mode 100644 index 000000000..5ed2fa235 --- /dev/null +++ b/configs/ttfnet/_base_/pafnet_lite.yml @@ -0,0 +1,44 @@ +architecture: TTFNet +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/MobileNetV3_large_x1_0_ssld_pretrained.pdparams +norm_type: sync_bn + +TTFNet: + backbone: MobileNetV3 + neck: TTFFPN + ttf_head: TTFHead + post_process: BBoxPostProcess + +MobileNetV3: + scale: 1.0 + model_name: large + feature_maps: [5, 8, 14, 17] + with_extra_blocks: true + lr_mult_list: [0.25, 0.25, 0.5, 0.5, 0.75] + conv_decay: 0.00001 + norm_decay: 0.0 + extra_block_filters: [] + +TTFFPN: + planes: [96, 48, 24] + shortcut_num: [2, 2, 1] + lite_neck: true + fusion_method: concat + +TTFHead: + hm_head_planes: 48 + wh_head_planes: 24 + lite_head: true + hm_loss: + name: CTFocalLoss + loss_weight: 1. + wh_loss: + name: GIoULoss + loss_weight: 5. + reduction: sum + +BBoxPostProcess: + decode: + name: TTFBox + max_per_img: 100 + score_thresh: 0.01 + down_ratio: 4 diff --git a/configs/ttfnet/_base_/pafnet_lite_reader.yml b/configs/ttfnet/_base_/pafnet_lite_reader.yml new file mode 100644 index 000000000..446a13a3c --- /dev/null +++ b/configs/ttfnet/_base_/pafnet_lite_reader.yml @@ -0,0 +1,40 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - Cutmix: {alpha: 1.5, beta: 1.5} + - RandomDistort: {} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {aspect_ratio: NULL, cover_all_box: True} + - RandomFlip: {} + - GridMask: {upper_iter: 300000} + batch_transforms: + - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512], random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375], is_scale: false} + - Permute: {} + - Gt2TTFTarget: {down_ratio: 4} + - PadBatch: {pad_to_stride: 32} + batch_size: 12 + shuffle: true + drop_last: true + use_shared_memory: true + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 1, target_size: [320, 320], keep_ratio: False} + - NormalizeImage: {is_scale: false, mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375]} + - Permute: {} + batch_size: 1 + drop_last: false + drop_empty: false + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 1, target_size: [320, 320], keep_ratio: False} + - NormalizeImage: {is_scale: false, mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375]} + - Permute: {} + batch_size: 1 + drop_last: false + drop_empty: false diff --git a/configs/ttfnet/_base_/pafnet_reader.yml b/configs/ttfnet/_base_/pafnet_reader.yml new file mode 100644 index 000000000..ea90a134f --- /dev/null +++ b/configs/ttfnet/_base_/pafnet_reader.yml @@ -0,0 +1,40 @@ +worker_num: 2 +TrainReader: + sample_transforms: + - Decode: {} + - Cutmix: {alpha: 1.5, beta: 1.5} + - RandomDistort: {random_apply: false, random_channel: true} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {aspect_ratio: NULL, cover_all_box: True} + - RandomFlip: {prob: 0.5} + batch_transforms: + - BatchRandomResize: {target_size: [416, 448, 480, 512, 544, 576, 608, 640, 672], keep_ratio: false} + - NormalizeImage: {mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375], is_scale: false} + - Permute: {} + - Gt2TTFTarget: {down_ratio: 4} + - PadBatch: {pad_to_stride: 32} + batch_size: 18 + shuffle: true + drop_last: true + use_shared_memory: true + mixup_epoch: 100 + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 1, target_size: [512, 512], keep_ratio: False} + - NormalizeImage: {is_scale: false, mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375]} + - Permute: {} + batch_size: 1 + drop_last: false + drop_empty: false + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {interp: 1, target_size: [512, 512], keep_ratio: False} + - NormalizeImage: {is_scale: false, mean: [123.675, 116.28, 103.53], std: [58.395, 57.12, 57.375]} + - Permute: {} + batch_size: 1 + drop_last: false + drop_empty: false diff --git a/configs/ttfnet/_base_/ttfnet_darknet53.yml b/configs/ttfnet/_base_/ttfnet_darknet53.yml index 90b0c3361..05c7dce65 100644 --- a/configs/ttfnet/_base_/ttfnet_darknet53.yml +++ b/configs/ttfnet/_base_/ttfnet_darknet53.yml @@ -14,8 +14,9 @@ DarkNet: norm_type: bn norm_decay: 0.0004 -# use default config -# TTFFPN: +TTFFPN: + planes: [256, 128, 64] + shortcut_num: [3, 2, 1] TTFHead: hm_loss: diff --git a/configs/ttfnet/pafnet_10x_coco.yml b/configs/ttfnet/pafnet_10x_coco.yml new file mode 100644 index 000000000..b14a2bc91 --- /dev/null +++ b/configs/ttfnet/pafnet_10x_coco.yml @@ -0,0 +1,8 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_10x.yml', + '_base_/pafnet.yml', + '_base_/pafnet_reader.yml', +] +weights: output/pafnet_10x_coco/model_final diff --git a/configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml b/configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml new file mode 100644 index 000000000..577af1635 --- /dev/null +++ b/configs/ttfnet/pafnet_lite_mobilenet_v3_20x_coco.yml @@ -0,0 +1,8 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/optimizer_20x.yml', + '_base_/pafnet_lite.yml', + '_base_/pafnet_lite_reader.yml', +] +weights: output/pafnet_lite_mobilenet_v3_10x_coco/model_final diff --git a/ppdet/data/source/dataset.py b/ppdet/data/source/dataset.py index 07330a6b2..6ca027301 100644 --- a/ppdet/data/source/dataset.py +++ b/ppdet/data/source/dataset.py @@ -55,6 +55,7 @@ class DetDataset(Dataset): self.sample_num = sample_num self.use_default_label = use_default_label self._epoch = 0 + self._curr_iter = 0 def __len__(self, ): return len(self.roidbs) @@ -76,6 +77,8 @@ class DetDataset(Dataset): copy.deepcopy(self.roidbs[np.random.randint(n)]) for _ in range(3) ] + roidb['curr_iter'] = self._curr_iter + self._curr_iter += 1 return self.transform(roidb) diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index f65c0dec1..e09c04796 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -533,6 +533,8 @@ class Gt2TTFTarget(BaseOperator): sample.pop('is_crowd') sample.pop('gt_class') sample.pop('gt_bbox') + if 'gt_score' in sample: + sample.pop('gt_score') return samples def draw_truncate_gaussian(self, heatmap, center, h_radius, w_radius): diff --git a/ppdet/data/transform/gridmask_utils.py b/ppdet/data/transform/gridmask_utils.py index 115cb1e9d..2b3e72408 100644 --- a/ppdet/data/transform/gridmask_utils.py +++ b/ppdet/data/transform/gridmask_utils.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image -class GridMask(object): +class Gridmask(object): def __init__(self, use_h=True, use_w=True, @@ -30,7 +30,7 @@ class GridMask(object): mode=1, prob=0.7, upper_iter=360000): - super(GridMask, self).__init__() + super(Gridmask, self).__init__() self.use_h = use_h self.use_w = use_w self.rotate = rotate diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index eb21ee5a1..65608f367 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -308,8 +308,8 @@ class GridMask(BaseOperator): self.prob = prob self.upper_iter = upper_iter - from .gridmask_utils import GridMask - self.gridmask_op = GridMask( + from .gridmask_utils import Gridmask + self.gridmask_op = Gridmask( use_h, use_w, rotate=rotate, @@ -1516,14 +1516,14 @@ class Cutmix(BaseOperator): bbx2 = np.clip(cx + cut_w // 2, 0, w - 1) bby2 = np.clip(cy + cut_h // 2, 0, h - 1) - img_1 = np.zeros((h, w, img1.shape[2]), 'float32') - img_1[:img1.shape[0], :img1.shape[1], :] = \ + img_1_pad = np.zeros((h, w, img1.shape[2]), 'float32') + img_1_pad[:img1.shape[0], :img1.shape[1], :] = \ img1.astype('float32') - img_2 = np.zeros((h, w, img2.shape[2]), 'float32') - img_2[:img2.shape[0], :img2.shape[1], :] = \ + img_2_pad = np.zeros((h, w, img2.shape[2]), 'float32') + img_2_pad[:img2.shape[0], :img2.shape[1], :] = \ img2.astype('float32') - img_1[bby1:bby2, bbx1:bbx2, :] = img2[bby1:bby2, bbx1:bbx2, :] - return img_1 + img_1_pad[bby1:bby2, bbx1:bbx2, :] = img_2_pad[bby1:bby2, bbx1:bbx2, :] + return img_1_pad def __call__(self, sample, context=None): if not isinstance(sample, Sequence): @@ -1546,16 +1546,27 @@ class Cutmix(BaseOperator): gt_class1 = sample[0]['gt_class'] gt_class2 = sample[1]['gt_class'] gt_class = np.concatenate((gt_class1, gt_class2), axis=0) - gt_score1 = sample[0]['gt_score'] - gt_score2 = sample[1]['gt_score'] + gt_score1 = np.ones_like(sample[0]['gt_class']) + gt_score2 = np.ones_like(sample[1]['gt_class']) gt_score = np.concatenate( (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0) - sample = sample[0] - sample['image'] = img - sample['gt_bbox'] = gt_bbox - sample['gt_score'] = gt_score - sample['gt_class'] = gt_class - return sample + result = copy.deepcopy(sample[0]) + result['image'] = img + result['gt_bbox'] = gt_bbox + result['gt_score'] = gt_score + result['gt_class'] = gt_class + if 'is_crowd' in sample[0]: + is_crowd1 = sample[0]['is_crowd'] + is_crowd2 = sample[1]['is_crowd'] + is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0) + result['is_crowd'] = is_crowd + if 'difficult' in sample[0]: + is_difficult1 = sample[0]['difficult'] + is_difficult2 = sample[1]['difficult'] + is_difficult = np.concatenate( + (is_difficult1, is_difficult2), axis=0) + result['difficult'] = is_difficult + return result @register_op diff --git a/ppdet/modeling/backbones/mobilenet_v3.py b/ppdet/modeling/backbones/mobilenet_v3.py index 1cebf5ef1..d7178c913 100644 --- a/ppdet/modeling/backbones/mobilenet_v3.py +++ b/ppdet/modeling/backbones/mobilenet_v3.py @@ -330,16 +330,16 @@ class MobileNetV3(nn.Layer): [3, 16, 16, False, "relu", 1], [3, 64, 24, False, "relu", 2], [3, 72, 24, False, "relu", 1], - [5, 72, 40, True, "relu", 2], + [5, 72, 40, True, "relu", 2], # RCNN output [5, 120, 40, True, "relu", 1], [5, 120, 40, True, "relu", 1], # YOLOv3 output - [3, 240, 80, False, "hard_swish", 2], + [3, 240, 80, False, "hard_swish", 2], # RCNN output [3, 200, 80, False, "hard_swish", 1], [3, 184, 80, False, "hard_swish", 1], [3, 184, 80, False, "hard_swish", 1], [3, 480, 112, True, "hard_swish", 1], [3, 672, 112, True, "hard_swish", 1], # YOLOv3 output - [5, 672, 160, True, "hard_swish", 2], # SSD/SSDLite output + [5, 672, 160, True, "hard_swish", 2], # SSD/SSDLite/RCNN output [5, 960, 160, True, "hard_swish", 1], [5, 960, 160, True, "hard_swish", 1], # YOLOv3 output ] @@ -347,14 +347,14 @@ class MobileNetV3(nn.Layer): self.cfg = [ # k, exp, c, se, nl, s, [3, 16, 16, True, "relu", 2], - [3, 72, 24, False, "relu", 2], + [3, 72, 24, False, "relu", 2], # RCNN output [3, 88, 24, False, "relu", 1], # YOLOv3 output - [5, 96, 40, True, "hard_swish", 2], + [5, 96, 40, True, "hard_swish", 2], # RCNN output [5, 240, 40, True, "hard_swish", 1], [5, 240, 40, True, "hard_swish", 1], [5, 120, 48, True, "hard_swish", 1], [5, 144, 48, True, "hard_swish", 1], # YOLOv3 output - [5, 288, 96, True, "hard_swish", 2], # SSD/SSDLite output + [5, 288, 96, True, "hard_swish", 2], # SSD/SSDLite/RCNN output [5, 576, 96, True, "hard_swish", 1], [5, 576, 96, True, "hard_swish", 1], # YOLOv3 output ] diff --git a/ppdet/modeling/heads/ttf_head.py b/ppdet/modeling/heads/ttf_head.py index de030dd3b..005918295 100644 --- a/ppdet/modeling/heads/ttf_head.py +++ b/ppdet/modeling/heads/ttf_head.py @@ -19,6 +19,7 @@ from paddle import ParamAttr from paddle.nn.initializer import Constant, Uniform, Normal from paddle.regularizer import L2Decay from ppdet.core.workspace import register +from ppdet.modeling.layers import DeformableConvV2, LiteConv import numpy as np @@ -30,27 +31,61 @@ class HMHead(nn.Layer): ch_out (int): The channel number of output Tensor. num_classes (int): Number of classes. conv_num (int): The convolution number of hm_feat. + dcn_head(bool): whether use dcn in head. False by default. + lite_head(bool): whether use lite version. False by default. + norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional. + bn by default + Return: Heatmap head output """ - __shared__ = ['num_classes'] + __shared__ = ['num_classes', 'norm_type'] - def __init__(self, ch_in, ch_out=128, num_classes=80, conv_num=2): + def __init__( + self, + ch_in, + ch_out=128, + num_classes=80, + conv_num=2, + dcn_head=False, + lite_head=False, + norm_type='bn', ): super(HMHead, self).__init__() head_conv = nn.Sequential() for i in range(conv_num): name = 'conv.{}'.format(i) - head_conv.add_sublayer( - name, - nn.Conv2D( - in_channels=ch_in if i == 0 else ch_out, - out_channels=ch_out, - kernel_size=3, - padding=1, - weight_attr=ParamAttr(initializer=Normal(0, 0.01)), - bias_attr=ParamAttr( - learning_rate=2., regularizer=L2Decay(0.)))) - head_conv.add_sublayer(name + '.act', nn.ReLU()) + if lite_head: + lite_name = 'hm.' + name + head_conv.add_sublayer( + lite_name, + LiteConv( + in_channels=ch_in if i == 0 else ch_out, + out_channels=ch_out, + norm_type=norm_type, + name=lite_name)) + head_conv.add_sublayer(lite_name + '.act', nn.ReLU6()) + else: + if dcn_head: + head_conv.add_sublayer( + name, + DeformableConvV2( + in_channels=ch_in if i == 0 else ch_out, + out_channels=ch_out, + kernel_size=3, + weight_attr=ParamAttr(initializer=Normal(0, 0.01)), + name='hm.' + name)) + else: + head_conv.add_sublayer( + name, + nn.Conv2D( + in_channels=ch_in if i == 0 else ch_out, + out_channels=ch_out, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0, 0.01)), + bias_attr=ParamAttr( + learning_rate=2., regularizer=L2Decay(0.)))) + head_conv.add_sublayer(name + '.act', nn.ReLU()) self.feat = self.add_sublayer('hm_feat', head_conv) bias_init = float(-np.log((1 - 0.01) / 0.01)) self.head = self.add_sublayer( @@ -78,26 +113,59 @@ class WHHead(nn.Layer): ch_in (int): The channel number of input Tensor. ch_out (int): The channel number of output Tensor. conv_num (int): The convolution number of wh_feat. + dcn_head(bool): whether use dcn in head. False by default. + lite_head(bool): whether use lite version. False by default. + norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional. + bn by default Return: Width & Height head output """ + __shared__ = ['norm_type'] - def __init__(self, ch_in, ch_out=64, conv_num=2): + def __init__(self, + ch_in, + ch_out=64, + conv_num=2, + dcn_head=False, + lite_head=False, + norm_type='bn'): super(WHHead, self).__init__() head_conv = nn.Sequential() for i in range(conv_num): name = 'conv.{}'.format(i) - head_conv.add_sublayer( - name, - nn.Conv2D( - in_channels=ch_in if i == 0 else ch_out, - out_channels=ch_out, - kernel_size=3, - padding=1, - weight_attr=ParamAttr(initializer=Normal(0, 0.001)), - bias_attr=ParamAttr( - learning_rate=2., regularizer=L2Decay(0.)))) - head_conv.add_sublayer(name + '.act', nn.ReLU()) + if lite_head: + lite_name = 'wh.' + name + head_conv.add_sublayer( + lite_name, + LiteConv( + in_channels=ch_in if i == 0 else ch_out, + out_channels=ch_out, + norm_type=norm_type, + name=lite_name)) + head_conv.add_sublayer(lite_name + '.act', nn.ReLU6()) + else: + if dcn_head: + head_conv.add_sublayer( + name, + DeformableConvV2( + in_channels=ch_in if i == 0 else ch_out, + out_channels=ch_out, + kernel_size=3, + weight_attr=ParamAttr(initializer=Normal(0, 0.01)), + name='wh.' + name)) + else: + head_conv.add_sublayer( + name, + nn.Conv2D( + in_channels=ch_in if i == 0 else ch_out, + out_channels=ch_out, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0, 0.01)), + bias_attr=ParamAttr( + learning_rate=2., regularizer=L2Decay(0.)))) + head_conv.add_sublayer(name + '.act', nn.ReLU()) + self.feat = self.add_sublayer('wh_feat', head_conv) self.head = self.add_sublayer( 'wh_head', @@ -137,9 +205,12 @@ class TTFHead(nn.Layer): 16.0 by default. down_ratio (int): the actual down_ratio is calculated by base_down_ratio (default 16) and the number of upsample layers. + lite_head(bool): whether use lite version. False by default. + norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional. + bn by default """ - __shared__ = ['num_classes', 'down_ratio'] + __shared__ = ['num_classes', 'down_ratio', 'norm_type'] __inject__ = ['hm_loss', 'wh_loss'] def __init__(self, @@ -152,12 +223,16 @@ class TTFHead(nn.Layer): hm_loss='CTFocalLoss', wh_loss='GIoULoss', wh_offset_base=16., - down_ratio=4): + down_ratio=4, + dcn_head=False, + lite_head=False, + norm_type='bn'): super(TTFHead, self).__init__() self.in_channels = in_channels self.hm_head = HMHead(in_channels, hm_head_planes, num_classes, - hm_head_conv_num) - self.wh_head = WHHead(in_channels, wh_head_planes, wh_head_conv_num) + hm_head_conv_num, dcn_head, lite_head, norm_type) + self.wh_head = WHHead(in_channels, wh_head_planes, wh_head_conv_num, + dcn_head, lite_head, norm_type) self.hm_loss = hm_loss self.wh_loss = wh_loss diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 99703112d..4924a2e18 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -23,7 +23,7 @@ from paddle import ParamAttr from paddle import to_tensor from paddle.nn import Conv2D, BatchNorm2D, GroupNorm import paddle.nn.functional as F -from paddle.nn.initializer import Normal, Constant +from paddle.nn.initializer import Normal, Constant, XavierUniform from paddle.regularizer import L2Decay from ppdet.core.workspace import register, serializable @@ -112,6 +112,7 @@ class ConvNormLayer(nn.Layer): ch_out, filter_size, stride, + groups=1, norm_type='bn', norm_decay=0., norm_groups=32, @@ -142,7 +143,7 @@ class ConvNormLayer(nn.Layer): kernel_size=filter_size, stride=stride, padding=(filter_size - 1) // 2, - groups=1, + groups=groups, weight_attr=ParamAttr( name=name + "_weight", initializer=initializer, @@ -158,7 +159,7 @@ class ConvNormLayer(nn.Layer): kernel_size=filter_size, stride=stride, padding=(filter_size - 1) // 2, - groups=1, + groups=groups, weight_attr=ParamAttr( name=name + "_weight", initializer=initializer, @@ -197,6 +198,71 @@ class ConvNormLayer(nn.Layer): return out +class LiteConv(nn.Layer): + def __init__(self, + in_channels, + out_channels, + stride=1, + with_act=True, + norm_type='sync_bn', + name=None): + super(LiteConv, self).__init__() + self.lite_conv = nn.Sequential() + conv1 = ConvNormLayer( + in_channels, + in_channels, + filter_size=5, + stride=stride, + groups=in_channels, + norm_type=norm_type, + initializer=XavierUniform(), + norm_name=name + '.conv1.norm', + name=name + '.conv1') + conv2 = ConvNormLayer( + in_channels, + out_channels, + filter_size=1, + stride=stride, + norm_type=norm_type, + initializer=XavierUniform(), + norm_name=name + '.conv2.norm', + name=name + '.conv2') + conv3 = ConvNormLayer( + out_channels, + out_channels, + filter_size=1, + stride=stride, + norm_type=norm_type, + initializer=XavierUniform(), + norm_name=name + '.conv3.norm', + name=name + '.conv3') + conv4 = ConvNormLayer( + out_channels, + out_channels, + filter_size=5, + stride=stride, + groups=out_channels, + norm_type=norm_type, + initializer=XavierUniform(), + norm_name=name + '.conv4.norm', + name=name + '.conv4') + conv_list = [conv1, conv2, conv3, conv4] + self.lite_conv.add_sublayer('conv1', conv1) + self.lite_conv.add_sublayer('relu6_1', nn.ReLU6()) + self.lite_conv.add_sublayer('conv2', conv2) + if with_act: + self.lite_conv.add_sublayer('relu6_2', nn.ReLU6()) + self.lite_conv.add_sublayer('conv3', conv3) + self.lite_conv.add_sublayer('relu6_3', nn.ReLU6()) + self.lite_conv.add_sublayer('conv4', conv4) + if with_act: + self.lite_conv.add_sublayer('relu6_4', nn.ReLU6()) + + def forward(self, inputs): + out = self.lite_conv(inputs) + return out + + @register @serializable class AnchorGeneratorRPN(object): diff --git a/ppdet/modeling/necks/ttf_fpn.py b/ppdet/modeling/necks/ttf_fpn.py index 1ae67bb8e..395bc1dc9 100644 --- a/ppdet/modeling/necks/ttf_fpn.py +++ b/ppdet/modeling/necks/ttf_fpn.py @@ -16,11 +16,11 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr -from paddle.nn.initializer import Constant, Uniform, Normal +from paddle.nn.initializer import Constant, Uniform, Normal, XavierUniform from paddle import ParamAttr from ppdet.core.workspace import register, serializable from paddle.regularizer import L2Decay -from ppdet.modeling.layers import DeformableConvV2 +from ppdet.modeling.layers import DeformableConvV2, ConvNormLayer, LiteConv import math from ppdet.modeling.ops import batch_norm from ..shape_spec import ShapeSpec @@ -29,7 +29,7 @@ __all__ = ['TTFFPN'] class Upsample(nn.Layer): - def __init__(self, ch_in, ch_out, name=None): + def __init__(self, ch_in, ch_out, norm_type='bn', name=None): super(Upsample, self).__init__() fan_in = ch_in * 3 * 3 stdv = 1. / math.sqrt(fan_in) @@ -46,7 +46,7 @@ class Upsample(nn.Layer): regularizer=L2Decay(0.)) self.bn = batch_norm( - ch_out, norm_type='bn', initializer=Constant(1.), name=name) + ch_out, norm_type=norm_type, initializer=Constant(1.), name=name) def forward(self, feat): dcn = self.dcn(feat) @@ -56,28 +56,105 @@ class Upsample(nn.Layer): return out +class DeConv(nn.Layer): + def __init__(self, ch_in, ch_out, norm_type='bn', name=None): + super(DeConv, self).__init__() + self.deconv = nn.Sequential() + conv1 = ConvNormLayer( + ch_in=ch_in, + ch_out=ch_out, + stride=1, + filter_size=1, + norm_type=norm_type, + initializer=XavierUniform(), + norm_name=name + '.conv1.norm', + name=name + '.conv1') + conv2 = nn.Conv2DTranspose( + in_channels=ch_out, + out_channels=ch_out, + kernel_size=4, + padding=1, + stride=2, + groups=ch_out, + weight_attr=ParamAttr(initializer=XavierUniform()), + bias_attr=False) + bn = batch_norm( + ch_out, norm_type=norm_type, norm_decay=0., name=name + '.bn') + conv3 = ConvNormLayer( + ch_in=ch_out, + ch_out=ch_out, + stride=1, + filter_size=1, + norm_type=norm_type, + initializer=XavierUniform(), + norm_name=name + '.conv3.norm', + name=name + '.conv3') + + self.deconv.add_sublayer('conv1', conv1) + self.deconv.add_sublayer('relu6_1', nn.ReLU6()) + self.deconv.add_sublayer('conv2', conv2) + self.deconv.add_sublayer('bn', bn) + self.deconv.add_sublayer('relu6_2', nn.ReLU6()) + self.deconv.add_sublayer('conv3', conv3) + self.deconv.add_sublayer('relu6_3', nn.ReLU6()) + + def forward(self, inputs): + return self.deconv(inputs) + + +class LiteUpsample(nn.Layer): + def __init__(self, ch_in, ch_out, norm_type='bn', name=None): + super(LiteUpsample, self).__init__() + self.deconv = DeConv( + ch_in, ch_out, norm_type=norm_type, name=name + '.deconv') + self.conv = LiteConv( + ch_in, ch_out, norm_type=norm_type, name=name + '.liteconv') + + def forward(self, inputs): + deconv_up = self.deconv(inputs) + conv = self.conv(inputs) + interp_up = F.interpolate(conv, scale_factor=2., mode='bilinear') + return deconv_up + interp_up + + class ShortCut(nn.Layer): - def __init__(self, layer_num, ch_out, name=None): + def __init__(self, + layer_num, + ch_in, + ch_out, + norm_type='bn', + lite_neck=False, + name=None): super(ShortCut, self).__init__() shortcut_conv = nn.Sequential() - ch_in = ch_out * 2 for i in range(layer_num): fan_out = 3 * 3 * ch_out std = math.sqrt(2. / fan_out) in_channels = ch_in if i == 0 else ch_out shortcut_name = name + '.conv.{}'.format(i) - shortcut_conv.add_sublayer( - shortcut_name, - nn.Conv2D( - in_channels=in_channels, - out_channels=ch_out, - kernel_size=3, - padding=1, - weight_attr=ParamAttr(initializer=Normal(0, std)), - bias_attr=ParamAttr( - learning_rate=2., regularizer=L2Decay(0.)))) - if i < layer_num - 1: - shortcut_conv.add_sublayer(shortcut_name + '.act', nn.ReLU()) + if lite_neck: + shortcut_conv.add_sublayer( + shortcut_name, + LiteConv( + in_channels=in_channels, + out_channels=ch_out, + with_act=i < layer_num - 1, + norm_type=norm_type, + name=shortcut_name)) + else: + shortcut_conv.add_sublayer( + shortcut_name, + nn.Conv2D( + in_channels=in_channels, + out_channels=ch_out, + kernel_size=3, + padding=1, + weight_attr=ParamAttr(initializer=Normal(0, std)), + bias_attr=ParamAttr( + learning_rate=2., regularizer=L2Decay(0.)))) + if i < layer_num - 1: + shortcut_conv.add_sublayer(shortcut_name + '.act', + nn.ReLU()) self.shortcut = self.add_sublayer('short', shortcut_conv) def forward(self, feat): @@ -93,35 +170,68 @@ class TTFFPN(nn.Layer): in_channels (list): number of input feature channels from backbone. [128,256,512,1024] by default, means the channels of DarkNet53 backbone return_idx [1,2,3,4]. + planes (list): the number of output feature channels of FPN. + [256, 128, 64] by default shortcut_num (list): the number of convolution layers in each shortcut. [3,2,1] by default, means DarkNet53 backbone return_idx_1 has 3 convs in its shortcut, return_idx_2 has 2 convs and return_idx_3 has 1 conv. + norm_type (string): norm type, 'sync_bn', 'bn', 'gn' are optional. + bn by default + lite_neck (bool): whether to use lite conv in TTFNet FPN, + False by default + fusion_method (string): the method to fusion upsample and lateral layer. + 'add' and 'concat' are optional, add by default """ + __shared__ = ['norm_type'] + def __init__(self, - in_channels=[128, 256, 512, 1024], - shortcut_num=[3, 2, 1]): + in_channels, + planes=[256, 128, 64], + shortcut_num=[3, 2, 1], + norm_type='bn', + lite_neck=False, + fusion_method='add'): super(TTFFPN, self).__init__() - self.planes = [c // 2 for c in in_channels[:-1]][::-1] + self.planes = planes self.shortcut_num = shortcut_num[::-1] self.shortcut_len = len(shortcut_num) self.ch_in = in_channels[::-1] + self.fusion_method = fusion_method self.upsample_list = [] self.shortcut_list = [] + self.upper_list = [] for i, out_c in enumerate(self.planes): - in_c = self.ch_in[i] if i == 0 else self.ch_in[i] // 2 + in_c = self.ch_in[i] if i == 0 else self.upper_list[-1] + upsample_module = LiteUpsample if lite_neck else Upsample upsample = self.add_sublayer( 'upsample.' + str(i), - Upsample( - in_c, out_c, name='upsample.' + str(i))) + upsample_module( + in_c, + out_c, + norm_type=norm_type, + name='deconv_layers.' + str(i))) self.upsample_list.append(upsample) if i < self.shortcut_len: shortcut = self.add_sublayer( 'shortcut.' + str(i), ShortCut( - self.shortcut_num[i], out_c, name='shortcut.' + str(i))) + self.shortcut_num[i], + self.ch_in[i + 1], + out_c, + norm_type=norm_type, + lite_neck=lite_neck, + name='shortcut.' + str(i))) self.shortcut_list.append(shortcut) + if self.fusion_method == 'add': + upper_c = out_c + elif self.fusion_method == 'concat': + upper_c = out_c * 2 + else: + raise ValueError('Illegal fusion method. Expected add or\ + concat, but received {}'.format(self.fusion_method)) + self.upper_list.append(upper_c) def forward(self, inputs): feat = inputs[-1] @@ -129,7 +239,10 @@ class TTFFPN(nn.Layer): feat = self.upsample_list[i](feat) if i < self.shortcut_len: shortcut = self.shortcut_list[i](inputs[-i - 2]) - feat = feat + shortcut + if self.fusion_method == 'add': + feat = feat + shortcut + else: + feat = paddle.concat([feat, shortcut], axis=1) return feat @classmethod @@ -138,4 +251,4 @@ class TTFFPN(nn.Layer): @property def out_shape(self): - return [ShapeSpec(channels=self.planes[-1], )] + return [ShapeSpec(channels=self.upper_list[-1], )] diff --git a/static/configs/anchor_free/pafnet_10x_coco.yml b/static/configs/anchor_free/pafnet_10x_coco.yml new file mode 100644 index 000000000..4c6728bcd --- /dev/null +++ b/static/configs/anchor_free/pafnet_10x_coco.yml @@ -0,0 +1,170 @@ +architecture: TTFNet +use_gpu: true +max_iters: 150000 +log_smooth_window: 20 +save_dir: output +snapshot_iter: 10000 +metric: COCO +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_ssld_pretrained.tar +weights: output/pafnet_10x_coco/model_final +num_classes: 80 +use_ema: true +ema_decay: 0.9998 + +TTFNet: + backbone: ResNet + ttf_head: TTFHead + +ResNet: + norm_type: sync_bn + freeze_at: 0 + freeze_norm: false + norm_decay: 0. + depth: 50 + feature_maps: [2, 3, 4, 5] + variant: d + dcn_v2_stages: [3, 4, 5] + +TTFHead: + head_conv: 128 + wh_conv: 64 + hm_head_conv_num: 2 + wh_head_conv_num: 2 + wh_offset_base: 16 + wh_loss: GiouLoss + dcn_head: True + +GiouLoss: + loss_weight: 5. + do_average: false + use_class_weight: false + +LearningRate: + base_lr: 0.015 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 112500 + - 137500 + - !LinearWarmup + start_factor: 0.2 + steps: 500 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0004 + type: L2 + +TrainReader: + inputs_def: + fields: ['image', 'ttf_heatmap', 'ttf_box_target', 'ttf_reg_weight'] + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: true + with_cutmix: True + - !CutmixImage + alpha: 1.5 + beta: 1.5 + - !ColorDistort + hue: [-18., 18., 0.5] + saturation: [0.5, 1.5, 0.5] + contrast: [0.5, 1.5, 0.5] + brightness: [-32., 32., 0.5] + random_apply: False + hsv_format: True + random_channel: True + - !RandomExpand + ratio: 4 + prob: 0.5 + fill_value: [123.675, 116.28, 103.53] + - !RandomCrop + aspect_ratio: NULL + cover_all_box: True + - !RandomFlipImage + prob: 0.5 + batch_transforms: + - !RandomShape + sizes: [416, 448, 480, 512, 544, 576, 608, 640, 672] + random_inter: True + resize_box: True + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + - !Permute + to_bgr: false + channel_first: true + - !Gt2TTFTarget + num_classes: 80 + down_ratio: 4 + - !PadBatch + pad_to_stride: 32 + batch_size: 12 + shuffle: true + worker_num: 8 + bufsize: 2 + use_process: false + cutmix_epoch: 100 + +EvalReader: + inputs_def: + image_shape: [3, 512, 512] + fields: ['image', 'im_id', 'scale_factor'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !Resize + target_dim: 512 + - !NormalizeImage + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + is_scale: false + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 + drop_empty: false + worker_num: 8 + bufsize: 16 + +TestReader: + inputs_def: + image_shape: [3, 512, 512] + fields: ['image', 'im_id', 'scale_factor'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !Resize + interp: 1 + target_dim: 512 + - !NormalizeImage + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + is_scale: false + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 diff --git a/static/configs/anchor_free/pafnet_lite_mobilenet_v3_20x_coco.yml b/static/configs/anchor_free/pafnet_lite_mobilenet_v3_20x_coco.yml new file mode 100644 index 000000000..1b1423883 --- /dev/null +++ b/static/configs/anchor_free/pafnet_lite_mobilenet_v3_20x_coco.yml @@ -0,0 +1,171 @@ +architecture: TTFNet +use_gpu: true +max_iters: 300000 +log_smooth_window: 20 +save_dir: output +snapshot_iter: 50000 +metric: COCO +pretrain_weights: https://paddle-imagenet-models-name.bj.bcebos.com/MobileNetV3_large_x1_0_ssld_pretrained.tar +weights: output/pafnet_lite_mobilenet_v3_20x_coco/model_final +num_classes: 80 + +TTFNet: + backbone: MobileNetV3RCNN + ttf_head: TTFLiteHead + +MobileNetV3RCNN: + norm_type: sync_bn + norm_decay: 0.0 + model_name: large + scale: 1.0 + conv_decay: 0.00001 + lr_mult_list: [0.25, 0.25, 0.5, 0.5, 0.75] + freeze_norm: false + +TTFLiteHead: + head_conv: 48 + +GiouLoss: + loss_weight: 5. + do_average: false + use_class_weight: false + +LearningRate: + base_lr: 0.015 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 225000 + - 275000 + - !LinearWarmup + start_factor: 0.2 + steps: 1000 + +OptimizerBuilder: + clip_grad_by_norm: 35 + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0004 + type: L2 + +TrainReader: + inputs_def: + fields: ['image', 'ttf_heatmap', 'ttf_box_target', 'ttf_reg_weight'] + dataset: + !COCODataSet + image_dir: train2017 + anno_path: annotations/instances_train2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: true + with_cutmix: True + - !ColorDistort + hue: [-18., 18., 0.5] + saturation: [0.5, 1.5, 0.5] + contrast: [0.5, 1.5, 0.5] + brightness: [-32., 32., 0.5] + random_apply: False + hsv_format: False + random_channel: True + - !RandomExpand + ratio: 4 + prob: 0.5 + fill_value: [123.675, 116.28, 103.53] + - !RandomCrop + aspect_ratio: NULL + cover_all_box: True + - !CutmixImage + alpha: 1.5 + beta: 1.5 + - !RandomFlipImage + prob: 0.5 + - !GridMaskOp + use_h: true + use_w: true + rotate: 1 + offset: false + ratio: 0.5 + mode: 1 + prob: 0.7 + upper_iter: 300000 + batch_transforms: + - !RandomShape + sizes: [320, 352, 384, 416, 448, 480, 512] + random_inter: True + resize_box: True + - !NormalizeImage + is_channel_first: false + is_scale: false + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + - !Permute + to_bgr: false + channel_first: true + - !Gt2TTFTarget + num_classes: 80 + down_ratio: 4 + - !PadBatch + pad_to_stride: 32 + batch_size: 12 + shuffle: true + worker_num: 8 + bufsize: 2 + use_process: false + cutmix_epoch: 200 + +EvalReader: + inputs_def: + image_shape: [3, 320, 320] + fields: ['image', 'im_id', 'scale_factor'] + dataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !Resize + target_dim: 320 + - !NormalizeImage + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + is_scale: false + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 + drop_empty: false + worker_num: 2 + bufsize: 2 + +TestReader: + inputs_def: + image_shape: [3, 320, 320] + fields: ['image', 'im_id', 'scale_factor'] + dataset: + !ImageFolder + anno_path: annotations/instances_val2017.json + with_background: false + sample_transforms: + - !DecodeImage + to_rgb: True + - !Resize + interp: 1 + target_dim: 320 + - !NormalizeImage + mean: [123.675, 116.28, 103.53] + std: [58.395, 57.12, 57.375] + is_scale: false + is_channel_first: false + - !Permute + to_bgr: false + channel_first: True + batch_size: 1 diff --git a/static/ppdet/modeling/anchor_heads/ttf_head.py b/static/ppdet/modeling/anchor_heads/ttf_head.py index ba9ec802e..31add344d 100644 --- a/static/ppdet/modeling/anchor_heads/ttf_head.py +++ b/static/ppdet/modeling/anchor_heads/ttf_head.py @@ -24,10 +24,10 @@ from paddle.fluid.param_attr import ParamAttr from paddle.fluid.initializer import Normal, Constant, Uniform, Xavier from paddle.fluid.regularizer import L2Decay from ppdet.core.workspace import register -from ppdet.modeling.ops import DeformConv, DropBlock +from ppdet.modeling.ops import DeformConv, DropBlock, ConvNorm from ppdet.modeling.losses import GiouLoss -__all__ = ['TTFHead'] +__all__ = ['TTFHead', 'TTFLiteHead'] @register @@ -65,6 +65,8 @@ class TTFHead(object): drop_block(bool): whether use dropblock. False by default. block_size(int): block_size parameter for drop_block. 3 by default. keep_prob(float): keep_prob parameter for drop_block. 0.9 by default. + fusion_method (string): Method to fusion upsample and lateral branch. + 'add' and 'concat' are optional, add by default """ __inject__ = ['wh_loss'] @@ -90,7 +92,8 @@ class TTFHead(object): dcn_head=False, drop_block=False, block_size=3, - keep_prob=0.9): + keep_prob=0.9, + fusion_method='add'): super(TTFHead, self).__init__() self.head_conv = head_conv self.num_classes = num_classes @@ -115,6 +118,7 @@ class TTFHead(object): self.drop_block = drop_block self.block_size = block_size self.keep_prob = keep_prob + self.fusion_method = fusion_method def shortcut(self, x, out_c, layer_num, kernel_size=3, padding=1, name=None): @@ -255,7 +259,14 @@ class TTFHead(object): out_c, self.shortcut_num[i], name=name + '.shortcut_layers.' + str(i)) - feat = fluid.layers.elementwise_add(feat, shortcut) + if self.fusion_method == 'add': + feat = fluid.layers.elementwise_add(feat, shortcut) + elif self.fusion_method == 'concat': + feat = fluid.layers.concat([feat, shortcut], axis=1) + else: + raise ValueError( + "Illegal fusion method, expected 'add' or 'concat', but received {}". + format(self.fusion_method)) hm = self.hm_head(feat, name=name + '.hm', is_test=is_test) wh = self.wh_head(feat, name=name + '.wh') * self.wh_offset_base @@ -273,12 +284,13 @@ class TTFHead(object): # batch size is 1 scores_r = fluid.layers.reshape(scores, [cat, -1]) topk_scores, topk_inds = fluid.layers.topk(scores_r, k) - topk_ys = topk_inds / width + topk_ys = topk_inds // width topk_xs = topk_inds % width topk_score_r = fluid.layers.reshape(topk_scores, [-1]) topk_score, topk_ind = fluid.layers.topk(topk_score_r, k) - topk_clses = fluid.layers.cast(topk_ind / k, 'float32') + k_t = fluid.layers.assign(np.array([k], dtype='int64')) + topk_clses = fluid.layers.cast(topk_ind / k_t, 'float32') topk_inds = fluid.layers.reshape(topk_inds, [-1]) topk_ys = fluid.layers.reshape(topk_ys, [-1, 1]) @@ -384,3 +396,172 @@ class TTFHead(object): ttf_loss = {'hm_loss': hm_loss, 'wh_loss': wh_loss} return ttf_loss + + +@register +class TTFLiteHead(TTFHead): + """ + TTFLiteHead + + Lite version for TTFNet + Args: + head_conv(int): the default channel number of convolution in head. + 32 by default. + num_classes(int): the number of classes, 80 by default. + planes(tuple): the channel number of convolution in each upsample. + (96, 48, 24) by default. + wh_conv(int): the channel number of convolution in wh head. + 24 by default. + wh_loss(object): `GiouLoss` instance. + shortcut_num(tuple): the number of convolution layers in each shortcut. + (1, 2, 2) by default. + fusion_method (string): Method to fusion upsample and lateral branch. + 'add' and 'concat' are optional, add by default + """ + __inject__ = ['wh_loss'] + __shared__ = ['num_classes'] + + def __init__(self, + head_conv=32, + num_classes=80, + planes=(96, 48, 24), + wh_conv=24, + wh_loss='GiouLoss', + shortcut_num=(1, 2, 2), + fusion_method='concat'): + super(TTFLiteHead, self).__init__( + head_conv=head_conv, + num_classes=num_classes, + planes=planes, + wh_conv=wh_conv, + wh_loss=wh_loss, + shortcut_num=shortcut_num, + fusion_method=fusion_method) + + def _lite_conv(self, x, out_c, act=None, name=None): + conv1 = ConvNorm( + input=x, + num_filters=x.shape[1], + filter_size=5, + groups=x.shape[1], + norm_type='bn', + act='relu6', + initializer=Xavier(), + name=name + '.depthwise', + norm_name=name + '.depthwise.bn') + + conv2 = ConvNorm( + input=conv1, + num_filters=out_c, + filter_size=1, + norm_type='bn', + act=act, + initializer=Xavier(), + name=name + '.pointwise_linear', + norm_name=name + '.pointwise_linear.bn') + + conv3 = ConvNorm( + input=conv2, + num_filters=out_c, + filter_size=1, + norm_type='bn', + act='relu6', + initializer=Xavier(), + name=name + '.pointwise', + norm_name=name + '.pointwise.bn') + + conv4 = ConvNorm( + input=conv3, + num_filters=out_c, + filter_size=5, + groups=out_c, + norm_type='bn', + act=act, + initializer=Xavier(), + name=name + '.depthwise_linear', + norm_name=name + '.depthwise_linear.bn') + + return conv4 + + def shortcut(self, x, out_c, layer_num, name=None): + assert layer_num > 0 + for i in range(layer_num): + param_name = name + '.layers.' + str(i * 2) + act = 'relu6' if i < layer_num - 1 else None + x = self._lite_conv(x, out_c, act, param_name) + return x + + def _deconv_upsample(self, x, out_c, name=None): + conv1 = ConvNorm( + input=x, + num_filters=out_c, + filter_size=1, + norm_type='bn', + act='relu6', + name=name + '.pointwise', + initializer=Xavier(), + norm_name=name + '.pointwise.bn') + conv2 = fluid.layers.conv2d_transpose( + input=conv1, + num_filters=out_c, + filter_size=4, + padding=1, + stride=2, + groups=out_c, + param_attr=ParamAttr( + name=name + '.deconv.weights', initializer=Xavier()), + bias_attr=False) + bn = fluid.layers.batch_norm( + input=conv2, + act='relu6', + param_attr=ParamAttr( + name=name + '.deconv.bn.scale', regularizer=L2Decay(0.)), + bias_attr=ParamAttr( + name=name + '.deconv.bn.offset', regularizer=L2Decay(0.)), + moving_mean_name=name + '.deconv.bn.mean', + moving_variance_name=name + '.deconv.bn.variance') + conv3 = ConvNorm( + input=bn, + num_filters=out_c, + filter_size=1, + norm_type='bn', + act='relu6', + name=name + '.normal', + initializer=Xavier(), + norm_name=name + '.normal.bn') + return conv3 + + def _interp_upsample(self, x, out_c, name=None): + conv = self._lite_conv(x, out_c, 'relu6', name) + up = fluid.layers.resize_bilinear(conv, scale=2) + return up + + def upsample(self, x, out_c, name=None): + deconv_up = self._deconv_upsample(x, out_c, name=name + '.dilation_up') + interp_up = self._interp_upsample(x, out_c, name=name + '.interp_up') + return deconv_up + interp_up + + def _head(self, + x, + out_c, + conv_num=1, + head_out_c=None, + name=None, + is_test=False): + head_out_c = self.head_conv if not head_out_c else head_out_c + for i in range(conv_num): + conv_name = '{}.{}.conv'.format(name, i) + x = self._lite_conv(x, head_out_c, 'relu6', conv_name) + bias_init = float(-np.log((1 - 0.01) / 0.01)) if '.hm' in name else 0. + conv_b_init = Constant(bias_init) + x = fluid.layers.conv2d( + x, + out_c, + 1, + param_attr=ParamAttr(name='{}.{}.weight'.format(name, conv_num)), + bias_attr=ParamAttr( + learning_rate=2., + regularizer=L2Decay(0.), + name='{}.{}.bias'.format(name, conv_num), + initializer=conv_b_init)) + return x -- GitLab