未验证 提交 1c923ffe 编写于 作者: F Feng Ni 提交者: GitHub

fix ppyoloe_contrast_head aux_pred input (#7705)

上级 9fafde8f
...@@ -27,6 +27,7 @@ from ppdet.modeling.layers import MultiClassNMS ...@@ -27,6 +27,7 @@ from ppdet.modeling.layers import MultiClassNMS
from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead from ppdet.modeling.heads.ppyoloe_head import PPYOLOEHead
__all__ = ['PPYOLOEContrastHead'] __all__ = ['PPYOLOEContrastHead']
@register @register
class PPYOLOEContrastHead(PPYOLOEHead): class PPYOLOEContrastHead(PPYOLOEHead):
__shared__ = [ __shared__ = [
...@@ -60,33 +61,18 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -60,33 +61,18 @@ class PPYOLOEContrastHead(PPYOLOEHead):
exclude_nms=False, exclude_nms=False,
exclude_post_process=False, exclude_post_process=False,
use_shared_conv=True): use_shared_conv=True):
super().__init__(in_channels, super().__init__(in_channels, num_classes, act, fpn_strides,
num_classes, grid_cell_scale, grid_cell_offset, reg_max, reg_range,
act, static_assigner_epoch, use_varifocal_loss,
fpn_strides, static_assigner, assigner, nms, eval_size, loss_weight,
grid_cell_scale, trt, exclude_nms, exclude_post_process,
grid_cell_offset,
reg_max,
reg_range,
static_assigner_epoch,
use_varifocal_loss,
static_assigner,
assigner,
nms,
eval_size,
loss_weight,
trt,
exclude_nms,
exclude_post_process,
use_shared_conv) use_shared_conv)
assert len(in_channels) > 0, "len(in_channels) should > 0" assert len(in_channels) > 0, "len(in_channels) should > 0"
self.contrast_loss = contrast_loss self.contrast_loss = contrast_loss
self.contrast_encoder = nn.LayerList() self.contrast_encoder = nn.LayerList()
for in_c in self.in_channels: for in_c in self.in_channels:
self.contrast_encoder.append( self.contrast_encoder.append(nn.Conv2D(in_c, 128, 3, padding=1))
nn.Conv2D(
in_c, 128, 3, padding=1))
self._init_contrast_encoder() self._init_contrast_encoder()
def _init_contrast_encoder(self): def _init_contrast_encoder(self):
...@@ -95,7 +81,7 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -95,7 +81,7 @@ class PPYOLOEContrastHead(PPYOLOEHead):
constant_(en_.weight) constant_(en_.weight)
constant_(en_.bias, bias_en) constant_(en_.bias, bias_en)
def forward_train(self, feats, targets): def forward_train(self, feats, targets, aux_pred=None):
anchors, anchor_points, num_anchors_list, stride_tensor = \ anchors, anchor_points, num_anchors_list, stride_tensor = \
generate_anchors_for_grid_cell( generate_anchors_for_grid_cell(
feats, self.fpn_strides, self.grid_cell_scale, feats, self.fpn_strides, self.grid_cell_scale,
...@@ -108,9 +94,10 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -108,9 +94,10 @@ class PPYOLOEContrastHead(PPYOLOEHead):
cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) + cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
feat) feat)
reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat)) reg_distri = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
contrast_logit = self.contrast_encoder[i](self.stem_cls[i](feat, avg_feat) + contrast_logit = self.contrast_encoder[i](self.stem_cls[i](
feat) feat, avg_feat) + feat)
contrast_encoder_list.append(contrast_logit.flatten(2).transpose([0, 2, 1])) contrast_encoder_list.append(
contrast_logit.flatten(2).transpose([0, 2, 1]))
# cls and reg # cls and reg
cls_score = F.sigmoid(cls_logit) cls_score = F.sigmoid(cls_logit)
cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1])) cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
...@@ -120,8 +107,8 @@ class PPYOLOEContrastHead(PPYOLOEHead): ...@@ -120,8 +107,8 @@ class PPYOLOEContrastHead(PPYOLOEHead):
contrast_encoder_list = paddle.concat(contrast_encoder_list, axis=1) contrast_encoder_list = paddle.concat(contrast_encoder_list, axis=1)
return self.get_loss([ return self.get_loss([
cls_score_list, reg_distri_list, contrast_encoder_list, anchors, anchor_points, cls_score_list, reg_distri_list, contrast_encoder_list, anchors,
num_anchors_list, stride_tensor anchor_points, num_anchors_list, stride_tensor
], targets) ], targets)
def get_loss(self, head_outs, gt_meta): def get_loss(self, head_outs, gt_meta):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册