diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index ed2e5273df38c6a4553db4a9342840da524921e4..0f740444123cbef87af0d2714d96d283123fa5b5 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -94,13 +94,11 @@ class TestLerpAPI(unittest.TestCase): with paddle.static.program_guard(paddle.static.Program()): x = paddle.fluid.data('x', [1, 4], dtype=self.dtype) y = paddle.fluid.data('y', [1, 4], dtype=self.dtype) - w = paddle.fluid.data('w', [1], dtype=self.dtype) - out = paddle.lerp(x, y, w) + out = paddle.lerp(x, y, 0.5) exe = paddle.static.Executor(place) res = exe.run(feed={ 'x': self.x.reshape([1, 4]), 'y': self.y.reshape([1, 4]), - 'w': self.w }) for r in res: self.assertEqual(np.allclose(self.res_ref, r), True) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 931e7a6787fff1f3905d326ff09fefe1b24cd665..a322c133436d19b570a1bcf3ae2478e27aa72048 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2828,6 +2828,9 @@ def lerp(x, y, weight, name=None): weight = paddle.to_tensor(weight, dtype=x.dtype) return _C_ops.lerp(x, y, weight) + if isinstance(weight, float): + weight = paddle.full(shape=[1], fill_value=weight, dtype=x.dtype) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lerp') check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'lerp') check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], 'lerp')