From d480d7b12018871422accff1b42e173ccde8581f Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Wed, 16 Feb 2022 14:39:20 +0800 Subject: [PATCH] Move lerp OP to pten (#39524) * move lerp to pten * refine include * move files * refine code --- paddle/fluid/operators/lerp_op.cc | 14 +- paddle/fluid/operators/lerp_op.h | 217 ------------------ paddle/pten/kernels/cpu/lerp_grad_kernel.cc | 21 ++ paddle/pten/kernels/cpu/lerp_kernel.cc | 20 ++ paddle/pten/kernels/funcs/common_shape.h | 25 ++ .../kernels/gpu/lerp_grad_kernel.cu} | 18 +- paddle/pten/kernels/gpu/lerp_kernel.cu | 20 ++ .../pten/kernels/impl/lerp_grad_kernel_impl.h | 133 +++++++++++ paddle/pten/kernels/impl/lerp_kernel_impl.h | 97 ++++++++ paddle/pten/kernels/lerp_grad_kernel.h | 31 +++ paddle/pten/kernels/lerp_kernel.h | 28 +++ paddle/pten/ops/compat/lerp_sig.cc | 33 +++ 12 files changed, 415 insertions(+), 242 deletions(-) delete mode 100644 paddle/fluid/operators/lerp_op.h create mode 100644 paddle/pten/kernels/cpu/lerp_grad_kernel.cc create mode 100644 paddle/pten/kernels/cpu/lerp_kernel.cc rename paddle/{fluid/operators/lerp_op.cu => pten/kernels/gpu/lerp_grad_kernel.cu} (54%) create mode 100644 paddle/pten/kernels/gpu/lerp_kernel.cu create mode 100644 paddle/pten/kernels/impl/lerp_grad_kernel_impl.h create mode 100644 paddle/pten/kernels/impl/lerp_kernel_impl.h create mode 100644 paddle/pten/kernels/lerp_grad_kernel.h create mode 100644 paddle/pten/kernels/lerp_kernel.h create mode 100644 paddle/pten/ops/compat/lerp_sig.cc diff --git a/paddle/fluid/operators/lerp_op.cc b/paddle/fluid/operators/lerp_op.cc index b94182e9db..b5e2b0d776 100644 --- a/paddle/fluid/operators/lerp_op.cc +++ b/paddle/fluid/operators/lerp_op.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/lerp_op.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -132,15 +132,3 @@ REGISTER_OPERATOR( paddle::operators::LerpInplaceInferer); REGISTER_OPERATOR(lerp_grad, paddle::operators::LerpGradOp); - -REGISTER_OP_CPU_KERNEL( - lerp, - paddle::operators::LerpKernel, - paddle::operators::LerpKernel); - -REGISTER_OP_CPU_KERNEL( - lerp_grad, - paddle::operators::LerpGradKernel, - paddle::operators::LerpGradKernel); diff --git a/paddle/fluid/operators/lerp_op.h b/paddle/fluid/operators/lerp_op.h deleted file mode 100644 index 380a8ccffd..0000000000 --- a/paddle/fluid/operators/lerp_op.h +++ /dev/null @@ -1,217 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -#ifdef _WIN32 -#ifndef NOMINMAX -#define NOMINMAX // msvc max/min macro conflict with std::min/max -#endif -#endif - -namespace paddle { -namespace operators { - -static framework::DDim ExtendDims2Rank(const framework::DDim& in_dims, - int rank) { - if (in_dims.size() == rank) { - return in_dims; - } - std::vector shapes(rank, 1); - for (int i = in_dims.size() - 1, j = rank - 1; i >= 0; --i, --j) { - shapes[j] = in_dims[i]; - } - return framework::make_ddim(shapes); -} - -template -static void GetBroadcastDims(const framework::DDim& in_dims, - const framework::DDim& out_dims, - Eigen::DSizes* bcast_dims) { - for (size_t i = 0; i < D; ++i) { - if (in_dims[i] == out_dims[i]) { - (*bcast_dims)[i] = 1; - } else { - (*bcast_dims)[i] = std::max(in_dims[i], out_dims[i]); - } - } -} - -template -static void LerpFunction(const framework::ExecutionContext& ctx) { - auto x = ctx.Input("X"); - auto y = ctx.Input("Y"); - auto w = ctx.Input("Weight"); - auto out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - auto out_dims = out->dims(); - auto x_dims = ExtendDims2Rank(x->dims(), D); - auto y_dims = ExtendDims2Rank(y->dims(), D); - auto w_dims = ExtendDims2Rank(w->dims(), D); - Eigen::DSizes x_bcast_dims; - Eigen::DSizes y_bcast_dims; - Eigen::DSizes w_bcast_dims; - GetBroadcastDims(x_dims, out_dims, &x_bcast_dims); - GetBroadcastDims(y_dims, out_dims, &y_bcast_dims); - GetBroadcastDims(w_dims, out_dims, &w_bcast_dims); - - auto eigen_x = framework::EigenTensor::From(*x, x_dims); - auto eigen_y = framework::EigenTensor::From(*y, y_dims); - auto eigen_w = framework::EigenTensor::From(*w, w_dims); - auto eigen_out = framework::EigenTensor::From(*out); - - auto& place = *ctx.template device_context().eigen_device(); - eigen_out.device(place) = - eigen_x.broadcast(x_bcast_dims) + - eigen_w.broadcast(w_bcast_dims) * - (eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims)); -} - -template -static void LerpGradFunction(const framework::ExecutionContext& ctx) { - auto w = ctx.Input("Weight"); - auto dout = ctx.Input(framework::GradVarName("Out")); - auto dx = ctx.Output(framework::GradVarName("X")); - auto dy = ctx.Output(framework::GradVarName("Y")); - - auto dout_dims = dout->dims(); - auto dx_dims = ExtendDims2Rank(dx->dims(), D); - auto dy_dims = ExtendDims2Rank(dy->dims(), D); - auto w_dims = ExtendDims2Rank(w->dims(), D); - Eigen::DSizes dx_bcast_dims; - Eigen::DSizes dy_bcast_dims; - Eigen::DSizes w_bcast_dims; - GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); - GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); - GetBroadcastDims(w_dims, dout_dims, &w_bcast_dims); - - auto eigen_w = framework::EigenTensor::From(*w, w_dims); - auto eigen_dout = framework::EigenTensor::From(*dout); - - Eigen::DSizes dx_reshape_dims; - Eigen::DSizes dy_reshape_dims; - Eigen::DSizes reduce_dims; - for (int i = 0; i < dout_dims.size(); ++i) { - dx_reshape_dims[2 * i] = dx_bcast_dims[i]; - dx_reshape_dims[2 * i + 1] = dx_dims[i]; - dy_reshape_dims[2 * i] = dy_bcast_dims[i]; - dy_reshape_dims[2 * i + 1] = dy_dims[i]; - reduce_dims[i] = 2 * i; - } - - auto& place = *ctx.template device_context().eigen_device(); - - if (dx) { - dx->mutable_data(ctx.GetPlace()); - auto eigen_dx = framework::EigenTensor::From(*dx, dx_dims); - auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout; - eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims) - .sum(reduce_dims) - .reshape(eigen_dx.dimensions()); - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - auto eigen_dy = framework::EigenTensor::From(*dy, dy_dims); - auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout; - eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims) - .sum(reduce_dims) - .reshape(eigen_dy.dimensions()); - } -} - -template -class LerpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - int rank = ctx.Output("Out")->dims().size(); - PADDLE_ENFORCE_GE( - rank, 1, - platform::errors::InvalidArgument( - "The number of dimensions for LerpOp must be " - "greater than or equal to 1, but the value received is %d.", - rank)); - PADDLE_ENFORCE_LE( - rank, 6, platform::errors::InvalidArgument( - "The number of dimensions for LerpOp must be " - "less than or equal to 6, but the value received is %d.", - rank)); - switch (rank) { - case 1: - LerpFunction(ctx); - break; - case 2: - LerpFunction(ctx); - break; - case 3: - LerpFunction(ctx); - break; - case 4: - LerpFunction(ctx); - break; - case 5: - LerpFunction(ctx); - break; - case 6: - LerpFunction(ctx); - break; - } - } -}; - -template -class LerpGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - int rank = ctx.Input(framework::GradVarName("Out")) - ->dims() - .size(); - PADDLE_ENFORCE_GE( - rank, 1, - platform::errors::InvalidArgument( - "The number of dimensions for LerpGradOp must be " - "greater than or equal to 1, but the value received is %d.", - rank)); - PADDLE_ENFORCE_LE( - rank, 6, platform::errors::InvalidArgument( - "The number of dimensions for LerpGradOp must be " - "less than or equal to 6, but the value received is %d.", - rank)); - switch (rank) { - case 1: - LerpGradFunction(ctx); - break; - case 2: - LerpGradFunction(ctx); - break; - case 3: - LerpGradFunction(ctx); - break; - case 4: - LerpGradFunction(ctx); - break; - case 5: - LerpGradFunction(ctx); - break; - case 6: - LerpGradFunction(ctx); - break; - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/pten/kernels/cpu/lerp_grad_kernel.cc b/paddle/pten/kernels/cpu/lerp_grad_kernel.cc new file mode 100644 index 0000000000..4aac143eb1 --- /dev/null +++ b/paddle/pten/kernels/cpu/lerp_grad_kernel.cc @@ -0,0 +1,21 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/lerp_grad_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/lerp_grad_kernel_impl.h" + +PT_REGISTER_KERNEL( + lerp_grad, CPU, ALL_LAYOUT, pten::LerpGradKernel, float, double) {} diff --git a/paddle/pten/kernels/cpu/lerp_kernel.cc b/paddle/pten/kernels/cpu/lerp_kernel.cc new file mode 100644 index 0000000000..9f8513065c --- /dev/null +++ b/paddle/pten/kernels/cpu/lerp_kernel.cc @@ -0,0 +1,20 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/lerp_kernel.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/lerp_kernel_impl.h" + +PT_REGISTER_KERNEL(lerp, CPU, ALL_LAYOUT, pten::LerpKernel, float, double) {} diff --git a/paddle/pten/kernels/funcs/common_shape.h b/paddle/pten/kernels/funcs/common_shape.h index 9a96a5fd45..e751f85b50 100644 --- a/paddle/pten/kernels/funcs/common_shape.h +++ b/paddle/pten/kernels/funcs/common_shape.h @@ -102,5 +102,30 @@ inline void GetPrePostNumel( } } +static framework::DDim ExtendDims2Rank(const framework::DDim &in_dims, + int rank) { + if (in_dims.size() == rank) { + return in_dims; + } + std::vector shapes(rank, 1); + for (int i = in_dims.size() - 1, j = rank - 1; i >= 0; --i, --j) { + shapes[j] = in_dims[i]; + } + return framework::make_ddim(shapes); +} + +template +static void GetBroadcastDims(const framework::DDim &in_dims, + const framework::DDim &out_dims, + Eigen::DSizes *bcast_dims) { + for (size_t i = 0; i < D; ++i) { + if (in_dims[i] == out_dims[i]) { + (*bcast_dims)[i] = 1; + } else { + (*bcast_dims)[i] = std::max(in_dims[i], out_dims[i]); + } + } +} + } // namespace funcs } // namespace pten diff --git a/paddle/fluid/operators/lerp_op.cu b/paddle/pten/kernels/gpu/lerp_grad_kernel.cu similarity index 54% rename from paddle/fluid/operators/lerp_op.cu rename to paddle/pten/kernels/gpu/lerp_grad_kernel.cu index 6f7d8b744d..30fdb1206f 100644 --- a/paddle/fluid/operators/lerp_op.cu +++ b/paddle/pten/kernels/gpu/lerp_grad_kernel.cu @@ -12,16 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/lerp_op.h" +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/lerp_grad_kernel_impl.h" +#include "paddle/pten/kernels/lerp_grad_kernel.h" -REGISTER_OP_CUDA_KERNEL( - lerp, - paddle::operators::LerpKernel, - paddle::operators::LerpKernel); - -REGISTER_OP_CUDA_KERNEL( - lerp_grad, - paddle::operators::LerpGradKernel, - paddle::operators::LerpGradKernel); +PT_REGISTER_KERNEL( + lerp_grad, GPU, ALL_LAYOUT, pten::LerpGradKernel, float, double) {} diff --git a/paddle/pten/kernels/gpu/lerp_kernel.cu b/paddle/pten/kernels/gpu/lerp_kernel.cu new file mode 100644 index 0000000000..8743cb12e4 --- /dev/null +++ b/paddle/pten/kernels/gpu/lerp_kernel.cu @@ -0,0 +1,20 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/impl/lerp_kernel_impl.h" +#include "paddle/pten/kernels/lerp_kernel.h" + +PT_REGISTER_KERNEL(lerp, GPU, ALL_LAYOUT, pten::LerpKernel, float, double) {} diff --git a/paddle/pten/kernels/impl/lerp_grad_kernel_impl.h b/paddle/pten/kernels/impl/lerp_grad_kernel_impl.h new file mode 100644 index 0000000000..5285c69e39 --- /dev/null +++ b/paddle/pten/kernels/impl/lerp_grad_kernel_impl.h @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/kernels/funcs/common_shape.h" +#include "paddle/pten/kernels/funcs/eigen/common.h" + +namespace pten { + +template +static void LerpGradFunction(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + auto& w = weight; + auto& dout = out_grad; + auto* dx = x_grad; + auto* dy = y_grad; + + auto dout_dims = dout.dims(); + auto dx_dims = pten::funcs::ExtendDims2Rank(dx->dims(), D); + auto dy_dims = pten::funcs::ExtendDims2Rank(dy->dims(), D); + auto w_dims = pten::funcs::ExtendDims2Rank(w.dims(), D); + Eigen::DSizes dx_bcast_dims; + Eigen::DSizes dy_bcast_dims; + Eigen::DSizes w_bcast_dims; + pten::funcs::GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); + pten::funcs::GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); + pten::funcs::GetBroadcastDims(w_dims, dout_dims, &w_bcast_dims); + + auto eigen_w = pten::EigenTensor::From(w, w_dims); + auto eigen_dout = pten::EigenTensor::From(dout); + + Eigen::DSizes dx_reshape_dims; + Eigen::DSizes dy_reshape_dims; + Eigen::DSizes reduce_dims; + for (int i = 0; i < dout_dims.size(); ++i) { + dx_reshape_dims[2 * i] = dx_bcast_dims[i]; + dx_reshape_dims[2 * i + 1] = dx_dims[i]; + dy_reshape_dims[2 * i] = dy_bcast_dims[i]; + dy_reshape_dims[2 * i + 1] = dy_dims[i]; + reduce_dims[i] = 2 * i; + } + + auto& place = *ctx.eigen_device(); + + if (dx) { + ctx.template Alloc(dx); + auto eigen_dx = pten::EigenTensor::From(*dx, dx_dims); + auto eigen_expr = (1 - eigen_w.broadcast(w_bcast_dims)) * eigen_dout; + eigen_dx.device(place) = eigen_expr.reshape(dx_reshape_dims) + .sum(reduce_dims) + .reshape(eigen_dx.dimensions()); + } + if (dy) { + ctx.template Alloc(dy); + auto eigen_dy = pten::EigenTensor::From(*dy, dy_dims); + auto eigen_expr = eigen_w.broadcast(w_bcast_dims) * eigen_dout; + eigen_dy.device(place) = eigen_expr.reshape(dy_reshape_dims) + .sum(reduce_dims) + .reshape(eigen_dy.dimensions()); + } +} + +template +void LerpGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + int rank = out.dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1, + pten::errors::InvalidArgument( + "The number of dimensions for LerpGradOp must be " + "greater than or equal to 1, but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, + 6, + pten::errors::InvalidArgument( + "The number of dimensions for LerpGradOp must be " + "less than or equal to 6, but the value received is %d.", + rank)); + switch (rank) { + case 1: + LerpGradFunction( + ctx, x, y, weight, out, out_grad, x_grad, y_grad); + break; + case 2: + LerpGradFunction( + ctx, x, y, weight, out, out_grad, x_grad, y_grad); + break; + case 3: + LerpGradFunction( + ctx, x, y, weight, out, out_grad, x_grad, y_grad); + break; + case 4: + LerpGradFunction( + ctx, x, y, weight, out, out_grad, x_grad, y_grad); + break; + case 5: + LerpGradFunction( + ctx, x, y, weight, out, out_grad, x_grad, y_grad); + break; + case 6: + LerpGradFunction( + ctx, x, y, weight, out, out_grad, x_grad, y_grad); + break; + } +} + +} // namespace pten diff --git a/paddle/pten/kernels/impl/lerp_kernel_impl.h b/paddle/pten/kernels/impl/lerp_kernel_impl.h new file mode 100644 index 0000000000..127e3e50a3 --- /dev/null +++ b/paddle/pten/kernels/impl/lerp_kernel_impl.h @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/kernels/funcs/common_shape.h" +#include "paddle/pten/kernels/funcs/eigen/common.h" + +namespace pten { + +template +static void LerpFunction(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + DenseTensor* out) { + ctx.template Alloc(out); + + auto out_dims = out->dims(); + auto x_dims = pten::funcs::ExtendDims2Rank(x.dims(), D); + auto y_dims = pten::funcs::ExtendDims2Rank(y.dims(), D); + auto w_dims = pten::funcs::ExtendDims2Rank(weight.dims(), D); + Eigen::DSizes x_bcast_dims; + Eigen::DSizes y_bcast_dims; + Eigen::DSizes w_bcast_dims; + pten::funcs::GetBroadcastDims(x_dims, out_dims, &x_bcast_dims); + pten::funcs::GetBroadcastDims(y_dims, out_dims, &y_bcast_dims); + pten::funcs::GetBroadcastDims(w_dims, out_dims, &w_bcast_dims); + + auto eigen_x = pten::EigenTensor::From(x, x_dims); + auto eigen_y = pten::EigenTensor::From(y, y_dims); + auto eigen_w = pten::EigenTensor::From(weight, w_dims); + auto eigen_out = pten::EigenTensor::From(*out); + + auto& place = *ctx.eigen_device(); + eigen_out.device(place) = + eigen_x.broadcast(x_bcast_dims) + + eigen_w.broadcast(w_bcast_dims) * + (eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims)); +} + +template +void LerpKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + DenseTensor* out) { + int rank = out->dims().size(); + PADDLE_ENFORCE_GE( + rank, + 1, + pten::errors::InvalidArgument( + "The number of dimensions for LerpOp must be " + "greater than or equal to 1, but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, + 6, + pten::errors::InvalidArgument( + "The number of dimensions for LerpOp must be " + "less than or equal to 6, but the value received is %d.", + rank)); + switch (rank) { + case 1: + LerpFunction(ctx, x, y, weight, out); + break; + case 2: + LerpFunction(ctx, x, y, weight, out); + break; + case 3: + LerpFunction(ctx, x, y, weight, out); + break; + case 4: + LerpFunction(ctx, x, y, weight, out); + break; + case 5: + LerpFunction(ctx, x, y, weight, out); + break; + case 6: + LerpFunction(ctx, x, y, weight, out); + break; + } +} + +} // namespace pten diff --git a/paddle/pten/kernels/lerp_grad_kernel.h b/paddle/pten/kernels/lerp_grad_kernel.h new file mode 100644 index 0000000000..18a38e7245 --- /dev/null +++ b/paddle/pten/kernels/lerp_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void LerpGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + const DenseTensor& out, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace pten diff --git a/paddle/pten/kernels/lerp_kernel.h b/paddle/pten/kernels/lerp_kernel.h new file mode 100644 index 0000000000..8e70a53c06 --- /dev/null +++ b/paddle/pten/kernels/lerp_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void LerpKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& weight, + DenseTensor* out); + +} // namespace pten diff --git a/paddle/pten/ops/compat/lerp_sig.cc b/paddle/pten/ops/compat/lerp_sig.cc new file mode 100644 index 0000000000..d225ff2bfd --- /dev/null +++ b/paddle/pten/ops/compat/lerp_sig.cc @@ -0,0 +1,33 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature LerpOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("lerp", {"X", "Y", "Weight"}, {}, {"Out"}); +} + +KernelSignature LerpGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("lerp_grad", + {"X", "Y", "Weight", "Out", GradVarName("Out")}, + {}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(lerp, pten::LerpOpArgumentMapping); +PT_REGISTER_ARG_MAPPING_FN(lerp_grad, pten::LerpGradOpArgumentMapping); -- GitLab