From 65dd23467522dfe3fcbee284cd7a2f2d9242d95a Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Thu, 8 Sep 2022 21:05:12 +0800 Subject: [PATCH] [PPYOLOE] fix proj_conv in ptq bug (#6900) --- configs/ppyoloe/_base_/ppyoloe_crn.yml | 1 + configs/ppyoloe/_base_/ppyoloe_plus_crn.yml | 1 + .../ppyoloe/ppyoloe_crn_l_36e_coco_xpu.yml | 2 + .../configs/ppyoloe_plus_m_qat_dis.yaml | 1 + .../configs/ppyoloe_plus_reader.yml | 4 +- .../configs/ppyoloe_plus_x_qat_dis.yaml | 1 + deploy/python/utils.py | 2 +- ppdet/engine/trainer.py | 6 ++- ppdet/modeling/assigners/atss_assigner.py | 2 +- .../assigners/task_aligned_assigner.py | 2 +- ppdet/modeling/heads/ppyoloe_head.py | 27 +++++----- ppdet/optimizer/ema.py | 50 ++++++++++++++----- 12 files changed, 64 insertions(+), 35 deletions(-) diff --git a/configs/ppyoloe/_base_/ppyoloe_crn.yml b/configs/ppyoloe/_base_/ppyoloe_crn.yml index e1c3cc857..118db7ee1 100644 --- a/configs/ppyoloe/_base_/ppyoloe_crn.yml +++ b/configs/ppyoloe/_base_/ppyoloe_crn.yml @@ -2,6 +2,7 @@ architecture: YOLOv3 norm_type: sync_bn use_ema: true ema_decay: 0.9998 +ema_black_list: ['proj_conv.weight'] custom_black_list: ['reduce_mean'] YOLOv3: diff --git a/configs/ppyoloe/_base_/ppyoloe_plus_crn.yml b/configs/ppyoloe/_base_/ppyoloe_plus_crn.yml index ba8ef992c..c8e6191fd 100644 --- a/configs/ppyoloe/_base_/ppyoloe_plus_crn.yml +++ b/configs/ppyoloe/_base_/ppyoloe_plus_crn.yml @@ -2,6 +2,7 @@ architecture: YOLOv3 norm_type: sync_bn use_ema: true ema_decay: 0.9998 +ema_black_list: ['proj_conv.weight'] custom_black_list: ['reduce_mean'] YOLOv3: diff --git a/configs/ppyoloe/ppyoloe_crn_l_36e_coco_xpu.yml b/configs/ppyoloe/ppyoloe_crn_l_36e_coco_xpu.yml index 379728886..21af7774c 100644 --- a/configs/ppyoloe/ppyoloe_crn_l_36e_coco_xpu.yml +++ b/configs/ppyoloe/ppyoloe_crn_l_36e_coco_xpu.yml @@ -26,6 +26,8 @@ architecture: YOLOv3 norm_type: sync_bn use_ema: true ema_decay: 0.9998 +ema_black_list: ['proj_conv.weight'] +custom_black_list: ['reduce_mean'] YOLOv3: backbone: CSPResNet diff --git a/deploy/auto_compression/configs/ppyoloe_plus_m_qat_dis.yaml b/deploy/auto_compression/configs/ppyoloe_plus_m_qat_dis.yaml index 65c1b9240..26ee8fe7f 100644 --- a/deploy/auto_compression/configs/ppyoloe_plus_m_qat_dis.yaml +++ b/deploy/auto_compression/configs/ppyoloe_plus_m_qat_dis.yaml @@ -14,6 +14,7 @@ Distillation: Quantization: use_pact: true + onnx_format: True activation_quantize_type: 'moving_average_abs_max' quantize_op_types: - conv2d diff --git a/deploy/auto_compression/configs/ppyoloe_plus_reader.yml b/deploy/auto_compression/configs/ppyoloe_plus_reader.yml index 632dd87c7..5f3795f29 100644 --- a/deploy/auto_compression/configs/ppyoloe_plus_reader.yml +++ b/deploy/auto_compression/configs/ppyoloe_plus_reader.yml @@ -1,5 +1,3 @@ - - metric: COCO num_classes: 80 @@ -23,6 +21,6 @@ EvalReader: sample_transforms: - Decode: {} - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} - - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: True} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} - Permute: {} batch_size: 4 diff --git a/deploy/auto_compression/configs/ppyoloe_plus_x_qat_dis.yaml b/deploy/auto_compression/configs/ppyoloe_plus_x_qat_dis.yaml index 5e6ba5850..5cbed405c 100644 --- a/deploy/auto_compression/configs/ppyoloe_plus_x_qat_dis.yaml +++ b/deploy/auto_compression/configs/ppyoloe_plus_x_qat_dis.yaml @@ -14,6 +14,7 @@ Distillation: Quantization: use_pact: true + onnx_format: True activation_quantize_type: 'moving_average_abs_max' quantize_op_types: - conv2d diff --git a/deploy/python/utils.py b/deploy/python/utils.py index c52ac184c..ac8a3f702 100644 --- a/deploy/python/utils.py +++ b/deploy/python/utils.py @@ -108,7 +108,7 @@ def argsparser(): "calibration, trt_calib_mode need to set True.") parser.add_argument( '--save_images', - type=bool, + type=ast.literal_eval, default=True, help='Save visualization image results.') parser.add_argument( diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index f46974aaf..53d7296a0 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -169,13 +169,15 @@ class Trainer(object): self.use_ema = ('use_ema' in cfg and cfg['use_ema']) if self.use_ema: ema_decay = self.cfg.get('ema_decay', 0.9998) - cycle_epoch = self.cfg.get('cycle_epoch', -1) ema_decay_type = self.cfg.get('ema_decay_type', 'threshold') + cycle_epoch = self.cfg.get('cycle_epoch', -1) + ema_black_list = self.cfg.get('ema_black_list', None) self.ema = ModelEMA( self.model, decay=ema_decay, ema_decay_type=ema_decay_type, - cycle_epoch=cycle_epoch) + cycle_epoch=cycle_epoch, + ema_black_list=ema_black_list) self._nranks = dist.get_world_size() self._local_rank = dist.get_rank() diff --git a/ppdet/modeling/assigners/atss_assigner.py b/ppdet/modeling/assigners/atss_assigner.py index 6406d7bce..409e040de 100644 --- a/ppdet/modeling/assigners/atss_assigner.py +++ b/ppdet/modeling/assigners/atss_assigner.py @@ -120,7 +120,7 @@ class ATSSAssigner(nn.Layer): # negative batch if num_max_boxes == 0: assigned_labels = paddle.full( - [batch_size, num_anchors], bg_index, dtype=gt_labels.dtype) + [batch_size, num_anchors], bg_index, dtype='int32') assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4]) assigned_scores = paddle.zeros( [batch_size, num_anchors, self.num_classes]) diff --git a/ppdet/modeling/assigners/task_aligned_assigner.py b/ppdet/modeling/assigners/task_aligned_assigner.py index 5b3368e06..cb932c788 100644 --- a/ppdet/modeling/assigners/task_aligned_assigner.py +++ b/ppdet/modeling/assigners/task_aligned_assigner.py @@ -86,7 +86,7 @@ class TaskAlignedAssigner(nn.Layer): # negative batch if num_max_boxes == 0: assigned_labels = paddle.full( - [batch_size, num_anchors], bg_index, dtype=gt_labels.dtype) + [batch_size, num_anchors], bg_index, dtype='int32') assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4]) assigned_scores = paddle.zeros( [batch_size, num_anchors, num_classes]) diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 612a37288..cdcf2bce5 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -130,11 +130,10 @@ class PPYOLOEHead(nn.Layer): constant_(reg_.weight) constant_(reg_.bias, 1.0) - self.proj = paddle.linspace(0, self.reg_max, self.reg_max + 1) - self.proj_conv.weight.set_value( - self.proj.reshape([1, self.reg_max + 1, 1, 1])) + proj = paddle.linspace(0, self.reg_max, self.reg_max + 1).reshape( + [1, self.reg_max + 1, 1, 1]) + self.proj_conv.weight.set_value(proj) self.proj_conv.weight.stop_gradient = True - if self.eval_size: anchor_points, stride_tensor = self._generate_anchors() self.anchor_points = anchor_points @@ -200,15 +199,15 @@ class PPYOLOEHead(nn.Layer): feat) reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat)) reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose( - [0, 2, 1, 3]) - reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1)) + [0, 2, 3, 1]) + reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1)).squeeze(1) # cls and reg cls_score = F.sigmoid(cls_logit) cls_score_list.append(cls_score.reshape([b, self.num_classes, l])) - reg_dist_list.append(reg_dist.reshape([b, 4, l])) + reg_dist_list.append(reg_dist) cls_score_list = paddle.concat(cls_score_list, axis=-1) - reg_dist_list = paddle.concat(reg_dist_list, axis=-1) + reg_dist_list = paddle.concat(reg_dist_list, axis=1) return cls_score_list, reg_dist_list, anchor_points, stride_tensor @@ -240,8 +239,8 @@ class PPYOLOEHead(nn.Layer): def _bbox_decode(self, anchor_points, pred_dist): b, l, _ = get_static_shape(pred_dist) - pred_dist = F.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1 - ])).matmul(self.proj) + pred_dist = F.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1])) + pred_dist = self.proj_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1) return batch_distance2bbox(anchor_points, pred_dist) def _bbox2distance(self, points, bbox): @@ -347,9 +346,8 @@ class PPYOLOEHead(nn.Layer): assigned_scores_sum = assigned_scores.sum() if paddle.distributed.get_world_size() > 1: paddle.distributed.all_reduce(assigned_scores_sum) - assigned_scores_sum = paddle.clip( - assigned_scores_sum / paddle.distributed.get_world_size(), - min=1) + assigned_scores_sum /= paddle.distributed.get_world_size() + assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.) loss_cls /= assigned_scores_sum loss_l1, loss_iou, loss_dfl = \ @@ -370,8 +368,7 @@ class PPYOLOEHead(nn.Layer): def post_process(self, head_outs, scale_factor): pred_scores, pred_dist, anchor_points, stride_tensor = head_outs - pred_bboxes = batch_distance2bbox(anchor_points, - pred_dist.transpose([0, 2, 1])) + pred_bboxes = batch_distance2bbox(anchor_points, pred_dist) pred_bboxes *= stride_tensor if self.exclude_post_process: return paddle.concat( diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index bd8cb825c..927d357b4 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -36,21 +36,30 @@ class ModelEMA(object): step. Defaults is -1, which means not reset. Its function is to add a regular effect to ema, which is set according to experience and is effective when the total training epoch is large. + ema_black_list (set|list|tuple, optional): The custom EMA black_list. + Blacklist of weight names that will not participate in EMA + calculation. Default: None. """ def __init__(self, model, decay=0.9998, ema_decay_type='threshold', - cycle_epoch=-1): + cycle_epoch=-1, + ema_black_list=None): self.step = 0 self.epoch = 0 self.decay = decay - self.state_dict = dict() - for k, v in model.state_dict().items(): - self.state_dict[k] = paddle.zeros_like(v) self.ema_decay_type = ema_decay_type self.cycle_epoch = cycle_epoch + self.ema_black_list = self._match_ema_black_list( + model.state_dict().keys(), ema_black_list) + self.state_dict = dict() + for k, v in model.state_dict().items(): + if k in self.ema_black_list: + self.state_dict[k] = v + else: + self.state_dict[k] = paddle.zeros_like(v) self._model_state = { k: weakref.ref(p) @@ -61,7 +70,10 @@ class ModelEMA(object): self.step = 0 self.epoch = 0 for k, v in self.state_dict.items(): - self.state_dict[k] = paddle.zeros_like(v) + if k in self.ema_black_list: + self.state_dict[k] = v + else: + self.state_dict[k] = paddle.zeros_like(v) def resume(self, state_dict, step=0): for k, v in state_dict.items(): @@ -89,9 +101,10 @@ class ModelEMA(object): [v is not None for _, v in model_dict.items()]), 'python gc.' for k, v in self.state_dict.items(): - v = decay * v + (1 - decay) * model_dict[k] - v.stop_gradient = True - self.state_dict[k] = v + if k not in self.ema_black_list: + v = decay * v + (1 - decay) * model_dict[k] + v.stop_gradient = True + self.state_dict[k] = v self.step += 1 def apply(self): @@ -99,12 +112,25 @@ class ModelEMA(object): return self.state_dict state_dict = dict() for k, v in self.state_dict.items(): - if self.ema_decay_type != 'exponential': - v = v / (1 - self._decay**self.step) - v.stop_gradient = True - state_dict[k] = v + if k in self.ema_black_list: + v.stop_gradient = True + state_dict[k] = v + else: + if self.ema_decay_type != 'exponential': + v = v / (1 - self._decay**self.step) + v.stop_gradient = True + state_dict[k] = v self.epoch += 1 if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch: self.reset() return state_dict + + def _match_ema_black_list(self, weight_name, ema_black_list=None): + out_list = set() + if ema_black_list: + for name in weight_name: + for key in ema_black_list: + if key in name: + out_list.add(name) + return out_list -- GitLab