未验证 提交 65dd2346 编写于 作者: S shangliang Xu 提交者: GitHub

[PPYOLOE] fix proj_conv in ptq bug (#6900)

上级 dafd365a
...@@ -2,6 +2,7 @@ architecture: YOLOv3 ...@@ -2,6 +2,7 @@ architecture: YOLOv3
norm_type: sync_bn norm_type: sync_bn
use_ema: true use_ema: true
ema_decay: 0.9998 ema_decay: 0.9998
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean'] custom_black_list: ['reduce_mean']
YOLOv3: YOLOv3:
......
...@@ -2,6 +2,7 @@ architecture: YOLOv3 ...@@ -2,6 +2,7 @@ architecture: YOLOv3
norm_type: sync_bn norm_type: sync_bn
use_ema: true use_ema: true
ema_decay: 0.9998 ema_decay: 0.9998
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean'] custom_black_list: ['reduce_mean']
YOLOv3: YOLOv3:
......
...@@ -26,6 +26,8 @@ architecture: YOLOv3 ...@@ -26,6 +26,8 @@ architecture: YOLOv3
norm_type: sync_bn norm_type: sync_bn
use_ema: true use_ema: true
ema_decay: 0.9998 ema_decay: 0.9998
ema_black_list: ['proj_conv.weight']
custom_black_list: ['reduce_mean']
YOLOv3: YOLOv3:
backbone: CSPResNet backbone: CSPResNet
......
...@@ -14,6 +14,7 @@ Distillation: ...@@ -14,6 +14,7 @@ Distillation:
Quantization: Quantization:
use_pact: true use_pact: true
onnx_format: True
activation_quantize_type: 'moving_average_abs_max' activation_quantize_type: 'moving_average_abs_max'
quantize_op_types: quantize_op_types:
- conv2d - conv2d
......
metric: COCO metric: COCO
num_classes: 80 num_classes: 80
...@@ -23,6 +21,6 @@ EvalReader: ...@@ -23,6 +21,6 @@ EvalReader:
sample_transforms: sample_transforms:
- Decode: {} - Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2} - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], is_scale: True} - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {} - Permute: {}
batch_size: 4 batch_size: 4
...@@ -14,6 +14,7 @@ Distillation: ...@@ -14,6 +14,7 @@ Distillation:
Quantization: Quantization:
use_pact: true use_pact: true
onnx_format: True
activation_quantize_type: 'moving_average_abs_max' activation_quantize_type: 'moving_average_abs_max'
quantize_op_types: quantize_op_types:
- conv2d - conv2d
......
...@@ -108,7 +108,7 @@ def argsparser(): ...@@ -108,7 +108,7 @@ def argsparser():
"calibration, trt_calib_mode need to set True.") "calibration, trt_calib_mode need to set True.")
parser.add_argument( parser.add_argument(
'--save_images', '--save_images',
type=bool, type=ast.literal_eval,
default=True, default=True,
help='Save visualization image results.') help='Save visualization image results.')
parser.add_argument( parser.add_argument(
......
...@@ -169,13 +169,15 @@ class Trainer(object): ...@@ -169,13 +169,15 @@ class Trainer(object):
self.use_ema = ('use_ema' in cfg and cfg['use_ema']) self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema: if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998) ema_decay = self.cfg.get('ema_decay', 0.9998)
cycle_epoch = self.cfg.get('cycle_epoch', -1)
ema_decay_type = self.cfg.get('ema_decay_type', 'threshold') ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
cycle_epoch = self.cfg.get('cycle_epoch', -1)
ema_black_list = self.cfg.get('ema_black_list', None)
self.ema = ModelEMA( self.ema = ModelEMA(
self.model, self.model,
decay=ema_decay, decay=ema_decay,
ema_decay_type=ema_decay_type, ema_decay_type=ema_decay_type,
cycle_epoch=cycle_epoch) cycle_epoch=cycle_epoch,
ema_black_list=ema_black_list)
self._nranks = dist.get_world_size() self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank() self._local_rank = dist.get_rank()
......
...@@ -120,7 +120,7 @@ class ATSSAssigner(nn.Layer): ...@@ -120,7 +120,7 @@ class ATSSAssigner(nn.Layer):
# negative batch # negative batch
if num_max_boxes == 0: if num_max_boxes == 0:
assigned_labels = paddle.full( assigned_labels = paddle.full(
[batch_size, num_anchors], bg_index, dtype=gt_labels.dtype) [batch_size, num_anchors], bg_index, dtype='int32')
assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4]) assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
assigned_scores = paddle.zeros( assigned_scores = paddle.zeros(
[batch_size, num_anchors, self.num_classes]) [batch_size, num_anchors, self.num_classes])
......
...@@ -86,7 +86,7 @@ class TaskAlignedAssigner(nn.Layer): ...@@ -86,7 +86,7 @@ class TaskAlignedAssigner(nn.Layer):
# negative batch # negative batch
if num_max_boxes == 0: if num_max_boxes == 0:
assigned_labels = paddle.full( assigned_labels = paddle.full(
[batch_size, num_anchors], bg_index, dtype=gt_labels.dtype) [batch_size, num_anchors], bg_index, dtype='int32')
assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4]) assigned_bboxes = paddle.zeros([batch_size, num_anchors, 4])
assigned_scores = paddle.zeros( assigned_scores = paddle.zeros(
[batch_size, num_anchors, num_classes]) [batch_size, num_anchors, num_classes])
......
...@@ -130,11 +130,10 @@ class PPYOLOEHead(nn.Layer): ...@@ -130,11 +130,10 @@ class PPYOLOEHead(nn.Layer):
constant_(reg_.weight) constant_(reg_.weight)
constant_(reg_.bias, 1.0) constant_(reg_.bias, 1.0)
self.proj = paddle.linspace(0, self.reg_max, self.reg_max + 1) proj = paddle.linspace(0, self.reg_max, self.reg_max + 1).reshape(
self.proj_conv.weight.set_value( [1, self.reg_max + 1, 1, 1])
self.proj.reshape([1, self.reg_max + 1, 1, 1])) self.proj_conv.weight.set_value(proj)
self.proj_conv.weight.stop_gradient = True self.proj_conv.weight.stop_gradient = True
if self.eval_size: if self.eval_size:
anchor_points, stride_tensor = self._generate_anchors() anchor_points, stride_tensor = self._generate_anchors()
self.anchor_points = anchor_points self.anchor_points = anchor_points
...@@ -200,15 +199,15 @@ class PPYOLOEHead(nn.Layer): ...@@ -200,15 +199,15 @@ class PPYOLOEHead(nn.Layer):
feat) feat)
reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat)) reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose( reg_dist = reg_dist.reshape([-1, 4, self.reg_max + 1, l]).transpose(
[0, 2, 1, 3]) [0, 2, 3, 1])
reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1)) reg_dist = self.proj_conv(F.softmax(reg_dist, axis=1)).squeeze(1)
# cls and reg # cls and reg
cls_score = F.sigmoid(cls_logit) cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.reshape([b, self.num_classes, l])) cls_score_list.append(cls_score.reshape([b, self.num_classes, l]))
reg_dist_list.append(reg_dist.reshape([b, 4, l])) reg_dist_list.append(reg_dist)
cls_score_list = paddle.concat(cls_score_list, axis=-1) cls_score_list = paddle.concat(cls_score_list, axis=-1)
reg_dist_list = paddle.concat(reg_dist_list, axis=-1) reg_dist_list = paddle.concat(reg_dist_list, axis=1)
return cls_score_list, reg_dist_list, anchor_points, stride_tensor return cls_score_list, reg_dist_list, anchor_points, stride_tensor
...@@ -240,8 +239,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -240,8 +239,8 @@ class PPYOLOEHead(nn.Layer):
def _bbox_decode(self, anchor_points, pred_dist): def _bbox_decode(self, anchor_points, pred_dist):
b, l, _ = get_static_shape(pred_dist) b, l, _ = get_static_shape(pred_dist)
pred_dist = F.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1 pred_dist = F.softmax(pred_dist.reshape([b, l, 4, self.reg_max + 1]))
])).matmul(self.proj) pred_dist = self.proj_conv(pred_dist.transpose([0, 3, 1, 2])).squeeze(1)
return batch_distance2bbox(anchor_points, pred_dist) return batch_distance2bbox(anchor_points, pred_dist)
def _bbox2distance(self, points, bbox): def _bbox2distance(self, points, bbox):
...@@ -347,9 +346,8 @@ class PPYOLOEHead(nn.Layer): ...@@ -347,9 +346,8 @@ class PPYOLOEHead(nn.Layer):
assigned_scores_sum = assigned_scores.sum() assigned_scores_sum = assigned_scores.sum()
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(assigned_scores_sum) paddle.distributed.all_reduce(assigned_scores_sum)
assigned_scores_sum = paddle.clip( assigned_scores_sum /= paddle.distributed.get_world_size()
assigned_scores_sum / paddle.distributed.get_world_size(), assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
min=1)
loss_cls /= assigned_scores_sum loss_cls /= assigned_scores_sum
loss_l1, loss_iou, loss_dfl = \ loss_l1, loss_iou, loss_dfl = \
...@@ -370,8 +368,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -370,8 +368,7 @@ class PPYOLOEHead(nn.Layer):
def post_process(self, head_outs, scale_factor): def post_process(self, head_outs, scale_factor):
pred_scores, pred_dist, anchor_points, stride_tensor = head_outs pred_scores, pred_dist, anchor_points, stride_tensor = head_outs
pred_bboxes = batch_distance2bbox(anchor_points, pred_bboxes = batch_distance2bbox(anchor_points, pred_dist)
pred_dist.transpose([0, 2, 1]))
pred_bboxes *= stride_tensor pred_bboxes *= stride_tensor
if self.exclude_post_process: if self.exclude_post_process:
return paddle.concat( return paddle.concat(
......
...@@ -36,21 +36,30 @@ class ModelEMA(object): ...@@ -36,21 +36,30 @@ class ModelEMA(object):
step. Defaults is -1, which means not reset. Its function is to step. Defaults is -1, which means not reset. Its function is to
add a regular effect to ema, which is set according to experience add a regular effect to ema, which is set according to experience
and is effective when the total training epoch is large. and is effective when the total training epoch is large.
ema_black_list (set|list|tuple, optional): The custom EMA black_list.
Blacklist of weight names that will not participate in EMA
calculation. Default: None.
""" """
def __init__(self, def __init__(self,
model, model,
decay=0.9998, decay=0.9998,
ema_decay_type='threshold', ema_decay_type='threshold',
cycle_epoch=-1): cycle_epoch=-1,
ema_black_list=None):
self.step = 0 self.step = 0
self.epoch = 0 self.epoch = 0
self.decay = decay self.decay = decay
self.state_dict = dict()
for k, v in model.state_dict().items():
self.state_dict[k] = paddle.zeros_like(v)
self.ema_decay_type = ema_decay_type self.ema_decay_type = ema_decay_type
self.cycle_epoch = cycle_epoch self.cycle_epoch = cycle_epoch
self.ema_black_list = self._match_ema_black_list(
model.state_dict().keys(), ema_black_list)
self.state_dict = dict()
for k, v in model.state_dict().items():
if k in self.ema_black_list:
self.state_dict[k] = v
else:
self.state_dict[k] = paddle.zeros_like(v)
self._model_state = { self._model_state = {
k: weakref.ref(p) k: weakref.ref(p)
...@@ -61,7 +70,10 @@ class ModelEMA(object): ...@@ -61,7 +70,10 @@ class ModelEMA(object):
self.step = 0 self.step = 0
self.epoch = 0 self.epoch = 0
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
self.state_dict[k] = paddle.zeros_like(v) if k in self.ema_black_list:
self.state_dict[k] = v
else:
self.state_dict[k] = paddle.zeros_like(v)
def resume(self, state_dict, step=0): def resume(self, state_dict, step=0):
for k, v in state_dict.items(): for k, v in state_dict.items():
...@@ -89,9 +101,10 @@ class ModelEMA(object): ...@@ -89,9 +101,10 @@ class ModelEMA(object):
[v is not None for _, v in model_dict.items()]), 'python gc.' [v is not None for _, v in model_dict.items()]), 'python gc.'
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
v = decay * v + (1 - decay) * model_dict[k] if k not in self.ema_black_list:
v.stop_gradient = True v = decay * v + (1 - decay) * model_dict[k]
self.state_dict[k] = v v.stop_gradient = True
self.state_dict[k] = v
self.step += 1 self.step += 1
def apply(self): def apply(self):
...@@ -99,12 +112,25 @@ class ModelEMA(object): ...@@ -99,12 +112,25 @@ class ModelEMA(object):
return self.state_dict return self.state_dict
state_dict = dict() state_dict = dict()
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
if self.ema_decay_type != 'exponential': if k in self.ema_black_list:
v = v / (1 - self._decay**self.step) v.stop_gradient = True
v.stop_gradient = True state_dict[k] = v
state_dict[k] = v else:
if self.ema_decay_type != 'exponential':
v = v / (1 - self._decay**self.step)
v.stop_gradient = True
state_dict[k] = v
self.epoch += 1 self.epoch += 1
if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch: if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch:
self.reset() self.reset()
return state_dict return state_dict
def _match_ema_black_list(self, weight_name, ema_black_list=None):
out_list = set()
if ema_black_list:
for name in weight_name:
for key in ema_black_list:
if key in name:
out_list.add(name)
return out_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册