From 58b4bc72f47a12929c3b3b29bd9bd1fc883f8fa5 Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Thu, 16 Dec 2021 13:59:11 +0800 Subject: [PATCH] Arg weight of lerp support float in static mode, test=develop (#38080) --- python/paddle/fluid/tests/unittests/test_lerp_op.py | 4 +--- python/paddle/tensor/math.py | 3 +++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index ed2e5273df3..0f740444123 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -94,13 +94,11 @@ class TestLerpAPI(unittest.TestCase): with paddle.static.program_guard(paddle.static.Program()): x = paddle.fluid.data('x', [1, 4], dtype=self.dtype) y = paddle.fluid.data('y', [1, 4], dtype=self.dtype) - w = paddle.fluid.data('w', [1], dtype=self.dtype) - out = paddle.lerp(x, y, w) + out = paddle.lerp(x, y, 0.5) exe = paddle.static.Executor(place) res = exe.run(feed={ 'x': self.x.reshape([1, 4]), 'y': self.y.reshape([1, 4]), - 'w': self.w }) for r in res: self.assertEqual(np.allclose(self.res_ref, r), True) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 931e7a6787f..a322c133436 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2828,6 +2828,9 @@ def lerp(x, y, weight, name=None): weight = paddle.to_tensor(weight, dtype=x.dtype) return _C_ops.lerp(x, y, weight) + if isinstance(weight, float): + weight = paddle.full(shape=[1], fill_value=weight, dtype=x.dtype) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lerp') check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'lerp') check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], 'lerp') -- GitLab