From 20ff74042c78e8a399408bc967604802137ab408 Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Wed, 8 Feb 2023 20:13:32 +0800 Subject: [PATCH] fix ppyoloe_contrast_head inherit (#7709) --- ppdet/modeling/heads/ppyoloe_contrast_head.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ppdet/modeling/heads/ppyoloe_contrast_head.py b/ppdet/modeling/heads/ppyoloe_contrast_head.py index df61194ed..3b8e26e63 100644 --- a/ppdet/modeling/heads/ppyoloe_contrast_head.py +++ b/ppdet/modeling/heads/ppyoloe_contrast_head.py @@ -17,14 +17,10 @@ import paddle.nn as nn import paddle.nn.functional as F from ppdet.core.workspace import register -from ..bbox_utils import batch_distance2bbox -from ..losses import GIoULoss -from ..initializer import bias_init_with_prob, constant_, normal_ +from ..initializer import bias_init_with_prob, constant_ from ..assigners.utils import generate_anchors_for_grid_cell -from ppdet.modeling.backbones.cspresnet import ConvBNLayer -from ppdet.modeling.ops import get_static_shape, get_act_fn -from ppdet.modeling.layers import MultiClassNMS from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead + __all__ = ['PPYOLOEContrastHead'] @@ -32,7 +28,7 @@ __all__ = ['PPYOLOEContrastHead'] class PPYOLOEContrastHead(PPYOLOEHead): __shared__ = [ 'num_classes', 'eval_size', 'trt', 'exclude_nms', - 'exclude_post_process', 'use_shared_conv' + 'exclude_post_process', 'use_shared_conv', 'for_distill' ] __inject__ = ['static_assigner', 'assigner', 'nms', 'contrast_loss'] @@ -58,15 +54,17 @@ class PPYOLOEContrastHead(PPYOLOEHead): 'dfl': 0.5, }, trt=False, + attn_conv='convbn', exclude_nms=False, exclude_post_process=False, - use_shared_conv=True): + use_shared_conv=True, + for_distill=False): super().__init__(in_channels, num_classes, act, fpn_strides, grid_cell_scale, grid_cell_offset, reg_max, reg_range, static_assigner_epoch, use_varifocal_loss, static_assigner, assigner, nms, eval_size, loss_weight, - trt, exclude_nms, exclude_post_process, - use_shared_conv) + trt, attn_conv, exclude_nms, exclude_post_process, + use_shared_conv, for_distill) assert len(in_channels) > 0, "len(in_channels) should > 0" self.contrast_loss = contrast_loss -- GitLab