From 8cc09552473b842c651ead3b9848d41827a3dbab Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 11 Jan 2022 20:58:24 +0800 Subject: [PATCH] refactor reshape grad kernel (#38833) --- paddle/fluid/operators/reshape_op.cc | 64 ++++++++++++++---- paddle/pten/core/kernel_alias_name.h | 3 + paddle/pten/kernels/reshape_grad_kernel.cc | 75 ++++++++++++++++++++++ paddle/pten/kernels/reshape_grad_kernel.h | 31 +++++++++ 4 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 paddle/pten/kernels/reshape_grad_kernel.cc create mode 100644 paddle/pten/kernels/reshape_grad_kernel.h diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index f2162f55636..a25e53aac5d 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/include/core.h" +#include "paddle/pten/kernels/reshape_grad_kernel.h" #include "paddle/pten/kernels/reshape_kernel.h" namespace paddle { namespace framework { @@ -467,13 +468,27 @@ class ReshapeGradKernel { void operator()(const framework::ExecutionContext &ctx) const { auto *d_out = ctx.Input(framework::GradVarName("Out")); auto *d_x = ctx.Output(framework::GradVarName("X")); - auto in_dims = d_x->dims(); - d_x->mutable_data(ctx.GetPlace(), d_out->type()); - framework::TensorCopy( - *d_out, ctx.GetPlace(), - ctx.template device_context(), d_x); - d_x->Resize(in_dims); + + auto pt_d_x = paddle::experimental::MakePtenDenseTensor(*d_x); + auto pt_d_out = paddle::experimental::MakePtenDenseTensor(*d_out); + + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get()); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get()); + } +#endif +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeGradKernel(dev_ctx, *pt_d_out.get(), pt_d_x.get()); + } +#endif } }; @@ -482,14 +497,27 @@ class ReshapeDoubleGradKernel { void operator()(const framework::ExecutionContext &ctx) const { auto *dd_x = ctx.Input("DDX"); auto *dd_out = ctx.Output("DDOut"); + dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); - auto out_dims = dd_out->dims(); + auto pt_dd_x = paddle::experimental::MakePtenDenseTensor(*dd_x); + auto pt_dd_out = paddle::experimental::MakePtenDenseTensor(*dd_out); - dd_out->mutable_data(ctx.GetPlace(), dd_x->type()); - framework::TensorCopy( - *dd_x, ctx.GetPlace(), - ctx.template device_context(), dd_out); - dd_out->Resize(out_dims); + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get()); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get()); + } +#endif +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeDoubleGradKernel(dev_ctx, *pt_dd_x.get(), pt_dd_out.get()); + } +#endif } }; @@ -624,6 +652,13 @@ class Reshape2GradOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + return framework::KernelSignature("reshape_grad", + {framework::GradVarName("Out")}, {}, + {framework::GradVarName("X")}); + } }; class Reshape2DoubleGradOp : public framework::OperatorWithKernel { @@ -660,6 +695,11 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { return framework::OpKernelType(expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + return framework::KernelSignature("reshape_double_grad", {"DDX"}, {}, + {"DDOut"}); + } }; DECLARE_INPLACE_OP_INFERER(ReshapeOpInplaceInferer, {"X", "Out"}); diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 46fa6dd376e..5c867879663 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -35,6 +35,8 @@ const std::unordered_map kernel_alias_name_map = { {"reduce_mean", "mean"}, {"reduce_sum", "sum"}, {"reshape2", "reshape"}, + {"reshape2_grad", "reshape_grad"}, + {"reshape2_grad_grad", "reshape_double_grad"}, // fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated {"flatten", "deprecated"}, {"flatten_grad", "deprecated"}, @@ -43,6 +45,7 @@ const std::unordered_map kernel_alias_name_map = { {"matmul_grad_grad", "deprecated"}, {"mean", "deprecated"}, {"reshape", "deprecated"}, + {"reshape_grad", "deprecated"}, {"sum", "deprecated"}}; } // namespace pten diff --git a/paddle/pten/kernels/reshape_grad_kernel.cc b/paddle/pten/kernels/reshape_grad_kernel.cc new file mode 100644 index 00000000000..99f0556765e --- /dev/null +++ b/paddle/pten/kernels/reshape_grad_kernel.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/reshape_grad_kernel.h" +#include "paddle/pten/backends/all_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/copy_kernel.h" + +namespace pten { + +template +void ReshapeGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + DenseTensor* x_grad) { + auto x_dims = x_grad->dims(); + pten::Copy(dev_ctx, out_grad, false, x_grad); + x_grad->Resize(x_dims); +} + +template +void ReshapeDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x_grad_grad, + DenseTensor* out_grad_grad) { + ReshapeGradKernel(dev_ctx, x_grad_grad, out_grad_grad); +} + +} // namespace pten + +PT_REGISTER_GENERAL_KERNEL(reshape_grad, + CPU, + ALL_LAYOUT, + pten::ReshapeGradKernel, + ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape_double_grad, + CPU, + ALL_LAYOUT, + pten::ReshapeDoubleGradKernel, + ALL_DTYPE) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PT_REGISTER_GENERAL_KERNEL(reshape_grad, + GPU, + ALL_LAYOUT, + pten::ReshapeGradKernel, + ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape_double_grad, + GPU, + ALL_LAYOUT, + pten::ReshapeDoubleGradKernel, + ALL_DTYPE) {} +#endif + +#ifdef PADDLE_WITH_XPU +PT_REGISTER_GENERAL_KERNEL(reshape_grad, + XPU, + ALL_LAYOUT, + pten::ReshapeGradKernel, + ALL_DTYPE) {} +PT_REGISTER_GENERAL_KERNEL(reshape_double_grad, + XPU, + ALL_LAYOUT, + pten::ReshapeDoubleGradKernel, + ALL_DTYPE) {} +#endif diff --git a/paddle/pten/kernels/reshape_grad_kernel.h b/paddle/pten/kernels/reshape_grad_kernel.h new file mode 100644 index 00000000000..1492d753704 --- /dev/null +++ b/paddle/pten/kernels/reshape_grad_kernel.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { + +template +void ReshapeGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + DenseTensor* x_grad); + +template +void ReshapeDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x_grad_grad, + DenseTensor* out_grad_grad); + +} // namespace pten -- GitLab