未验证 提交 9a1ea9b4 编写于 作者: Y Yang Zhang 提交者: GitHub

Add support for tensor min/max in dygraph (#26764)

上级 4a57880d
......@@ -166,12 +166,16 @@ class TestClipAPI(unittest.TestCase):
data_shape = [1, 9, 9, 4]
data = np.random.random(data_shape).astype('float32')
images = paddle.to_variable(data, dtype='float32')
v_min = paddle.to_variable(np.array([0.2], dtype=np.float32))
v_max = paddle.to_variable(np.array([0.8], dtype=np.float32))
out_1 = paddle.clip(images, min=0.2, max=0.8)
out_2 = paddle.clip(images, min=0.2, max=0.9)
out_3 = paddle.clip(images, min=v_min, max=v_max)
self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8)))
self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9)))
self.assertTrue(np.allclose(out_3.numpy(), data.clip(0.2, 0.8)))
def test_errors(self):
paddle.enable_static()
......
......@@ -1618,6 +1618,10 @@ def clip(x, min=None, max=None, name=None):
fmax = float(np.finfo(np_dtype).max)
if in_dygraph_mode():
if isinstance(min, Variable):
min = min.numpy().item(0)
if isinstance(max, Variable):
max = max.numpy().item(0)
min = fmin if min is None else min
max = fmax if max is None else max
return core.ops.clip(x, "min", min, "max", max)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册