From 5d749951dcf8a995dd39bc3140bd138740b5b22d Mon Sep 17 00:00:00 2001 From: Wenyu Date: Tue, 5 Jul 2022 10:59:02 +0800 Subject: [PATCH] add stage_loss_weights (#6350) * add stage_loss_weights --- ppdet/modeling/heads/cascade_head.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/ppdet/modeling/heads/cascade_head.py b/ppdet/modeling/heads/cascade_head.py index fab310ef1..c07c22734 100644 --- a/ppdet/modeling/heads/cascade_head.py +++ b/ppdet/modeling/heads/cascade_head.py @@ -161,7 +161,9 @@ class CascadeHead(BBoxHead): [30.0, 30.0, 15.0, 15.0]], num_cascade_stages=3, bbox_loss=None, - reg_class_agnostic=True): + reg_class_agnostic=True, + stage_loss_weights=None): + nn.Layer.__init__(self, ) self.head = head self.roi_extractor = roi_extractor @@ -173,6 +175,12 @@ class CascadeHead(BBoxHead): self.bbox_weight = bbox_weight self.num_cascade_stages = num_cascade_stages self.bbox_loss = bbox_loss + self.stage_loss_weights = [ + 1. / num_cascade_stages for _ in range(num_cascade_stages) + ] if stage_loss_weights is None else stage_loss_weights + assert len( + self.stage_loss_weights + ) == num_cascade_stages, f'stage_loss_weights({len(self.stage_loss_weights)}) do not equal to num_cascade_stages({num_cascade_stages})' self.reg_class_agnostic = reg_class_agnostic num_bbox_delta = 4 if reg_class_agnostic else 4 * num_classes @@ -249,7 +257,7 @@ class CascadeHead(BBoxHead): self.bbox_weight[stage]) for k, v in loss_stage.items(): loss[k + "_stage{}".format( - stage)] = v / self.num_cascade_stages + stage)] = v * self.stage_loss_weights[stage] return loss, bbox_feat else: -- GitLab