diff --git a/ppdet/modeling/heads/cascade_head.py b/ppdet/modeling/heads/cascade_head.py index fab310ef11931ff3a8ce4eb6014428b4998b339b..c07c22734d5e1d3229648658884b7ec7ec0741ee 100644 --- a/ppdet/modeling/heads/cascade_head.py +++ b/ppdet/modeling/heads/cascade_head.py @@ -161,7 +161,9 @@ class CascadeHead(BBoxHead): [30.0, 30.0, 15.0, 15.0]], num_cascade_stages=3, bbox_loss=None, - reg_class_agnostic=True): + reg_class_agnostic=True, + stage_loss_weights=None): + nn.Layer.__init__(self, ) self.head = head self.roi_extractor = roi_extractor @@ -173,6 +175,12 @@ class CascadeHead(BBoxHead): self.bbox_weight = bbox_weight self.num_cascade_stages = num_cascade_stages self.bbox_loss = bbox_loss + self.stage_loss_weights = [ + 1. / num_cascade_stages for _ in range(num_cascade_stages) + ] if stage_loss_weights is None else stage_loss_weights + assert len( + self.stage_loss_weights + ) == num_cascade_stages, f'stage_loss_weights({len(self.stage_loss_weights)}) do not equal to num_cascade_stages({num_cascade_stages})' self.reg_class_agnostic = reg_class_agnostic num_bbox_delta = 4 if reg_class_agnostic else 4 * num_classes @@ -249,7 +257,7 @@ class CascadeHead(BBoxHead): self.bbox_weight[stage]) for k, v in loss_stage.items(): loss[k + "_stage{}".format( - stage)] = v / self.num_cascade_stages + stage)] = v * self.stage_loss_weights[stage] return loss, bbox_feat else: