未验证 提交 58b4bc72 编写于 作者: W wuhuanzhou 提交者: GitHub

Arg weight of lerp support float in static mode, test=develop (#38080)

上级 8305c2be
...@@ -94,13 +94,11 @@ class TestLerpAPI(unittest.TestCase): ...@@ -94,13 +94,11 @@ class TestLerpAPI(unittest.TestCase):
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('x', [1, 4], dtype=self.dtype) x = paddle.fluid.data('x', [1, 4], dtype=self.dtype)
y = paddle.fluid.data('y', [1, 4], dtype=self.dtype) y = paddle.fluid.data('y', [1, 4], dtype=self.dtype)
w = paddle.fluid.data('w', [1], dtype=self.dtype) out = paddle.lerp(x, y, 0.5)
out = paddle.lerp(x, y, w)
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
res = exe.run(feed={ res = exe.run(feed={
'x': self.x.reshape([1, 4]), 'x': self.x.reshape([1, 4]),
'y': self.y.reshape([1, 4]), 'y': self.y.reshape([1, 4]),
'w': self.w
}) })
for r in res: for r in res:
self.assertEqual(np.allclose(self.res_ref, r), True) self.assertEqual(np.allclose(self.res_ref, r), True)
......
...@@ -2828,6 +2828,9 @@ def lerp(x, y, weight, name=None): ...@@ -2828,6 +2828,9 @@ def lerp(x, y, weight, name=None):
weight = paddle.to_tensor(weight, dtype=x.dtype) weight = paddle.to_tensor(weight, dtype=x.dtype)
return _C_ops.lerp(x, y, weight) return _C_ops.lerp(x, y, weight)
if isinstance(weight, float):
weight = paddle.full(shape=[1], fill_value=weight, dtype=x.dtype)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lerp') check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lerp')
check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'lerp') check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'lerp')
check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], 'lerp') check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], 'lerp')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册