未验证 提交 cb89c8d0 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix trt nms error output in ppyoloe (#6607)

上级 16a657e5
......@@ -481,8 +481,9 @@ class MultiClassNMS(object):
# TODO(wangxinxin08): tricky switch to run nms on tensorrt
kwargs.update({'nms_eta': 1.1})
bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs)
mask = paddle.slice(bbox, [-1], [0], [1]) != -1
bbox = paddle.masked_select(bbox, mask).reshape((-1, 6))
bbox = bbox.reshape([1, -1, 6])
idx = paddle.nonzero(bbox[..., 0] != -1)
bbox = paddle.gather_nd(bbox, idx)
return bbox, bbox_num, None
else:
return ops.multiclass_nms(bboxes, score, **kwargs)
......@@ -1353,7 +1354,7 @@ class ConvMixer(nn.Layer):
Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim))
Residual = type('Residual', (Seq, ),
{'forward': lambda self, x: self[0](x) + x})
return Seq(* [
return Seq(*[
Seq(Residual(
ActBn(
nn.Conv2D(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册