From e34ef6b61168f1feb9e12305c544d73b04f585cb Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Fri, 19 Nov 2021 10:12:16 +0800 Subject: [PATCH] fix centernet_head size_target (#4626) --- ppdet/modeling/heads/centernet_head.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ppdet/modeling/heads/centernet_head.py b/ppdet/modeling/heads/centernet_head.py index 349661aff..ce8b5c15d 100755 --- a/ppdet/modeling/heads/centernet_head.py +++ b/ppdet/modeling/heads/centernet_head.py @@ -201,9 +201,14 @@ class CenterNetHead(nn.Layer): size_target = inputs['size'] # shape: [bs, max_per_img, 4] else: - size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :, - 2:] - # shape: [bs, max_per_img, 2] + if inputs['size'].shape[-1] == 2: + # inputs['size'] is wh, and regress as wh + # shape: [bs, max_per_img, 2] + size_target = inputs['size'] + else: + # inputs['size'] is ltrb, but regress as wh + # shape: [bs, max_per_img, 4] + size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :, 2:] size_target.stop_gradient = True size_loss = F.l1_loss( -- GitLab