未验证 提交 cb89c8d0 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix trt nms error output in ppyoloe (#6607)

上级 16a657e5
...@@ -481,8 +481,9 @@ class MultiClassNMS(object): ...@@ -481,8 +481,9 @@ class MultiClassNMS(object):
# TODO(wangxinxin08): tricky switch to run nms on tensorrt # TODO(wangxinxin08): tricky switch to run nms on tensorrt
kwargs.update({'nms_eta': 1.1}) kwargs.update({'nms_eta': 1.1})
bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs) bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs)
mask = paddle.slice(bbox, [-1], [0], [1]) != -1 bbox = bbox.reshape([1, -1, 6])
bbox = paddle.masked_select(bbox, mask).reshape((-1, 6)) idx = paddle.nonzero(bbox[..., 0] != -1)
bbox = paddle.gather_nd(bbox, idx)
return bbox, bbox_num, None return bbox, bbox_num, None
else: else:
return ops.multiclass_nms(bboxes, score, **kwargs) return ops.multiclass_nms(bboxes, score, **kwargs)
...@@ -1353,7 +1354,7 @@ class ConvMixer(nn.Layer): ...@@ -1353,7 +1354,7 @@ class ConvMixer(nn.Layer):
Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim)) Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim))
Residual = type('Residual', (Seq, ), Residual = type('Residual', (Seq, ),
{'forward': lambda self, x: self[0](x) + x}) {'forward': lambda self, x: self[0](x) + x})
return Seq(* [ return Seq(*[
Seq(Residual( Seq(Residual(
ActBn( ActBn(
nn.Conv2D( nn.Conv2D(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册