未验证 提交 b06d7f17 编写于 作者: G Guanghua Yu 提交者: GitHub

update slim model (#2427)

* update slim model

* fix skip quant
上级 f048a21c
......@@ -87,13 +87,13 @@ python tools/export_model.py -c configs/{MODEL.yml} --slim_config configs/slim/{
| 模型 | 压缩策略 | 输入尺寸 | Box AP | 下载 | 模型配置文件 | 压缩算法配置文件 |
| ------------------ | ------------ | -------- | :---------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| YOLOv3-MobileNetV1 | baseline | 608 | 28.8 | [下载链接](https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) | - |
| YOLOv3-MobileNetV1 | 普通在线量化 | 608 | 30.3 (+1.5) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_mobilenet_v1_qat.yml) |
| YOLOv3-MobileNetV1 | 普通在线量化 | 608 | 30.5 (+1.7) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_mobilenet_v1_qat.yml) |
| YOLOv3-MobileNetV3 | baseline | 608 | 31.4 | [下载链接](https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v3_large_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v3_large_270e_coco.yml) | - |
| YOLOv3-MobileNetV3 | PACT在线量化 | 608 | 29.5 (-1.9) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v3_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v3_large_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_mobilenet_v3_qat.yml) |
| YOLOv3-MobileNetV3 | PACT在线量化 | 608 | 29.1 (-2.3) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v3_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_mobilenet_v3_large_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_mobilenet_v3_qat.yml) |
| YOLOv3-DarkNet53 | baseline | 608 | 39.0 | [下载链接](https://paddledet.bj.bcebos.com/models/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_darknet53_270e_coco.yml) | - |
| YOLOv3-DarkNet53 | 普通在线量化 | 608 | 38.7 (-0.3) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/yolov3_darknet_coco_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/yolov3/yolov3_darknet53_270e_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/yolov3_darknet_qat.yml) |
| SSD-MobileNet_v1 | baseline | 300 | 73.8 | [下载链接](https://paddledet.bj.bcebos.com/models/ssd_mobilenet_v1_300_120e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml) | - |
| SSD-MobileNet_v1 | 普通在线量化 | 300 | 73.1(-0.7) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/ssd_mobilenet_v1_300_voc_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/ssd_mobilenet_v1_qat.yml) |
| SSD-MobileNet_v1 | 普通在线量化 | 300 | 72.9(-0.9) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/ssd_mobilenet_v1_300_voc_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/ssd_mobilenet_v1_qat.yml) |
| Mask-ResNet50-FPN | baseline | (800, 1333) | 39.2/35.6 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.yml) | - |
| Mask-ResNet50-FPN | 普通在线量化 | (800, 1333) | 39.7(+0.5)/35.9(+0.3) | [下载链接](https://paddledet.bj.bcebos.com/models/slim/mask_rcnn_r50_fpn_1x_qat.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.yml) | [slim配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/slim/quant/mask_rcnn_r50_fpn_1x_qat.yml) |
......
pretrain_weights: https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_fpn_1x_coco.pdparams
slim: QAT
QAT:
quant_config: {
'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: True
epoch: 5
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [3, 4]
- !LinearWarmup
start_factor: 0.001
steps: 100
pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_mobilenet_v1_300_120e_voc.pdparams
slim: QAT
QAT:
quant_config: {
'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: True
pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_darknet53_270e_coco.pdparams
slim: QAT
QAT:
quant_config: {
'weight_quantize_type': 'channel_wise_abs_max', 'activation_quantize_type': 'moving_average_abs_max',
'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: True
epoch: 50
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 30
- 45
- !LinearWarmup
start_factor: 0.
steps: 1000
OptimizerBuilder:
optimizer:
momentum: 0.9
type: Momentum
regularizer:
factor: 0.0005
type: L2
# Weights of yolov3_mobilenet_v1_coco
pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams
weight_type: resume
slim: QAT
QAT:
......
# Weights of yolov3_mobilenet_v3_coco
pretrain_weights: https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v3_large_270e_coco.pdparams
weight_type: resume
slim: QAT
QAT:
......@@ -10,3 +9,16 @@ QAT:
'weight_bits': 8, 'activation_bits': 8, 'dtype': 'int8', 'window_size': 10000, 'moving_rate': 0.9,
'quantizable_layer_type': ['Conv2D', 'Linear']}
print_model: True
epoch: 30
LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 25
- 28
- !LinearWarmup
start_factor: 0.
steps: 2000
......@@ -50,11 +50,13 @@ class TwoFCHead(nn.Layer):
out_channel,
weight_attr=paddle.ParamAttr(
initializer=XavierUniform(fan_out=fan)))
self.fc6.skip_quant = True
self.fc7 = nn.Linear(
out_channel,
out_channel,
weight_attr=paddle.ParamAttr(initializer=XavierUniform()))
self.fc7.skip_quant = True
@classmethod
def from_config(cls, cfg, input_shape):
......@@ -199,12 +201,14 @@ class BBoxHead(nn.Layer):
self.num_classes + 1,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0.0, std=0.01)))
self.bbox_score.skip_quant = True
self.bbox_delta = nn.Linear(
in_channel,
4 * self.num_classes,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0.0, std=0.001)))
self.bbox_delta.skip_quant = True
self.assigned_label = None
self.assigned_rois = None
......
......@@ -65,20 +65,21 @@ class MaskFeat(nn.Layer):
norm_type=self.norm_type,
norm_name=conv_name + '_norm',
initializer=KaimingNormal(fan_in=fan_conv),
skip_quant=True,
name=conv_name))
mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
else:
for i in range(self.num_convs):
conv_name = 'mask_inter_feat_{}'.format(i + 1)
mask_conv.add_sublayer(
conv_name,
nn.Conv2D(
in_channels=in_channel if i == 0 else out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=paddle.ParamAttr(
initializer=KaimingNormal(fan_in=fan_conv))))
conv = nn.Conv2D(
in_channels=in_channel if i == 0 else out_channel,
out_channels=out_channel,
kernel_size=3,
padding=1,
weight_attr=paddle.ParamAttr(
initializer=KaimingNormal(fan_in=fan_conv)))
conv.skip_quant = True
mask_conv.add_sublayer(conv_name, conv)
mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
mask_conv.add_sublayer(
'conv5_mask',
......@@ -146,6 +147,7 @@ class MaskHead(nn.Layer):
kernel_size=1,
weight_attr=paddle.ParamAttr(initializer=KaimingNormal(
fan_in=self.num_classes)))
self.mask_fcn_logits.skip_quant = True
@classmethod
def from_config(cls, cfg, input_shape):
......
......@@ -47,16 +47,16 @@ class YOLOv3Head(nn.Layer):
else:
num_filters = len(self.anchors[i]) * (self.num_classes + 5)
name = 'yolo_output.{}'.format(i)
yolo_output = self.add_sublayer(
name,
nn.Conv2D(
in_channels=128 * (2**self.num_outputs) // (2**i),
out_channels=num_filters,
kernel_size=1,
stride=1,
padding=0,
data_format=data_format,
bias_attr=ParamAttr(regularizer=L2Decay(0.))))
conv = nn.Conv2D(
in_channels=128 * (2**self.num_outputs) // (2**i),
out_channels=num_filters,
kernel_size=1,
stride=1,
padding=0,
data_format=data_format,
bias_attr=ParamAttr(regularizer=L2Decay(0.)))
conv.skip_quant = True
yolo_output = self.add_sublayer(name, conv)
self.yolo_outputs.append(yolo_output)
def parse_anchor(self, anchors, anchor_masks):
......
......@@ -52,6 +52,7 @@ class DeformableConvV2(nn.Layer):
bias_attr=None,
lr_scale=1,
regularizer=None,
skip_quant=False,
name=None):
super(DeformableConvV2, self).__init__()
self.offset_channel = 2 * kernel_size**2
......@@ -77,6 +78,8 @@ class DeformableConvV2(nn.Layer):
initializer=Constant(0.0),
name='{}._conv_offset.weight'.format(name)),
bias_attr=offset_bias_attr)
if skip_quant:
self.conv_offset.skip_quant = True
if bias_attr:
# in FCOS-DCN head, specifically need learning_rate and regularizer
......@@ -126,6 +129,7 @@ class ConvNormLayer(nn.Layer):
freeze_norm=False,
initializer=Normal(
mean=0., std=0.01),
skip_quant=False,
name=None):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn', 'gn']
......@@ -151,6 +155,8 @@ class ConvNormLayer(nn.Layer):
initializer=initializer,
learning_rate=1.),
bias_attr=bias_attr)
if skip_quant:
self.conv.skip_quant = True
else:
# in FCOS-DCN head, specifically need learning_rate and regularizer
self.conv = DeformableConvV2(
......@@ -167,6 +173,7 @@ class ConvNormLayer(nn.Layer):
bias_attr=True,
lr_scale=2.,
regularizer=L2Decay(norm_decay),
skip_quant=skip_quant,
name=name)
norm_lr = 0. if freeze_norm else 1.
......
......@@ -45,6 +45,7 @@ class RPNFeat(nn.Layer):
padding=1,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0., std=0.01)))
self.rpn_conv.skip_quant = True
def forward(self, feats):
rpn_feats = []
......@@ -100,6 +101,7 @@ class RPNHead(nn.Layer):
padding=0,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0., std=0.01)))
self.rpn_rois_score.skip_quant = True
# rpn roi bbox regression deltas
self.rpn_rois_delta = nn.Conv2D(
......@@ -109,6 +111,7 @@ class RPNHead(nn.Layer):
padding=0,
weight_attr=paddle.ParamAttr(initializer=Normal(
mean=0., std=0.01)))
self.rpn_rois_delta.skip_quant = True
@classmethod
def from_config(cls, cfg, input_shape):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册