提交 af349d61 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix op mismatch when tracing NMSKeep

GitOrigin-RevId: e8f2cbb7557b7482df936faca80f4fcc15eef22b
上级 d502e79f
...@@ -17,6 +17,7 @@ from ..core.tensor import megbrain_graph, utils ...@@ -17,6 +17,7 @@ from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import TensorBase, TensorWrapperBase, apply from ..core.tensor.core import TensorBase, TensorWrapperBase, apply
from ..core.tensor.utils import astensor1d from ..core.tensor.utils import astensor1d
from ..distributed import WORLD, is_distributed from ..distributed import WORLD, is_distributed
from ..jit.tracing import is_tracing
from ..random import uniform from ..random import uniform
from ..tensor import Tensor from ..tensor import Tensor
from .debug_param import get_conv_execution_strategy from .debug_param import get_conv_execution_strategy
...@@ -1470,13 +1471,17 @@ def indexing_one_hot( ...@@ -1470,13 +1471,17 @@ def indexing_one_hot(
return result return result
def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor: def nms(
boxes: Tensor, scores: Tensor, iou_thresh: float, max_output: Optional[int] = None
) -> Tensor:
r""" r"""
Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU). Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union(IoU).
:param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format. :param boxes: tensor of shape `(N, 4)`; the boxes to perform nms on; each box is expected to be in `(x1, y1, x2, y2)` format.
:param iou_thresh: IoU threshold for overlapping. :param iou_thresh: IoU threshold for overlapping.
:param scores: tensor of shape `(N,)`, the score of boxes. :param scores: tensor of shape `(N,)`, the score of boxes.
:param max_output: the maximum number of boxes to keep; it is optional if this operator is not traced
otherwise it required to be specified; if it is not specified, all boxes are kept.
:return: indices of the elements that have been kept by NMS. :return: indices of the elements that have been kept by NMS.
Examples: Examples:
...@@ -1515,12 +1520,19 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor: ...@@ -1515,12 +1520,19 @@ def nms(boxes: Tensor, scores: Tensor, iou_thresh: float) -> Tensor:
scores = scores.detach() scores = scores.detach()
sorted_idx = argsort(scores, descending=True) sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx] boxes = boxes[sorted_idx]
if is_tracing():
assert (
max_output is not None and max_output > 0
), "max_output should be specified under tracing"
if max_output is None:
max_output = boxes.shape[0] max_output = boxes.shape[0]
op = builtin.NMSKeep(iou_thresh, max_output) op = builtin.NMSKeep(iou_thresh, max_output)
inp = utils.convert_inputs(boxes.reshape(1, -1, 4)) inp = utils.convert_inputs(boxes.reshape(1, -1, 4))
indices, count = apply(op, *inp) indices, count = apply(op, *inp)
indices = indices[0][: count.item()] indices = indices[0][: count[0]]
keep_inds = sorted_idx[indices] keep_inds = sorted_idx[indices]
return keep_inds return keep_inds
......
...@@ -36,6 +36,13 @@ active_trace = None ...@@ -36,6 +36,13 @@ active_trace = None
skip_tracing = False skip_tracing = False
def is_tracing():
if active_trace is None:
return False
else:
return not skip_tracing
@contextlib.contextmanager @contextlib.contextmanager
def exclude_from_trace(): def exclude_from_trace():
global skip_tracing global skip_tracing
......
...@@ -357,3 +357,25 @@ def test_trace_broadcast(): ...@@ -357,3 +357,25 @@ def test_trace_broadcast():
f(x1) f(x1)
f(x2) f(x2)
f(x3) f(x3)
def test_trace_nms():
def make_inputs(n):
boxes = np.zeros((n, 4))
boxes[:, :2] = np.random.rand(n, 2) * 100
boxes[:, 2:] = np.random.rand(n, 2) * 100 + 100
scores = np.random.rand(n)
return tensor(boxes), tensor(scores)
@trace(symbolic=False)
def f(boxes, scores):
results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20)
with exclude_from_trace():
_ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5)
return results
f(*make_inputs(10))
f(*make_inputs(20))
f(*make_inputs(30))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册