未验证 提交 49e1e963 编写于 作者: F Feng Ni 提交者: GitHub

fix centernet_head size_target (#4627)

上级 242c9803
...@@ -201,9 +201,14 @@ class CenterNetHead(nn.Layer): ...@@ -201,9 +201,14 @@ class CenterNetHead(nn.Layer):
size_target = inputs['size'] size_target = inputs['size']
# shape: [bs, max_per_img, 4] # shape: [bs, max_per_img, 4]
else: else:
size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :, if inputs['size'].shape[-1] == 2:
2:] # inputs['size'] is wh, and regress as wh
# shape: [bs, max_per_img, 2] # shape: [bs, max_per_img, 2]
size_target = inputs['size']
else:
# inputs['size'] is ltrb, but regress as wh
# shape: [bs, max_per_img, 4]
size_target = inputs['size'][:, :, 0:2] + inputs['size'][:, :, 2:]
size_target.stop_gradient = True size_target.stop_gradient = True
size_loss = F.l1_loss( size_loss = F.l1_loss(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册