diff --git a/imperative/python/megengine/functional/elemwise.py b/imperative/python/megengine/functional/elemwise.py index 3b71291c83d77bec501797cbcb16e04a640a8b20..686ddf4cbb29f7d220b65985122d5a8d04d85100 100644 --- a/imperative/python/megengine/functional/elemwise.py +++ b/imperative/python/megengine/functional/elemwise.py @@ -13,6 +13,7 @@ from ..core.ops import builtin from ..core.tensor import megbrain_graph, utils from ..core.tensor.core import apply from ..device import get_default_device +from ..jit.tracing import is_tracing from ..tensor import Tensor __all__ = [ @@ -580,7 +581,8 @@ def clip(x: Tensor, lower=None, upper=None) -> Tensor: ), "At least one of 'lower' or 'upper' must not be None" if lower is not None: if upper is not None: - assert lower <= upper, "clip lower bound is bigger that upper bound" + if not is_tracing(): + assert lower <= upper, "clip lower bound is bigger that upper bound" return minimum(maximum(x, lower), upper) else: return maximum(x, lower) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index bca796a320f1ec6c34fdb3a095dad1b479a33699..abc2463f6b43960758c7941149c06550c8f72d25 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -394,3 +394,15 @@ def test_trace_valid_broadcast(): f(x1, shape) f(x2, shape) + + +def test_clip(): + x = tensor(np.random.randn(10, 10)) + + @trace(symbolic=True) + def f(x, lower, upper): + y = F.clip(x, lower, upper) + return y + + for i in range(3): + f(x, tensor([0]), tensor([1]))