提交 207527d1 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/clamp): fix `F.clamp`

GitOrigin-RevId: 1efac8add61819630895cb29ca4d607c94b040d1
上级 8ac73333
......@@ -233,9 +233,11 @@ def clamp(inp: Tensor, lower=None, upper=None) -> Tensor:
[0 1 2 3 3]
"""
assert lower or upper, "At least one of 'lower' or 'upper' must not be None"
if lower:
if upper:
assert (
lower is not None or upper is not None
), "At least one of 'lower' or 'upper' must not be None"
if lower is not None:
if upper is not None:
assert lower <= upper, "clamp lower bound is bigger that upper bound"
return minimum(maximum(inp, lower), upper)
else:
......
......@@ -44,3 +44,12 @@ def test_multiply():
np.array([3.0, 4.0], dtype=np.float32),
),
)
def test_clamp():
"""Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and
`F.clamp` will fall into wrong conditions unexpectedly.
"""
x = np.linspace(-6, 6, dtype="float32")
assertTensorClose(F.clamp(tensor(x) + 3, 0, 6).numpy(), np.clip(x + 3, 0, 6))
assertTensorClose(F.clamp(tensor(x) - 3, -6, 0).numpy(), np.clip(x - 3, -6, 0))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册