From 207527d10872bbd87e973efed21ea4b9113913df Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 17 Apr 2020 15:26:49 +0800 Subject: [PATCH] fix(mge/clamp): fix `F.clamp` GitOrigin-RevId: 1efac8add61819630895cb29ca4d607c94b040d1 --- python_module/megengine/functional/elemwise.py | 8 +++++--- python_module/test/unit/functional/test_elemwise.py | 9 +++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/python_module/megengine/functional/elemwise.py b/python_module/megengine/functional/elemwise.py index 3b35fb815..6bed2d3d8 100644 --- a/python_module/megengine/functional/elemwise.py +++ b/python_module/megengine/functional/elemwise.py @@ -233,9 +233,11 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor: [0 1 2 3 3] """ - assert lower or upper, "At least one of 'lower' or 'upper' must not be None" - if lower: - if upper: + assert ( + lower is not None or upper is not None + ), "At least one of 'lower' or 'upper' must not be None" + if lower is not None: + if upper is not None: assert lower <= upper, "clamp lower bound is bigger that upper bound" return minimum(maximum(inp, lower), upper) else: diff --git a/python_module/test/unit/functional/test_elemwise.py b/python_module/test/unit/functional/test_elemwise.py index 9ed9e42b0..ef9cf6fad 100644 --- a/python_module/test/unit/functional/test_elemwise.py +++ b/python_module/test/unit/functional/test_elemwise.py @@ -44,3 +44,12 @@ def test_multiply(): np.array([3.0, 4.0], dtype=np.float32), ), ) + + +def test_clamp(): + """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and + `F.clamp` will fall into wrong conditions unexpectedly. + """ + x = np.linspace(-6, 6, dtype="float32") + assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6)) + assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0)) -- GitLab