diff --git a/ppdet/modeling/heads/centernet_head.py b/ppdet/modeling/heads/centernet_head.py index 349661affa938aad10d9a157f4621e56dfb4c472..ce8b5c15ddd92c4da0aa217c98e7388cf9b6a3b5 100755 --- a/ppdet/modeling/heads/centernet_head.py +++ b/ppdet/modeling/heads/centernet_head.py @@ -201,9 +201,14 @@ class CenterNetHead(nn.Layer): size_target = inputs['size'] # shape: [bs, max_per_img, 4] else: - size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :, - 2:] - # shape: [bs, max_per_img, 2] + if inputs['size'].shape[-1] == 2: + # inputs['size'] is wh, and regress as wh + # shape: [bs, max_per_img, 2] + size_target = inputs['size'] + else: + # inputs['size'] is ltrb, but regress as wh + # shape: [bs, max_per_img, 4] + size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :, 2:] size_target.stop_gradient = True size_loss = F.l1_loss(