From 1716324c78e86ff22b588666460698d1845c05cb Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Wed, 8 Dec 2021 16:20:04 +0800 Subject: [PATCH] Add paddle.lerp API to do a linear interpolation (#37253) * save temp * add unittest, test=develop * fix ci error, test=develop * fix grad accuracy error, test=develop * fix unused error, test=develop * fix compilation error on Windows, test=develop * add unittest, test=develop * modify by review comment and add lerp_ * fix inplace api, test=develop * fix inplace api, test=develop * fix coverage error, test=develop --- paddle/fluid/operators/lerp_op.cc | 146 ++++++++++++ paddle/fluid/operators/lerp_op.cu | 27 +++ paddle/fluid/operators/lerp_op.h | 217 ++++++++++++++++++ python/paddle/__init__.py | 2 + .../fluid/tests/unittests/test_lerp_op.py | 171 ++++++++++++++ python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/math.py | 62 +++++ 7 files changed, 629 insertions(+) create mode 100644 paddle/fluid/operators/lerp_op.cc create mode 100644 paddle/fluid/operators/lerp_op.cu create mode 100644 paddle/fluid/operators/lerp_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_lerp_op.py diff --git a/paddle/fluid/operators/lerp_op.cc b/paddle/fluid/operators/lerp_op.cc new file mode 100644 index 0000000000..b94182e9db --- /dev/null +++ b/paddle/fluid/operators/lerp_op.cc @@ -0,0 +1,146 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/lerp_op.h" + +namespace paddle { +namespace operators { + +class LerpOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "lerp"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "lerp"); + OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "lerp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "lerp"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto w_dims = ctx->GetInputDim("Weight"); + framework::DDim out_dims; + out_dims = GetOutputDims(x_dims, y_dims); + if (w_dims.size() > 1 || w_dims[0] != 1) { + out_dims = GetOutputDims(out_dims, w_dims); + } + + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + private: + framework::DDim GetOutputDims(const framework::DDim& s_dims, + const framework::DDim& l_dims) const { + if (s_dims.size() > l_dims.size()) { + return GetOutputDims(l_dims, s_dims); + } + std::vector shapes = framework::vectorize(l_dims); + for (int i = s_dims.size() - 1, j = l_dims.size() - 1; i >= 0; --i, --j) { + int64_t s = s_dims[i]; + int64_t l = l_dims[j]; + if (s != l) { + if (l == 1) { + shapes[j] = s; + } else if (s != 1) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The shape of tensor a %s:%d must match shape of tensor b " + "%s:%d.", + s_dims.to_str(), i, l_dims.to_str(), j)); + } + } + } + return framework::make_ddim(shapes); + } +}; + +class LerpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of lerp op."); + AddInput("Y", "(Tensor), The input tensor of lerp op."); + AddInput("Weight", "(Tensor, optional), The input tensor of lerp op."); + AddOutput("Out", "(Tensor), The output tensor of lerp op."); + AddComment(R"DOC( +Lerp Operator. + +This operator is used to do a linear interpolation of input $X$ and $Y$ with $Weight$. + +The equation is: + +$$Out = X + Weight * (Y - X)$$ + +Both the input $X$ and $Y$ can carry the LoD (Level of Details) information, +or not. But the output only shares the LoD information with input $X$. + +)DOC"); + } +}; + +class LerpGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + if (ctx->HasOutput(framework::GradVarName("Y"))) { + ctx->SetOutputDim(framework::GradVarName("Y"), ctx->GetInputDim("Y")); + } + } +}; + +template +class LerpOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("lerp_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput("Weight", this->Input("Weight")); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_INPLACE_OP_INFERER(LerpInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR( + lerp, paddle::operators::LerpOp, paddle::operators::LerpOpMaker, + paddle::operators::LerpOpGradMaker, + paddle::operators::LerpOpGradMaker, + paddle::operators::LerpInplaceInferer); + +REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp); + +REGISTER_OP_CPU_KERNEL( + lerp, + paddle::operators::LerpKernel, + paddle::operators::LerpKernel); + +REGISTER_OP_CPU_KERNEL( + lerp_grad, + paddle::operators::LerpGradKernel, + paddle::operators::LerpGradKernel); diff --git a/paddle/fluid/operators/lerp_op.cu b/paddle/fluid/operators/lerp_op.cu new file mode 100644 index 0000000000..6f7d8b744d --- /dev/null +++ b/paddle/fluid/operators/lerp_op.cu @@ -0,0 +1,27 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/lerp_op.h" + +REGISTER_OP_CUDA_KERNEL( + lerp, + paddle::operators::LerpKernel, + paddle::operators::LerpKernel); + +REGISTER_OP_CUDA_KERNEL( + lerp_grad, + paddle::operators::LerpGradKernel, + paddle::operators::LerpGradKernel); diff --git a/paddle/fluid/operators/lerp_op.h b/paddle/fluid/operators/lerp_op.h new file mode 100644 index 0000000000..380a8ccffd --- /dev/null +++ b/paddle/fluid/operators/lerp_op.h @@ -0,0 +1,217 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +#ifdef _WIN32 +#ifndef NOMINMAX +#define NOMINMAX // msvc max/min macro conflict with std::min/max +#endif +#endif + +namespace paddle { +namespace operators { + +static framework::DDim ExtendDims2Rank(const framework::DDim& in_dims, + int rank) { + if (in_dims.size() == rank) { + return in_dims; + } + std::vector shapes(rank, 1); + for (int i = in_dims.size() - 1, j = rank - 1; i >= 0; --i, --j) { + shapes[j] = in_dims[i]; + } + return framework::make_ddim(shapes); +} + +template +static void GetBroadcastDims(const framework::DDim& in_dims, + const framework::DDim& out_dims, + Eigen::DSizes* bcast_dims) { + for (size_t i = 0; i < D; ++i) { + if (in_dims[i] == out_dims[i]) { + (*bcast_dims)[i] = 1; + } else { + (*bcast_dims)[i] = std::max(in_dims[i], out_dims[i]); + } + } +} + +template +static void LerpFunction(const framework::ExecutionContext& ctx) { + auto x = ctx.Input("X"); + auto y = ctx.Input("Y"); + auto w = ctx.Input("Weight"); + auto out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto out_dims = out->dims(); + auto x_dims = ExtendDims2Rank(x->dims(), D); + auto y_dims = ExtendDims2Rank(y->dims(), D); + auto w_dims = ExtendDims2Rank(w->dims(), D); + Eigen::DSizes x_bcast_dims; + Eigen::DSizes y_bcast_dims; + Eigen::DSizes w_bcast_dims; + GetBroadcastDims(x_dims, out_dims, &x_bcast_dims); + GetBroadcastDims(y_dims, out_dims, &y_bcast_dims); + GetBroadcastDims(w_dims, out_dims, &w_bcast_dims); + + auto eigen_x = framework::EigenTensor::From(*x, x_dims); + auto eigen_y = framework::EigenTensor::From(*y, y_dims); + auto eigen_w = framework::EigenTensor::From(*w, w_dims); + auto eigen_out = framework::EigenTensor::From(*out); + + auto& place = *ctx.template device_context().eigen_device(); + eigen_out.device(place) = + eigen_x.broadcast(x_bcast_dims) + + eigen_w.broadcast(w_bcast_dims) * + (eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims)); +} + +template +static void LerpGradFunction(const framework::ExecutionContext& ctx) { + auto w = ctx.Input("Weight"); + auto dout = ctx.Input(framework::GradVarName("Out")); + auto dx = ctx.Output(framework::GradVarName("X")); + auto dy = ctx.Output(framework::GradVarName("Y")); + + auto dout_dims = dout->dims(); + auto dx_dims = ExtendDims2Rank(dx->dims(), D); + auto dy_dims = ExtendDims2Rank(dy->dims(), D); + auto w_dims = ExtendDims2Rank(w->dims(), D); + Eigen::DSizes dx_bcast_dims; + Eigen::DSizes dy_bcast_dims; + Eigen::DSizes w_bcast_dims; + GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); + GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); + GetBroadcastDims(w_dims, dout_dims, &w_bcast_dims); + + auto eigen_w = framework::EigenTensor::From(*w, w_dims); + auto eigen_dout = framework::EigenTensor::From(*dout); + + Eigen::DSizes dx_reshape_dims; + Eigen::DSizes dy_reshape_dims; + Eigen::DSizes reduce_dims; + for (int i = 0; i < dout_dims.size(); ++i) { + dx_reshape_dims[2 * i] = dx_bcast_dims[i]; + dx_reshape_dims[2 * i + 1] = dx_dims[i]; + dy_reshape_dims[2 * i] = dy_bcast_dims[i]; + dy_reshape_dims[2 * i + 1] = dy_dims[i]; + reduce_dims[i] = 2 * i; + } + + auto& place = *ctx.template device_context().eigen_device(); + + if (dx) { + dx->mutable_data(ctx.GetPlace()); + auto eigen_dx = framework::EigenTensor::From(*dx, dx_dims); + auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout; + eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims) + .sum(reduce_dims) + .reshape(eigen_dx.dimensions()); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + auto eigen_dy = framework::EigenTensor::From(*dy, dy_dims); + auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout; + eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims) + .sum(reduce_dims) + .reshape(eigen_dy.dimensions()); + } +} + +template +class LerpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int rank = ctx.Output("Out")->dims().size(); + PADDLE_ENFORCE_GE( + rank, 1, + platform::errors::InvalidArgument( + "The number of dimensions for LerpOp must be " + "greater than or equal to 1, but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, 6, platform::errors::InvalidArgument( + "The number of dimensions for LerpOp must be " + "less than or equal to 6, but the value received is %d.", + rank)); + switch (rank) { + case 1: + LerpFunction(ctx); + break; + case 2: + LerpFunction(ctx); + break; + case 3: + LerpFunction(ctx); + break; + case 4: + LerpFunction(ctx); + break; + case 5: + LerpFunction(ctx); + break; + case 6: + LerpFunction(ctx); + break; + } + } +}; + +template +class LerpGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int rank = ctx.Input(framework::GradVarName("Out")) + ->dims() + .size(); + PADDLE_ENFORCE_GE( + rank, 1, + platform::errors::InvalidArgument( + "The number of dimensions for LerpGradOp must be " + "greater than or equal to 1, but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, 6, platform::errors::InvalidArgument( + "The number of dimensions for LerpGradOp must be " + "less than or equal to 6, but the value received is %d.", + rank)); + switch (rank) { + case 1: + LerpGradFunction(ctx); + break; + case 2: + LerpGradFunction(ctx); + break; + case 3: + LerpGradFunction(ctx); + break; + case 4: + LerpGradFunction(ctx); + break; + case 5: + LerpGradFunction(ctx); + break; + case 6: + LerpGradFunction(ctx); + break; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index a70bd3f81b..44afeecec3 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -224,6 +224,7 @@ from .tensor.math import trunc # noqa: F401 from .tensor.math import digamma # noqa: F401 from .tensor.math import neg # noqa: F401 from .tensor.math import lgamma # noqa: F401 +from .tensor.math import lerp # noqa: F401 from .tensor.math import rad2deg # noqa: F401 from .tensor.math import deg2rad # noqa: F401 from .tensor.math import diff # noqa: F401 @@ -469,6 +470,7 @@ __all__ = [ # noqa 'conj', 'neg', 'lgamma', + 'lerp', 'square', 'divide', 'ceil', diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py new file mode 100644 index 0000000000..ed2e5273df --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -0,0 +1,171 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid.core as core + +paddle.enable_static() +np.random.seed(0) + + +class TestLerp(OpTest): + def setUp(self): + self.op_type = "lerp" + self.init_dtype() + self.init_shape() + x = np.arange(1., 101.).astype(self.dtype).reshape(self.shape) + y = np.full(100, 10.).astype(self.dtype).reshape(self.shape) + w = np.asarray([0.5]).astype(self.dtype) + self.inputs = {'X': x, 'Y': y, 'Weight': w} + self.outputs = {'Out': x + w * (y - x)} + + def init_dtype(self): + self.dtype = np.float64 + + def init_shape(self): + self.shape = [100] + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'Y'], 'Out') + + +class TestLerpWithDim2(TestLerp): + def init_shape(self): + self.shape = [2, 50] + + +class TestLerpWithDim3(TestLerp): + def init_shape(self): + self.shape = [2, 2, 25] + + +class TestLerpWithDim4(TestLerp): + def init_shape(self): + self.shape = [2, 2, 5, 5] + + +class TestLerpWithDim5(TestLerp): + def init_shape(self): + self.shape = [2, 1, 2, 5, 5] + + +class TestLerpWithDim6(TestLerp): + def init_shape(self): + self.shape = [2, 1, 2, 5, 1, 5] + + +class TestLerpAPI(unittest.TestCase): + def init_dtype(self): + self.dtype = 'float32' + + def setUp(self): + self.init_dtype() + self.x = np.arange(1., 5.).astype(self.dtype) + self.y = np.full(4, 10.).astype(self.dtype) + self.w = np.asarray([0.75]).astype(self.dtype) + self.res_ref = self.x + self.w * (self.y - self.x) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('x', [1, 4], dtype=self.dtype) + y = paddle.fluid.data('y', [1, 4], dtype=self.dtype) + w = paddle.fluid.data('w', [1], dtype=self.dtype) + out = paddle.lerp(x, y, w) + exe = paddle.static.Executor(place) + res = exe.run(feed={ + 'x': self.x.reshape([1, 4]), + 'y': self.y.reshape([1, 4]), + 'w': self.w + }) + for r in res: + self.assertEqual(np.allclose(self.res_ref, r), True) + + for place in self.place: + run(place) + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + w = paddle.to_tensor(np.full(4, 0.75).astype(self.dtype)) + out = paddle.lerp(x, y, w) + self.assertEqual(np.allclose(self.res_ref, out.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_inplace_api(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + x.lerp_(y, 0.75) + self.assertEqual(np.allclose(self.res_ref, x.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_inplace_api_exception(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + y = paddle.to_tensor(self.y) + w = paddle.to_tensor([0.75, 0.75], dtype=self.dtype) + with self.assertRaises(ValueError): + x.lerp_(y, w) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_x_broadcast_y(self): + paddle.disable_static() + x = np.arange(1., 21.).astype(self.dtype).reshape([2, 2, 5]) + y = np.full(30, 10.).astype(self.dtype).reshape([3, 2, 1, 5]) + out = paddle.lerp(paddle.to_tensor(x), paddle.to_tensor(y), 0.5) + res_ref = x + 0.5 * (y - x) + self.assertEqual(np.allclose(res_ref, out.numpy()), True) + paddle.enable_static() + + def test_x_y_broadcast_w(self): + paddle.disable_static() + x = np.arange(11., 21.).astype(self.dtype).reshape([2, 5]) + y = np.full(20, 7.5).astype(self.dtype).reshape([2, 2, 5]) + w = np.full(40, 0.225).astype(self.dtype).reshape([2, 2, 2, 5]) + out = paddle.lerp( + paddle.to_tensor(x), paddle.to_tensor(y), paddle.to_tensor(w)) + res_ref = x + w * (y - x) + self.assertEqual(np.allclose(res_ref, out.numpy()), True) + paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 793fdb89d0..53001c0715 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -190,6 +190,8 @@ from .math import digamma # noqa: F401 from .math import neg # noqa: F401 from .math import lgamma # noqa: F401 from .math import diagonal # noqa: F401 +from .math import lerp # noqa: F401 +from .math import lerp_ # noqa: F401 from .math import rad2deg # noqa: F401 from .math import deg2rad # noqa: F401 from .math import diff # noqa: F401 @@ -408,6 +410,8 @@ tensor_method_func = [ #noqa 'solve', 'triangular_solve', 'diff', + 'lerp', + 'lerp_', 'angle', ] diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index df0116c4c2..f705510f84 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2614,6 +2614,68 @@ def atan2(x, y, name=None): type='atan2', inputs=inputs, outputs={'Out': out}) return out +def lerp(x, y, weight, name=None): + r""" + Does a linear interpolation between x and y based on weight. + + Equation: + .. math:: + + lerp(x, y, weight) = x + weight * (y - x). + + Args: + x (Tensor): An N-D Tensor, the data type is float32, float64. + y (Tensor): An N-D Tensor, the data type is float32, float64. + weight (float|Tensor): the weight for the interpolation formula. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor, the shape and data type is the same with input. + + Example: + .. code-block:: python + + import paddle + + x = paddle.arange(1., 5., dtype='float32') + y = paddle.empty([4], dtype='float32') + y.fill_(10.) + out = paddle.lerp(start, end, 0.5) + # out: [5.5., 6., 6.5, 7.] + + """ + if in_dygraph_mode(): + check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp') + if isinstance(weight, float): + weight = paddle.to_tensor(weight, dtype=x.dtype) + return _C_ops.lerp(x, y, weight) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lerp') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'lerp') + check_variable_and_dtype(weight, 'weight', ['float32', 'float64'], 'lerp') + + helper = LayerHelper('lerp', **locals()) + inputs = {'X': x, 'Y': y, 'Weight': weight} + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op(type='lerp', inputs=inputs, outputs={'Out': out}) + return out + +@inplace_apis_in_dygraph_only +def lerp_(x, y, weight, name=None): + r""" + Inplace version of ``lerp`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_tensor_lerp`. + """ + out_shape = broadcast_shape(x.shape, y.shape) + check_type(weight, 'weight', (float, paddle.Tensor, Variable), 'lerp') + if isinstance(weight, float): + weight = paddle.to_tensor([weight], dtype=x.dtype) + elif isinstance(weight, (paddle.Tensor, Variable)): + out_shape = broadcast_shape(out_shape, weight.shape) + if out_shape != x.shape: + raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape)) + return _C_ops.lerp_(x, y, weight) + def rad2deg(x, name=None): """ Convert each of the elements of input x from angles in radians to degrees. -- GitLab