提交 37e56f4b 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix clip under trace(symbolic=True)

GitOrigin-RevId: 5b6f5373270bf4699574feacc4b391b08ecdf6e9
上级 8230ea2d
......@@ -13,6 +13,7 @@ from ..core.ops import builtin
from ..core.tensor import megbrain_graph, utils
from ..core.tensor.core import apply
from ..device import get_default_device
from ..jit.tracing import is_tracing
from ..tensor import Tensor
__all__ = [
......@@ -580,7 +581,8 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor:
), "At least one of 'lower' or 'upper' must not be None"
if lower is not None:
if upper is not None:
assert lower <= upper, "clip lower bound is bigger that upper bound"
if not is_tracing():
assert lower <= upper, "clip lower bound is bigger that upper bound"
return minimum(maximum(x, lower), upper)
else:
return maximum(x, lower)
......
......@@ -394,3 +394,15 @@ def test_trace_valid_broadcast():
f(x1, shape)
f(x2, shape)
def test_clip():
x = tensor(np.random.randn(10, 10))
@trace(symbolic=True)
def f(x, lower, upper):
y = F.clip(x, lower, upper)
return y
for i in range(3):
f(x, tensor([0]), tensor([1]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册