未验证 提交 20ff7404 编写于 作者: F Feng Ni 提交者: GitHub

fix ppyoloe_contrast_head inherit (#7709)

上级 e3ec5d0f
...@@ -17,14 +17,10 @@ import paddle.nn as nn ...@@ -17,14 +17,10 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from ppdet.core.workspace import register from ppdet.core.workspace import register
from ..bbox_utils import batch_distance2bbox from ..initializer import bias_init_with_prob, constant_
from ..losses import GIoULoss
from ..initializer import bias_init_with_prob, constant_, normal_
from ..assigners.utils import generate_anchors_for_grid_cell from ..assigners.utils import generate_anchors_for_grid_cell
from ppdet.modeling.backbones.cspresnet import ConvBNLayer
from ppdet.modeling.ops import get_static_shape, get_act_fn
from ppdet.modeling.layers import MultiClassNMS
from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead
__all__ = ['PPYOLOEContrastHead'] __all__ = ['PPYOLOEContrastHead']
...@@ -32,7 +28,7 @@ __all__ = ['PPYOLOEContrastHead'] ...@@ -32,7 +28,7 @@ __all__ = ['PPYOLOEContrastHead']
class PPYOLOEContrastHead(PPYOLOEHead): class PPYOLOEContrastHead(PPYOLOEHead):
__shared__ = [ __shared__ = [
'num_classes', 'eval_size', 'trt', 'exclude_nms', 'num_classes', 'eval_size', 'trt', 'exclude_nms',
'exclude_post_process', 'use_shared_conv' 'exclude_post_process', 'use_shared_conv', 'for_distill'
] ]
__inject__ = ['static_assigner', 'assigner', 'nms', 'contrast_loss'] __inject__ = ['static_assigner', 'assigner', 'nms', 'contrast_loss']
...@@ -58,15 +54,17 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -58,15 +54,17 @@ class PPYOLOEContrastHead(PPYOLOEHead):
'dfl': 0.5, 'dfl': 0.5,
}, },
trt=False, trt=False,
attn_conv='convbn',
exclude_nms=False, exclude_nms=False,
exclude_post_process=False, exclude_post_process=False,
use_shared_conv=True): use_shared_conv=True,
for_distill=False):
super().__init__(in_channels, num_classes, act, fpn_strides, super().__init__(in_channels, num_classes, act, fpn_strides,
grid_cell_scale, grid_cell_offset, reg_max, reg_range, grid_cell_scale, grid_cell_offset, reg_max, reg_range,
static_assigner_epoch, use_varifocal_loss, static_assigner_epoch, use_varifocal_loss,
static_assigner, assigner, nms, eval_size, loss_weight, static_assigner, assigner, nms, eval_size, loss_weight,
trt, exclude_nms, exclude_post_process, trt, attn_conv, exclude_nms, exclude_post_process,
use_shared_conv) use_shared_conv, for_distill)
assert len(in_channels) > 0, "len(in_channels) should > 0" assert len(in_channels) > 0, "len(in_channels) should > 0"
self.contrast_loss = contrast_loss self.contrast_loss = contrast_loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册