From b06d7f17ea48b8f318a560447ae9083b42030f75 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Thu, 8 Apr 2021 10:26:45 +0800 Subject: [PATCH] update slim model (#2427) * update slim model * fix skip quant --- configs/slim/README.md | 6 ++-- .../slim/quant/mask_rcnn_r50_fpn_1x_qat.yml | 22 +++++++++++++ configs/slim/quant/ssd_mobilenet_v1_qat.yml | 9 ++++++ configs/slim/quant/yolov3_darknet_qat.yml | 31 +++++++++++++++++++ .../slim/quant/yolov3_mobilenet_v1_qat.yml | 1 - .../slim/quant/yolov3_mobilenet_v3_qat.yml | 14 ++++++++- ppdet/modeling/heads/bbox_head.py | 4 +++ ppdet/modeling/heads/mask_head.py | 20 ++++++------ ppdet/modeling/heads/yolo_head.py | 20 ++++++------ ppdet/modeling/layers.py | 7 +++++ ppdet/modeling/proposal_generator/rpn_head.py | 3 ++ 11 files changed, 113 insertions(+), 24 deletions(-) create mode 100644 configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml create mode 100644 configs/slim/quant/ssd_mobilenet_v1_qat.yml create mode 100644 configs/slim/quant/yolov3_darknet_qat.yml diff --git a/configs/slim/README.md b/configs/slim/README.md index 67ba71585..c314f52b6 100755 --- a/configs/slim/README.md +++ b/configs/slim/README.md @@ -87,13 +87,13 @@ python tools/export_model.py -c configs/{MODEL.yml} --slim_config configs/slim/{ | 模型 | 压缩策略 | 输入尺寸 | Box AP | 下载 | 模型配置文件 | 压缩算法配置文件 | | ------------------ | ------------ | -------- | :---------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | | YOLOv3-MobileNetV1 | baseline | 608 | 28.8 | [下载链接](https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) | - | -| YOLOv3-MobileNetV1 | 普通在线量化 | 608 | 30.3 (+1.5) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_mobilenet_v1_qat.yml) | +| YOLOv3-MobileNetV1 | 普通在线量化 | 608 | 30.5 (+1.7) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_mobilenet_v1_qat.yml) | | YOLOv3-MobileNetV3 | baseline | 608 | 31.4 | [下载链接](https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v3_large_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v3_large_270e_coco.yml) | - | -| YOLOv3-MobileNetV3 | PACT在线量化 | 608 | 29.5 (-1.9) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v3_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v3_large_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_mobilenet_v3_qat.yml) | +| YOLOv3-MobileNetV3 | PACT在线量化 | 608 | 29.1 (-2.3) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v3_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v3_large_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_mobilenet_v3_qat.yml) | | YOLOv3-DarkNet53 | baseline | 608 | 39.0 | [下载链接](https://paddledet.bj.bcebos.com/models/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_darknet53_270e_coco.yml) | - | | YOLOv3-DarkNet53 | 普通在线量化 | 608 | 38.7 (-0.3) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_darknet_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_darknet53_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_darknet_qat.yml) | | SSD-MobileNet_v1 | baseline | 300 | 73.8 | [下载链接](https://paddledet.bj.bcebos.com/models/ssd_mobilenet_v1_300_120e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml) | - | -| SSD-MobileNet_v1 | 普通在线量化 | 300 | 73.1(-0.7) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/ssd_mobilenet_v1_300_voc_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/ssd_mobilenet_v1_qat.yml) | +| SSD-MobileNet_v1 | 普通在线量化 | 300 | 72.9(-0.9) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/ssd_mobilenet_v1_300_voc_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/ssd_mobilenet_v1_qat.yml) | | Mask-ResNet50-FPN | baseline | (800, 1333) | 39.2/35.6 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.yml) | - | | Mask-ResNet50-FPN | 普通在线量化 | (800, 1333) | 39.7(+0.5)/35.9(+0.3) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/mask_rcnn_r50_fpn_1x_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml) | diff --git a/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml b/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml new file mode 100644 index 000000000..7363b4e55 --- /dev/null +++ b/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml @@ -0,0 +1,22 @@ +pretrain_weights: https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_fpn_1x_coco.pdparams +slim: QAT + +QAT: + quant_config: { + 'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max', + 'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9, + 'quantizable_layer_type': ['Conv2D', 'Linear']} + print_model: True + + +epoch: 5 + +LearningRate: + base_lr: 0.001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [3, 4] + - !LinearWarmup + start_factor: 0.001 + steps: 100 diff --git a/configs/slim/quant/ssd_mobilenet_v1_qat.yml b/configs/slim/quant/ssd_mobilenet_v1_qat.yml new file mode 100644 index 000000000..05e068368 --- /dev/null +++ b/configs/slim/quant/ssd_mobilenet_v1_qat.yml @@ -0,0 +1,9 @@ +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_mobilenet_v1_300_120e_voc.pdparams +slim: QAT + +QAT: + quant_config: { + 'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max', + 'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9, + 'quantizable_layer_type': ['Conv2D', 'Linear']} + print_model: True diff --git a/configs/slim/quant/yolov3_darknet_qat.yml b/configs/slim/quant/yolov3_darknet_qat.yml new file mode 100644 index 000000000..281b53418 --- /dev/null +++ b/configs/slim/quant/yolov3_darknet_qat.yml @@ -0,0 +1,31 @@ +pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_darknet53_270e_coco.pdparams +slim: QAT + +QAT: + quant_config: { + 'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max', + 'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9, + 'quantizable_layer_type': ['Conv2D', 'Linear']} + print_model: True + +epoch: 50 + +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 30 + - 45 + - !LinearWarmup + start_factor: 0. + steps: 1000 + +OptimizerBuilder: + optimizer: + momentum: 0.9 + type: Momentum + regularizer: + factor: 0.0005 + type: L2 diff --git a/configs/slim/quant/yolov3_mobilenet_v1_qat.yml b/configs/slim/quant/yolov3_mobilenet_v1_qat.yml index dfa365c10..d14520829 100644 --- a/configs/slim/quant/yolov3_mobilenet_v1_qat.yml +++ b/configs/slim/quant/yolov3_mobilenet_v1_qat.yml @@ -1,6 +1,5 @@ # Weights of yolov3_mobilenet_v1_coco pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams -weight_type: resume slim: QAT QAT: diff --git a/configs/slim/quant/yolov3_mobilenet_v3_qat.yml b/configs/slim/quant/yolov3_mobilenet_v3_qat.yml index 288e72a10..812690908 100644 --- a/configs/slim/quant/yolov3_mobilenet_v3_qat.yml +++ b/configs/slim/quant/yolov3_mobilenet_v3_qat.yml @@ -1,6 +1,5 @@ # Weights of yolov3_mobilenet_v3_coco pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v3_large_270e_coco.pdparams -weight_type: resume slim: QAT QAT: @@ -10,3 +9,16 @@ QAT: 'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9, 'quantizable_layer_type': ['Conv2D', 'Linear']} print_model: True + +epoch: 30 +LearningRate: + base_lr: 0.0001 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: + - 25 + - 28 + - !LinearWarmup + start_factor: 0. + steps: 2000 diff --git a/ppdet/modeling/heads/bbox_head.py b/ppdet/modeling/heads/bbox_head.py index 0c75f8f1d..58614c578 100644 --- a/ppdet/modeling/heads/bbox_head.py +++ b/ppdet/modeling/heads/bbox_head.py @@ -50,11 +50,13 @@ class TwoFCHead(nn.Layer): out_channel, weight_attr=paddle.ParamAttr( initializer=XavierUniform(fan_out=fan))) + self.fc6.skip_quant = True self.fc7 = nn.Linear( out_channel, out_channel, weight_attr=paddle.ParamAttr(initializer=XavierUniform())) + self.fc7.skip_quant = True @classmethod def from_config(cls, cfg, input_shape): @@ -199,12 +201,14 @@ class BBoxHead(nn.Layer): self.num_classes + 1, weight_attr=paddle.ParamAttr(initializer=Normal( mean=0.0, std=0.01))) + self.bbox_score.skip_quant = True self.bbox_delta = nn.Linear( in_channel, 4 * self.num_classes, weight_attr=paddle.ParamAttr(initializer=Normal( mean=0.0, std=0.001))) + self.bbox_delta.skip_quant = True self.assigned_label = None self.assigned_rois = None diff --git a/ppdet/modeling/heads/mask_head.py b/ppdet/modeling/heads/mask_head.py index eea70922a..2b2cbc7ff 100644 --- a/ppdet/modeling/heads/mask_head.py +++ b/ppdet/modeling/heads/mask_head.py @@ -65,20 +65,21 @@ class MaskFeat(nn.Layer): norm_type=self.norm_type, norm_name=conv_name + '_norm', initializer=KaimingNormal(fan_in=fan_conv), + skip_quant=True, name=conv_name)) mask_conv.add_sublayer(conv_name + 'act', nn.ReLU()) else: for i in range(self.num_convs): conv_name = 'mask_inter_feat_{}'.format(i + 1) - mask_conv.add_sublayer( - conv_name, - nn.Conv2D( - in_channels=in_channel if i == 0 else out_channel, - out_channels=out_channel, - kernel_size=3, - padding=1, - weight_attr=paddle.ParamAttr( - initializer=KaimingNormal(fan_in=fan_conv)))) + conv = nn.Conv2D( + in_channels=in_channel if i == 0 else out_channel, + out_channels=out_channel, + kernel_size=3, + padding=1, + weight_attr=paddle.ParamAttr( + initializer=KaimingNormal(fan_in=fan_conv))) + conv.skip_quant = True + mask_conv.add_sublayer(conv_name, conv) mask_conv.add_sublayer(conv_name + 'act', nn.ReLU()) mask_conv.add_sublayer( 'conv5_mask', @@ -146,6 +147,7 @@ class MaskHead(nn.Layer): kernel_size=1, weight_attr=paddle.ParamAttr(initializer=KaimingNormal( fan_in=self.num_classes))) + self.mask_fcn_logits.skip_quant = True @classmethod def from_config(cls, cfg, input_shape): diff --git a/ppdet/modeling/heads/yolo_head.py b/ppdet/modeling/heads/yolo_head.py index 3516da410..e2881f489 100644 --- a/ppdet/modeling/heads/yolo_head.py +++ b/ppdet/modeling/heads/yolo_head.py @@ -47,16 +47,16 @@ class YOLOv3Head(nn.Layer): else: num_filters = len(self.anchors[i]) * (self.num_classes + 5) name = 'yolo_output.{}'.format(i) - yolo_output = self.add_sublayer( - name, - nn.Conv2D( - in_channels=128 * (2**self.num_outputs) // (2**i), - out_channels=num_filters, - kernel_size=1, - stride=1, - padding=0, - data_format=data_format, - bias_attr=ParamAttr(regularizer=L2Decay(0.)))) + conv = nn.Conv2D( + in_channels=128 * (2**self.num_outputs) // (2**i), + out_channels=num_filters, + kernel_size=1, + stride=1, + padding=0, + data_format=data_format, + bias_attr=ParamAttr(regularizer=L2Decay(0.))) + conv.skip_quant = True + yolo_output = self.add_sublayer(name, conv) self.yolo_outputs.append(yolo_output) def parse_anchor(self, anchors, anchor_masks): diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 002ef63ca..df7e9fab3 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -52,6 +52,7 @@ class DeformableConvV2(nn.Layer): bias_attr=None, lr_scale=1, regularizer=None, + skip_quant=False, name=None): super(DeformableConvV2, self).__init__() self.offset_channel = 2 * kernel_size**2 @@ -77,6 +78,8 @@ class DeformableConvV2(nn.Layer): initializer=Constant(0.0), name='{}._conv_offset.weight'.format(name)), bias_attr=offset_bias_attr) + if skip_quant: + self.conv_offset.skip_quant = True if bias_attr: # in FCOS-DCN head, specifically need learning_rate and regularizer @@ -126,6 +129,7 @@ class ConvNormLayer(nn.Layer): freeze_norm=False, initializer=Normal( mean=0., std=0.01), + skip_quant=False, name=None): super(ConvNormLayer, self).__init__() assert norm_type in ['bn', 'sync_bn', 'gn'] @@ -151,6 +155,8 @@ class ConvNormLayer(nn.Layer): initializer=initializer, learning_rate=1.), bias_attr=bias_attr) + if skip_quant: + self.conv.skip_quant = True else: # in FCOS-DCN head, specifically need learning_rate and regularizer self.conv = DeformableConvV2( @@ -167,6 +173,7 @@ class ConvNormLayer(nn.Layer): bias_attr=True, lr_scale=2., regularizer=L2Decay(norm_decay), + skip_quant=skip_quant, name=name) norm_lr = 0. if freeze_norm else 1. diff --git a/ppdet/modeling/proposal_generator/rpn_head.py b/ppdet/modeling/proposal_generator/rpn_head.py index 2b1e6c77b..ea9fb851c 100644 --- a/ppdet/modeling/proposal_generator/rpn_head.py +++ b/ppdet/modeling/proposal_generator/rpn_head.py @@ -45,6 +45,7 @@ class RPNFeat(nn.Layer): padding=1, weight_attr=paddle.ParamAttr(initializer=Normal( mean=0., std=0.01))) + self.rpn_conv.skip_quant = True def forward(self, feats): rpn_feats = [] @@ -100,6 +101,7 @@ class RPNHead(nn.Layer): padding=0, weight_attr=paddle.ParamAttr(initializer=Normal( mean=0., std=0.01))) + self.rpn_rois_score.skip_quant = True # rpn roi bbox regression deltas self.rpn_rois_delta = nn.Conv2D( @@ -109,6 +111,7 @@ class RPNHead(nn.Layer): padding=0, weight_attr=paddle.ParamAttr(initializer=Normal( mean=0., std=0.01))) + self.rpn_rois_delta.skip_quant = True @classmethod def from_config(cls, cfg, input_shape): -- GitLab