提交 c6b552cf 编写于 作者: M Megvii Engine Team

refactor(mge/functional): remove dependence to trace in functional implementations

GitOrigin-RevId: 0b18479fccd551a9ab2902ae5f086176e6c58d0a
上级 46d96478
......@@ -16,7 +16,6 @@ from ..core.tensor import utils
from ..core.tensor.array_method import _elwise_apply
from ..core.tensor.utils import astype
from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor
from ..utils.deprecation import deprecated_func
......@@ -560,8 +559,8 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor:
), "At least one of 'lower' or 'upper' must not be None"
if lower is not None:
if upper is not None:
if not is_tracing():
assert lower <= upper, "clip lower bound is bigger that upper bound"
# FIXME: following assertion won't work during trace if upper and lower are Tensors
# assert lower <= upper, "clip lower bound is bigger that upper bound"
return minimum(maximum(x, lower), upper)
else:
return maximum(x, lower)
......
......@@ -12,7 +12,6 @@ from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.utils import astensor1d
from ..jit.tracing import is_tracing
from ..tensor import Tensor
from .elemwise import floor
from .math import argsort
......@@ -226,6 +225,10 @@ def nms(
otherwise it required to be specified; if it is not specified, all boxes are kept.
:return: indices of the elements that have been kept by NMS, sorted by scores.
.. note::
max_output should be specified and should have valid positive value under tracing
Examples:
.. testcode::
......@@ -263,11 +266,6 @@ def nms(
sorted_idx = argsort(scores, descending=True)
boxes = boxes[sorted_idx]
if is_tracing():
assert (
max_output is not None and max_output > 0
), "max_output should be specified under tracing"
if max_output is None:
max_output = boxes.shape[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册