未验证 提交 5d749951 编写于 作者: W Wenyu 提交者: GitHub

add stage_loss_weights (#6350)

* add stage_loss_weights
上级 dcadfc3e
......@@ -161,7 +161,9 @@ class CascadeHead(BBoxHead):
[30.0, 30.0, 15.0, 15.0]],
num_cascade_stages=3,
bbox_loss=None,
reg_class_agnostic=True):
reg_class_agnostic=True,
stage_loss_weights=None):
nn.Layer.__init__(self, )
self.head = head
self.roi_extractor = roi_extractor
......@@ -173,6 +175,12 @@ class CascadeHead(BBoxHead):
self.bbox_weight = bbox_weight
self.num_cascade_stages = num_cascade_stages
self.bbox_loss = bbox_loss
self.stage_loss_weights = [
1. / num_cascade_stages for _ in range(num_cascade_stages)
] if stage_loss_weights is None else stage_loss_weights
assert len(
self.stage_loss_weights
) == num_cascade_stages, f'stage_loss_weights({len(self.stage_loss_weights)}) do not equal to num_cascade_stages({num_cascade_stages})'
self.reg_class_agnostic = reg_class_agnostic
num_bbox_delta = 4 if reg_class_agnostic else 4 * num_classes
......@@ -249,7 +257,7 @@ class CascadeHead(BBoxHead):
self.bbox_weight[stage])
for k, v in loss_stage.items():
loss[k + "_stage{}".format(
stage)] = v / self.num_cascade_stages
stage)] = v * self.stage_loss_weights[stage]
return loss, bbox_feat
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册