From 9a1ea9b45df4e8d90e71d3c7d50309aef8b4d801 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Mon, 31 Aug 2020 11:14:22 +0800 Subject: [PATCH] Add support for tensor min/max in dygraph (#26764) --- python/paddle/fluid/tests/unittests/test_clip_op.py | 4 ++++ python/paddle/tensor/math.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_clip_op.py b/python/paddle/fluid/tests/unittests/test_clip_op.py index 74c01e1424..2e1f9d4174 100644 --- a/python/paddle/fluid/tests/unittests/test_clip_op.py +++ b/python/paddle/fluid/tests/unittests/test_clip_op.py @@ -166,12 +166,16 @@ class TestClipAPI(unittest.TestCase): data_shape = [1, 9, 9, 4] data = np.random.random(data_shape).astype('float32') images = paddle.to_variable(data, dtype='float32') + v_min = paddle.to_variable(np.array([0.2], dtype=np.float32)) + v_max = paddle.to_variable(np.array([0.8], dtype=np.float32)) out_1 = paddle.clip(images, min=0.2, max=0.8) out_2 = paddle.clip(images, min=0.2, max=0.9) + out_3 = paddle.clip(images, min=v_min, max=v_max) self.assertTrue(np.allclose(out_1.numpy(), data.clip(0.2, 0.8))) self.assertTrue(np.allclose(out_2.numpy(), data.clip(0.2, 0.9))) + self.assertTrue(np.allclose(out_3.numpy(), data.clip(0.2, 0.8))) def test_errors(self): paddle.enable_static() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 0d87c1c2cf..d2db2a7cb7 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1618,6 +1618,10 @@ def clip(x, min=None, max=None, name=None): fmax = float(np.finfo(np_dtype).max) if in_dygraph_mode(): + if isinstance(min, Variable): + min = min.numpy().item(0) + if isinstance(max, Variable): + max = max.numpy().item(0) min = fmin if min is None else min max = fmax if max is None else max return core.ops.clip(x, "min", min, "max", max) -- GitLab