From 3bf3a6ee9ff91f3f6e054bb7161e7dea27db5055 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Mon, 24 Jan 2022 20:16:20 +0800 Subject: [PATCH] [Pten]Refactor elementwise_add grad / double grad / triple grad Kernel and move them to pten (#39048) * refactor elementwise add grad * fix compile bugs * fix unit test bugs * fix file conflicts * fix bugs when buildPtenContext --- paddle/fluid/imperative/prepared_operator.cc | 4 + .../elementwise/elementwise_add_op.cc | 28 --- .../elementwise/elementwise_add_op.cu | 130 +----------- .../elementwise/elementwise_add_op.h | 195 +++--------------- .../operators/elementwise/elementwise_op.h | 20 +- .../elementwise/elementwise_op_function.h | 74 +------ .../elementwise/elementwise_sub_op.h | 8 +- paddle/fluid/operators/math/math_function.cc | 13 ++ paddle/pten/core/kernel_alias_name.h | 1 + paddle/pten/kernels/CMakeLists.txt | 2 +- paddle/pten/kernels/cpu/elementwise.h | 90 ++++++++ .../kernels/cpu/elementwise_grad_kernel.cc | 128 ++++++++++++ paddle/pten/kernels/elementwise_grad_kernel.h | 49 +++++ paddle/pten/kernels/funcs/elementwise_base.h | 46 +++++ paddle/pten/kernels/gpu/elementwise.h | 142 +++++++++++++ .../kernels/gpu/elementwise_grad_kernel.cu | 121 +++++++++++ .../impl/elementwise_grad_kernel_impl.h | 88 ++++++++ 17 files changed, 743 insertions(+), 396 deletions(-) create mode 100644 paddle/pten/kernels/cpu/elementwise_grad_kernel.cc create mode 100644 paddle/pten/kernels/elementwise_grad_kernel.h create mode 100644 paddle/pten/kernels/gpu/elementwise_grad_kernel.cu create mode 100644 paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index fe60f05e1d..5f144c08d2 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -369,6 +369,10 @@ static void BuildDygraphPtenKernelContext( size_t end_idx = start_idx + outs_vector.size(); for (size_t offset = 0; offset < outs_vector.size(); ++offset) { + if (outs_vector[offset] == nullptr) { + kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr}); + continue; + } auto* var = outs_vector[offset]->MutableVar(); framework::Tensor* tensor_out = nullptr; if (var->template IsType()) { diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index 0c2476fde0..f462c2ea07 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -33,34 +33,6 @@ class CPUDeviceContext; namespace paddle { namespace operators { -template -struct SameDimsElemwiseAdd< - platform::CPUDeviceContext, T, - typename std::enable_if::value>::type> { - void operator()(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - framework::Tensor *z) { - auto blas = math::GetBlas(ctx); - blas.VADD(x->numel(), x->data(), y->data(), z->data()); - } -}; - -template -struct SameDimsElemwiseAdd< - platform::CPUDeviceContext, T, - typename std::enable_if::value>::type> { - void operator()(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - framework::Tensor *z) { - auto eigen_x = framework::EigenVector::Flatten(*x); - auto eigen_y = framework::EigenVector::Flatten(*y); - auto eigen_z = framework::EigenVector::Flatten(*z); - auto &place = *ctx.template device_context() - .eigen_device(); - eigen_z.device(place) = eigen_x + eigen_y; - } -}; - class ElementwiseAddOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Add"; } diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 779779b44d..2326aa561e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -13,139 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/pten/kernels/gpu/elementwise.h" namespace ops = paddle::operators; namespace plat = paddle::platform; namespace paddle { -namespace operators { - -template -static __global__ void SimpleElemwiseAddGradCUDAKernel( - const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = gridDim.x * blockDim.x; - int loop = size / vec_size; - int remainder = size % vec_size; - const float4* dout_vec = reinterpret_cast(dout); - float4* dx_vec = reinterpret_cast(dx); - float4* dy_vec = reinterpret_cast(dy); - float4 tmp_loop; - - for (int i = tid; i < loop; i += stride) { - tmp_loop = dout_vec[i]; - dx_vec[i] = tmp_loop; - dy_vec[i] = tmp_loop; - } - - if (tid == loop && remainder != 0) { - T tmp_rem; - while (remainder) { - int idx = size - remainder; - remainder--; - tmp_rem = dout[idx]; - dx[idx] = tmp_rem; - dy[idx] = tmp_rem; - } - } -} - -template -typename std::enable_if< - std::is_same::value>::type -default_elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - auto* dout_data = dout->data(); - - // dx - if (dx != nullptr) { - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - if (dx->dims() == dout->dims()) { - if (dx_data != dout_data) { - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } - } else { - // For inplace strategy, dx will be stored in addr of dout, which makes - // the result of dy wrong. - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), ctx.GetPlace()); - } - std::vector reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); - gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl>( - *dout, dx, kps::IdentityFunctor(), reduce_dims, stream); - } - } - // dy - if (dy != nullptr) { - auto* dy_data = dy->mutable_data(ctx.GetPlace()); - if (dy->dims() == dout->dims()) { - if (dy_data != dout_data) { - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dy); - } - } else { - std::vector reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); - gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl>( - *dout, dy, kps::IdentityFunctor(), reduce_dims, stream); - } - } -} - -template -typename std::enable_if< - std::is_same::value>::type -elementwise_add_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dy_data = dy->mutable_data(ctx.GetPlace()); - auto* dout_data = dout->data(); - if (dx_data == dout_data && dy_data != dout_data) { - VLOG(4) << "Special case when dx_data is the same as dout_data, " - "only need copy dout to dy"; - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dy); - } else if (dx_data != dout_data && dy_data == dout_data) { - VLOG(4) << "Special case when dy_data is the same as dout_data, " - "only need copy dout to dx"; - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } else if (dx_data != dout_data && dy_data != dout_data) { - auto size = x->numel(); - int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); - dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); - dim3 grid_size = - dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) / - PREDEFINED_BLOCK_SIZE, - 1); - SimpleElemwiseAddGradCUDAKernel< - T><<().stream()>>>( - dout->data(), size, vec_size, dx->mutable_data(ctx.GetPlace()), - dy->mutable_data(ctx.GetPlace())); - } else { - VLOG(4) << "Special case when dy_data is the same as dout_data, " - "and dx_data is the same as dout_data, do not need " - "any operator"; - } -} - -} // namespace operators +namespace operators {} // namespace operators } // namespace paddle REGISTER_OP_CUDA_KERNEL( elementwise_add, ops::ElementwiseAddKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index 5c4f791b22..73415d3fdb 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -18,35 +18,13 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" +// only can include the headers in paddle/pten/include dirs +#include "paddle/pten/kernels/elementwise_grad_kernel.h" #include "paddle/pten/kernels/math_kernel.h" namespace paddle { namespace operators { -template -void LaunchBroadcastElementwiseCpuKernel(const framework::ExecutionContext &ctx, - const framework::Tensor *x, - const framework::Tensor *y, - framework::Tensor *z) { - int axis = ctx.Attr("axis"); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - AddFunctor(), z); - } else { - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, InverseAddFunctor(), z); - } -} - -template -struct SameDimsElemwiseAdd { - void operator()(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - framework::Tensor *z); -}; - template class ElementwiseAddKernel : public framework::OpKernel { public: @@ -58,128 +36,29 @@ class ElementwiseAddKernel : public framework::OpKernel { auto &dev_ctx = ctx.device_context(); int axis = ctx.Attr("axis"); - auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); - auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); - auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); pten::AddRawKernel( static_cast::TYPE &>(dev_ctx), - *pt_x.get(), *pt_y.get(), axis, pt_z.get()); + *x, *y, axis, z); } }; -template -struct IdentityGrad { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } -}; - -template -typename std::enable_if< - std::is_same::value>::type -default_elementwise_add_grad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, - const framework::Tensor *y, - const framework::Tensor *out, - const framework::Tensor *dout, - framework::Tensor *dx, framework::Tensor *dy) { - int axis = ctx.Attr("axis"); - - ElemwiseExplicitGradCompute, - IdentityGrad>(ctx, *x, *y, *out, *dout, axis, - dx, dy, IdentityGrad(), - IdentityGrad()); -} - -template -typename std::enable_if< - std::is_floating_point::value && - std::is_same::value>::type -elementwise_add_grad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - const framework::Tensor *out, - const framework::Tensor *dout, framework::Tensor *dx, - framework::Tensor *dy) { - auto blas = math::GetBlas(ctx); - if (dx) { - blas.VCOPY(dout->numel(), dout->data(), - dx->mutable_data(ctx.GetPlace())); - } - - if (dy) { - blas.VCOPY(dout->numel(), dout->data(), - dy->mutable_data(ctx.GetPlace())); - } -} - -template -typename std::enable_if< - !std::is_floating_point::value && - std::is_same::value>::type -elementwise_add_grad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - const framework::Tensor *out, - const framework::Tensor *dout, framework::Tensor *dx, - framework::Tensor *dy) { - default_elementwise_add_grad(ctx, x, y, out, dout, dx, dy); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -// cuda definition -template -typename std::enable_if< - std::is_same::value>::type -elementwise_add_grad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, const framework::Tensor *y, - const framework::Tensor *out, - const framework::Tensor *dout, framework::Tensor *dx, - framework::Tensor *dy); - -template -typename std::enable_if< - std::is_same::value>::type -default_elementwise_add_grad(const framework::ExecutionContext &ctx, - const framework::Tensor *x, - const framework::Tensor *y, - const framework::Tensor *out, - const framework::Tensor *dout, - framework::Tensor *dx, framework::Tensor *dy); -#endif - template class ElementwiseAddGradKernel : public ElemwiseGradKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *dout = ctx.Input(framework::GradVarName("Out")); auto *dx = ctx.Output(framework::GradVarName("X")); auto *dy = ctx.Output(framework::GradVarName("Y")); - // skip out - auto *out = dout; - - // Special case when dy is not needed and dx doesn't reduce - if (dx != nullptr && dy == nullptr && dx->dims() == dout->dims()) { - VLOG(4) << "Special case when dy is not needed and dx doesn't " - "reduce"; - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } else if (dx == nullptr && dy != nullptr && dy->dims() == dout->dims()) { - VLOG(4) << "Special case when dx is not needed and dy doesn't " - "reduce"; - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dy); - } else if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { - elementwise_add_grad(ctx, x, y, out, dout, dx, dy); - } else { - default_elementwise_add_grad(ctx, x, y, out, dout, dx, - dy); - } + const auto &dev_ctx = ctx.template device_context(); + int axis = ctx.Attr("axis"); + pten::AddGradKernel( + static_cast::TYPE &>(dev_ctx), + *x, *y, *dout, axis, dx, dy); } }; @@ -195,17 +74,20 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel { auto *ddy = ctx.Input("DDY"); auto *ddout = ctx.Output("DDOut"); - - // ddOut = ddx + ddy - if (ddout) { - Tensor ddx_safe, ddy_safe; - GetDoubleGradSafeTensor(ctx, dout, ddx, &ddx_safe); - GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); - - ddout->mutable_data(ctx.GetPlace()); - LaunchBroadcastElementwiseCpuKernel(ctx, &ddx_safe, - &ddy_safe, ddout); + const auto &dev_ctx = ctx.template device_context(); + int axis = ctx.Attr("axis"); + paddle::optional ddx_optional = paddle::none; + paddle::optional ddy_optional = paddle::none; + if (ddx != nullptr) { + ddx_optional = *ddx; + } + if (ddy != nullptr) { + ddy_optional = *ddy; } + pten::AddDoubleGradKernel( + static_cast::TYPE &>(dev_ctx), + *y, ddx_optional, ddy_optional, *dout, axis, ddout); } }; @@ -219,32 +101,13 @@ class ElementwiseAddTripleGradKernel : public framework::OpKernel { auto *d_ddout = ctx.Input("D_DDOut"); auto *d_ddx = ctx.Output("D_DDX"); auto *d_ddy = ctx.Output("D_DDY"); - // skip out - auto *out = d_ddout; - - // Special case when d_ddy is not needed and d_ddx doesn't reduce - if (d_ddx != nullptr && d_ddy == nullptr && - d_ddx->dims() == d_ddout->dims()) { - VLOG(4) << "Special case when d_ddy is not needed and d_ddx doesn't " - "reduce"; - framework::TensorCopy( - *d_ddout, ctx.GetPlace(), - ctx.template device_context(), d_ddx); - } else if (d_ddx == nullptr && d_ddy != nullptr && - d_ddy->dims() == d_ddout->dims()) { - VLOG(4) << "Special case when d_ddx is not needed and d_ddy doesn't " - "reduce"; - framework::TensorCopy( - *d_ddout, ctx.GetPlace(), - ctx.template device_context(), d_ddy); - } else if (d_ddx != nullptr && d_ddy != nullptr && - (d_ddx->dims() == d_ddy->dims())) { - elementwise_add_grad(ctx, ddx, ddy, out, d_ddout, d_ddx, - d_ddy); - } else { - default_elementwise_add_grad(ctx, ddx, ddy, out, - d_ddout, d_ddx, d_ddy); - } + + const auto &dev_ctx = ctx.template device_context(); + int axis = ctx.Attr("axis"); + pten::AddTripleGradKernel( + static_cast::TYPE &>(dev_ctx), + *ddx, *ddy, *d_ddout, axis, d_ddx, d_ddy); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index aaf33ca674..64beac0804 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -354,6 +354,18 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { tensor.place(), tensor.layout()); } } + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + if (Type() == "elementwise_add_grad") { + if (ctx.InputVar("X")->IsType()) { + return framework::KernelSignature( + "add_grad", {"X", "Y", framework::GradVarName("Out")}, {"axis"}, + {framework::GradVarName("X"), framework::GradVarName("Y")}); + } + } + + return framework::KernelSignature("None", {"X"}, {}, {"Out"}); + } }; class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { @@ -522,11 +534,9 @@ class ElemwiseGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &context) const override { auto *dx = context.Output(framework::GradVarName("X")); - if (dx != nullptr) { - auto &dout = - *context.Input(framework::GradVarName("Out")); - dx->set_lod(dout.lod()); - } + auto &dout = + *context.Input(framework::GradVarName("Out")); + pten::funcs::ElementwiseGradPreProcess(dout, dx); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index f0641dd97d..5f4212e556 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -158,32 +158,6 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx, } } -// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub. -// explicit gradient can cut off X, Y, Out from gradient op -// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse -// elementwise code. -template -void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx, - const framework::Tensor &x, - const framework::Tensor &y, - const framework::Tensor &out, - const framework::Tensor &dout, int axis, - framework::Tensor *dx, framework::Tensor *dy, - DX_OP dx_op, DY_OP dy_op) { - const framework::DDim &x_dim = x.dims(); - const framework::DDim &y_dim = y.dims(); - const auto &dev_ctx = ctx.template device_context(); - if (x.dims() == y.dims()) { - pten::funcs::ElemwiseGradComputeNoBroadcast( - dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, - dy_op); - } else { - pten::ElemwiseGradComputeWithBroadcast( - dev_ctx, x_dim, y_dim, dout, dout, out, dout, axis, dx, dy, dx_op, - dy_op); - } -} - // It is a common implementation to compute binary calculation with the support // of broadcast, supporting both CPU and GPU. // - CPU implementation cannot support the case when x needs broadcast, thus @@ -199,30 +173,20 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, int axis, Functor func, framework::Tensor *z) { + z->mutable_data(ctx.GetPlace()); if (platform::is_gpu_place(ctx.GetPlace())) { #if defined(__NVCC__) || defined(__HIPCC__) - std::vector ins = {x, y}; - std::vector outs = {z}; - z->mutable_data(ctx.GetPlace()); - const auto &dev_ctx = ctx.template device_context(); - paddle::operators::LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, - axis, func); + pten::ElementwiseCompute(dev_ctx, *x, *y, axis, func, + z); + #endif return; } - - z->mutable_data(ctx.GetPlace()); - auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); - auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); - auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); - const auto &dev_ctx = ctx.template device_context(); - pten::ElementwiseCompute( - dev_ctx, *pt_x.get(), *pt_y.get(), axis, func, pt_z.get()); + pten::ElementwiseCompute(dev_ctx, *x, *y, axis, func, z); } // FusedElemwiseAndAct @@ -1207,36 +1171,16 @@ template static inline void GetDoubleGradSafeTensor( const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *ddx, framework::Tensor *ddx_safe) { - if (ddx) { - *ddx_safe = *ddx; - } else { - auto &dev_ctx = ctx.template device_context(); - *ddx_safe = ctx.AllocateTmpTensor(x->dims(), dev_ctx); - math::SetConstant set_zero; - set_zero(ctx.template device_context(), ddx_safe, - static_cast(0)); - } + const auto &dev_ctx = ctx.template device_context(); + pten::funcs::GetDoubleGradSafeTensor(dev_ctx, *x, ddx, + ddx_safe); } // for broadcast backwards static inline std::vector GetReduceDim(const framework::DDim &in, const framework::DDim &out, int axis) { - axis = - (axis == -1 ? std::abs(static_cast(out.size() - in.size())) : axis); - std::vector dims; - for (int i = 0; i < axis; ++i) { - dims.push_back(i); - } - for (int i = 0; i < in.size(); ++i) { - if (out[i + axis] != in[i]) { - dims.push_back(i + axis); - } - } - for (int i = axis + in.size(); i < out.size(); ++i) { - dims.push_back(i); - } - return dims; + return pten::funcs::GetReduceDim(in, out, axis); } #if defined(__NVCC__) || defined(__HIPCC__) diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 7d1749f20a..8fc6038ab6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -78,9 +78,11 @@ default_elementwise_sub_grad(const framework::ExecutionContext& ctx, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { int axis = ctx.Attr("axis"); - - ElemwiseExplicitGradCompute, SubGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), SubGradDY()); + const auto& dev_ctx = + ctx.template device_context(); + pten::ElemwiseExplicitGradCompute, SubGradDY>( + dev_ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), + SubGradDY()); } template diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index f2d1e79f03..2672d02db0 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -29,6 +29,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/math_function_impl.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/kernels/funcs/eigen/common.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -52,6 +53,18 @@ template struct SetConstant>; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant; +template struct SetConstant>; +template struct SetConstant>; + #ifdef PADDLE_WITH_XPU template struct SetConstant; template struct SetConstant; diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 8e089970f9..e473861dcf 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -21,6 +21,7 @@ namespace pten { // the key is sorted by key's alphabet const std::unordered_map kernel_alias_name_map = { {"elementwise_add", "add_raw"}, + {"elementwise_add_grad", "add_grad"}, {"elementwise_div", "divide_raw"}, {"elementwise_mul", "muliply_raw"}, {"elementwise_sub", "subtract_raw"}, diff --git a/paddle/pten/kernels/CMakeLists.txt b/paddle/pten/kernels/CMakeLists.txt index 999f72a7e6..615b80be59 100644 --- a/paddle/pten/kernels/CMakeLists.txt +++ b/paddle/pten/kernels/CMakeLists.txt @@ -9,7 +9,7 @@ add_subdirectory(funcs) set_property(GLOBAL PROPERTY PTEN_KERNELS "") set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils) -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function) # remove this dep after removing fluid deps on tensor creation set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) diff --git a/paddle/pten/kernels/cpu/elementwise.h b/paddle/pten/kernels/cpu/elementwise.h index 6bfde977ce..179a188118 100644 --- a/paddle/pten/kernels/cpu/elementwise.h +++ b/paddle/pten/kernels/cpu/elementwise.h @@ -706,4 +706,94 @@ void ElemwiseGradComputeWithBroadcast(const CPUContext& ctx, } } +// NOTE(dzhwinter): Only used in elementwise_add, elementwise_sub. +// explicit gradient can cut off X, Y, Out from gradient op +// In elementwise_add, elementwise_sub, we use dout as fake X, Y, Out to reuse +// elementwise code. +template +void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy, + DX_OP dx_op, + DY_OP dy_op) { + const DDim& x_dim = x.dims(); + const DDim& y_dim = y.dims(); + if (x.dims() == y.dims()) { + pten::funcs::ElemwiseGradComputeNoBroadcast( + dev_ctx, + x_dim, + y_dim, + dout, + dout, + out, + dout, + axis, + dx, + dy, + dx_op, + dy_op); + } else { + ElemwiseGradComputeWithBroadcast(dev_ctx, + x_dim, + y_dim, + dout, + dout, + out, + dout, + axis, + dx, + dy, + dx_op, + dy_op); + } +} + +// Add Grad + +template +struct IdentityGrad { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } +}; + +template +typename std::enable_if::value>::type +elementwise_add_grad(const CPUContext& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int axis = -1) { + auto blas = paddle::operators::math::GetBlas(ctx); + if (dx) { + blas.VCOPY( + dout.numel(), dout.data(), dx->mutable_data(ctx.GetPlace())); + } + + if (dy) { + blas.VCOPY( + dout.numel(), dout.data(), dy->mutable_data(ctx.GetPlace())); + } +} + +template +typename std::enable_if::value>::type +elementwise_add_grad(const CPUContext& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int axis = -1) { + ElemwiseExplicitGradCompute, IdentityGrad>( + ctx, x, y, out, dout, axis, dx, dy, IdentityGrad(), IdentityGrad()); +} + } // namespace pten diff --git a/paddle/pten/kernels/cpu/elementwise_grad_kernel.cc b/paddle/pten/kernels/cpu/elementwise_grad_kernel.cc new file mode 100644 index 0000000000..4a940c2be1 --- /dev/null +++ b/paddle/pten/kernels/cpu/elementwise_grad_kernel.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/elementwise_grad_kernel.h" + +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/copy_kernel.h" +#include "paddle/pten/kernels/cpu/elementwise.h" +#include "paddle/pten/kernels/funcs/elementwise_functor.h" +#include "paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h" + +namespace pten { + +template +void AddGradFunc(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int axis = -1) { + if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { + elementwise_add_grad(dev_ctx, x, y, out, dout, dx, dy); + } else { + ElemwiseExplicitGradCompute, IdentityGrad>( + dev_ctx, + x, + y, + out, + dout, + axis, + dx, + dy, + IdentityGrad(), + IdentityGrad()); + } +} + +template +void AddGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + pten::AddGradImpl(dev_ctx, x, y, dout, axis, dx, dy, AddGradFunc); +} + +template +void AddDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& y, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& dout, + int axis, + DenseTensor* ddout) { + pten::AddDoubleGradImpl( + dev_ctx, + y, + ddx, + ddy, + dout, + axis, + ddout, + ElementwiseCompute, T>, + ElementwiseCompute, T>); +} + +template +void AddTripleGradKernel(const Context& dev_ctx, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_ddout, + int axis, + DenseTensor* d_ddx, + DenseTensor* d_ddy) { + pten::AddGradImpl( + dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc); +} + +} // namespace pten + +PT_REGISTER_KERNEL(add_grad, + CPU, + ALL_LAYOUT, + pten::AddGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(add_double_grad, + CPU, + ALL_LAYOUT, + pten::AddDoubleGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(add_triple_grad, + CPU, + ALL_LAYOUT, + pten::AddTripleGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/elementwise_grad_kernel.h b/paddle/pten/kernels/elementwise_grad_kernel.h new file mode 100644 index 0000000000..067eebc9e1 --- /dev/null +++ b/paddle/pten/kernels/elementwise_grad_kernel.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/utils/optional.h" + +namespace pten { + +template +void AddGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy); + +template +void AddDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& y, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& dout, + int axis, + DenseTensor* ddout); + +template +void AddTripleGradKernel(const Context& dev_ctx, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_ddout, + int axis, + DenseTensor* d_ddx, + DenseTensor* d_ddy); + +} // namespace pten diff --git a/paddle/pten/kernels/funcs/elementwise_base.h b/paddle/pten/kernels/funcs/elementwise_base.h index 1c18e9f799..206ad151c5 100644 --- a/paddle/pten/kernels/funcs/elementwise_base.h +++ b/paddle/pten/kernels/funcs/elementwise_base.h @@ -14,10 +14,12 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/transform.h" #include "paddle/pten/backends/all_context.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/empty_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" @@ -360,6 +362,43 @@ inline void get_mid_dims(const DDim &x_dims, } } +// for broadcast backwards +static inline std::vector GetReduceDim(const paddle::framework::DDim &in, + const paddle::framework::DDim &out, + int axis) { + axis = + (axis == -1 ? std::abs(static_cast(out.size() - in.size())) : axis); + std::vector dims; + for (int i = 0; i < axis; ++i) { + dims.push_back(i); + } + for (int i = 0; i < in.size(); ++i) { + if (out[i + axis] != in[i]) { + dims.push_back(i + axis); + } + } + for (int i = axis + in.size(); i < out.size(); ++i) { + dims.push_back(i); + } + return dims; +} + +template +static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx, + const DenseTensor &x, + const DenseTensor *ddx, + DenseTensor *ddx_safe) { + if (ddx) { + *ddx_safe = *ddx; + } else { + auto meta = pten::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); + *ddx_safe = pten::Empty(dev_ctx, std::move(meta)); + ddx_safe->mutable_data(dev_ctx.GetPlace()); + paddle::operators::math::SetConstant set_zero; + set_zero(dev_ctx, ddx_safe, static_cast(0)); + } +} + template mutable_data(dev_ctx.GetPlace())}); } +inline void ElementwiseGradPreProcess(const DenseTensor &dout, + DenseTensor *dx) { + if (dx != nullptr) { + dx->set_lod(dout.lod()); + } +} + #if defined(__NVCC__) || defined(__HIPCC__) template diff --git a/paddle/pten/kernels/gpu/elementwise.h b/paddle/pten/kernels/gpu/elementwise.h index f4d8e442fc..9a3ae7f12d 100644 --- a/paddle/pten/kernels/gpu/elementwise.h +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -14,9 +14,11 @@ limitations under the License. */ #pragma once +#include "paddle/pten/kernels/copy_kernel.h" #include "paddle/pten/kernels/funcs/common_shape.h" #include "paddle/pten/kernels/funcs/cuda_kernel_config.h" #include "paddle/pten/kernels/funcs/elementwise_base.h" +#include "paddle/pten/kernels/gpu/reduce.h" #ifdef __HIPCC__ constexpr int ELEMWISE_MAX_BLOCK_DIM = 256; @@ -578,6 +580,20 @@ void LaunchElementwiseCudaKernel(const KPDevice &ctx, } } +template +void ElementwiseCompute(const GPUContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + int axis, + Functor func, + DenseTensor *z) { + std::vector ins = {&x, &y}; + std::vector outs = {z}; + z->mutable_data(dev_ctx.GetPlace()); + pten::LaunchElementwiseCudaKernel( + dev_ctx, ins, &outs, axis, func); +} + // BACKWARD CODE // Suppose only has contiguous dims @@ -1938,4 +1954,130 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, } } +template +static __global__ void SimpleElemwiseAddGradCUDAKernel( + const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + int loop = size / vec_size; + int remainder = size % vec_size; + const float4 *dout_vec = reinterpret_cast(dout); + float4 *dx_vec = reinterpret_cast(dx); + float4 *dy_vec = reinterpret_cast(dy); + float4 tmp_loop; + + for (int i = tid; i < loop; i += stride) { + tmp_loop = dout_vec[i]; + dx_vec[i] = tmp_loop; + dy_vec[i] = tmp_loop; + } + + if (tid == loop && remainder != 0) { + T tmp_rem; + while (remainder) { + int idx = size - remainder; + remainder--; + tmp_rem = dout[idx]; + dx[idx] = tmp_rem; + dy[idx] = tmp_rem; + } + } +} + +template +void default_elementwise_add_grad(const GPUContext &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int axis = -1) { + auto *dout_data = dout.data(); + + // dx + if (dx != nullptr) { + auto *dx_data = dx->mutable_data(ctx.GetPlace()); + if (dx->dims() == dout.dims()) { + if (dx_data != dout_data) { + pten::Copy(ctx, dout, false, dx); + } + } else { + // For inplace strategy, dx will be stored in addr of dout, which makes + // the result of dy wrong. + if (dx->IsSharedBufferWith(dout)) { + dx->clear(); + dx->mutable_data(x.dims(), ctx.GetPlace()); + } + std::vector reduce_dims = + funcs::GetReduceDim(x.dims(), out.dims(), axis); + gpuStream_t stream = ctx.stream(); + kernels::TensorReduceFunctorImpl>( + dout, dx, kps::IdentityFunctor(), reduce_dims, stream); + } + } + // dy + if (dy != nullptr) { + auto *dy_data = dy->mutable_data(ctx.GetPlace()); + if (dy->dims() == dout.dims()) { + if (dy_data != dout_data) { + pten::Copy(ctx, dout, false, dy); + } + } else { + std::vector reduce_dims = + funcs::GetReduceDim(y.dims(), out.dims(), axis); + gpuStream_t stream = ctx.stream(); + kernels::TensorReduceFunctorImpl>( + dout, dy, kps::IdentityFunctor(), reduce_dims, stream); + } + } +} + +template +void elementwise_add_grad(const GPUContext &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy) { + auto *dx_data = dx->mutable_data(ctx.GetPlace()); + auto *dy_data = dy->mutable_data(ctx.GetPlace()); + auto *dout_data = dout.data(); + if (dx_data == dout_data && dy_data != dout_data) { + VLOG(4) << "Special case when dx_data is the same as dout_data, " + "only need copy dout to dy"; + pten::Copy(ctx, dout, false, dy); + } else if (dx_data != dout_data && dy_data == dout_data) { + VLOG(4) << "Special case when dy_data is the same as dout_data, " + "only need copy dout to dx"; + pten::Copy(ctx, dout, false, dx); + } else if (dx_data != dout_data && dy_data != dout_data) { + auto size = x.numel(); + int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); + dim3 grid_size = + dim3(((size + vec_size - 1) / vec_size + PREDEFINED_BLOCK_SIZE - 1) / + PREDEFINED_BLOCK_SIZE, + 1); + SimpleElemwiseAddGradCUDAKernel< + T><<>>( + dout.data(), + size, + vec_size, + dx->mutable_data(ctx.GetPlace()), + dy->mutable_data(ctx.GetPlace())); + } else { + VLOG(4) << "Special case when dy_data is the same as dout_data, " + "and dx_data is the same as dout_data, do not need " + "any operator"; + } +} + } // namespace pten diff --git a/paddle/pten/kernels/gpu/elementwise_grad_kernel.cu b/paddle/pten/kernels/gpu/elementwise_grad_kernel.cu new file mode 100644 index 0000000000..76af94f42f --- /dev/null +++ b/paddle/pten/kernels/gpu/elementwise_grad_kernel.cu @@ -0,0 +1,121 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/elementwise_grad_kernel.h" + +#include "paddle/pten/backends/gpu/gpu_context.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/copy_kernel.h" +#include "paddle/pten/kernels/funcs/elementwise_functor.h" +#include "paddle/pten/kernels/gpu/elementwise.h" +#include "paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h" + +namespace pten { + +template +void AddGradFunc(const GPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int axis = -1) { + if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { + elementwise_add_grad(dev_ctx, x, y, out, dout, dx, dy); + } else { + default_elementwise_add_grad(dev_ctx, x, y, out, dout, dx, dy, axis); + } +} + +template +void AddGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + pten::AddGradImpl(dev_ctx, x, y, dout, axis, dx, dy, AddGradFunc); +} + +template +void AddDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& y, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& dout, + int axis, + DenseTensor* ddout) { + pten::AddDoubleGradImpl( + dev_ctx, + y, + ddx, + ddy, + dout, + axis, + ddout, + ElementwiseCompute, T>, + ElementwiseCompute, T>); +} + +template +void AddTripleGradKernel(const Context& dev_ctx, + const DenseTensor& ddx, + const DenseTensor& ddy, + const DenseTensor& d_ddout, + int axis, + DenseTensor* d_ddx, + DenseTensor* d_ddy) { + pten::AddGradImpl( + dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc); +} + +} // namespace pten + +PT_REGISTER_KERNEL(add_grad, + GPU, + ALL_LAYOUT, + pten::AddGradKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(add_double_grad, + GPU, + ALL_LAYOUT, + pten::AddDoubleGradKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(add_triple_grad, + GPU, + ALL_LAYOUT, + pten::AddTripleGradKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h new file mode 100644 index 0000000000..a74c9c0b6b --- /dev/null +++ b/paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/kernels/funcs/elementwise_base.h" +#include "paddle/pten/kernels/funcs/elementwise_functor.h" + +namespace pten { + +template +void AddGradImpl(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + int axis, + DenseTensor* x_grad, + DenseTensor* y_grad, + GradFunc grad_func) { + pten::funcs::ElementwiseGradPreProcess(out_grad, x_grad); + auto* out = &out_grad; + // Special case when y_grad is not needed and x_grad doesn't reduce + if (x_grad != nullptr && y_grad == nullptr && + x_grad->dims() == out_grad.dims()) { + VLOG(4) << "Special case when y_grad is not needed and x_grad doesn't " + "reduce"; + pten::Copy(dev_ctx, out_grad, false, x_grad); + } else if (x_grad == nullptr && y_grad != nullptr && + y_grad->dims() == out_grad.dims()) { + VLOG(4) << "Special case when x_grad is not needed and y_grad doesn't " + "reduce"; + pten::Copy(dev_ctx, out_grad, false, y_grad); + } else { + grad_func(dev_ctx, x, y, *out, out_grad, x_grad, y_grad, axis); + } +} + +template +void AddDoubleGradImpl(const Context& dev_ctx, + const DenseTensor& y, + const paddle::optional& ddx, + const paddle::optional& ddy, + const DenseTensor& dout, + int axis, + DenseTensor* ddout, + GradFunc grad_func, + GradInverseFunc grad_inverse_func) { + // ddOut = ddx + ddy + if (ddout) { + DenseTensor ddx_safe, ddy_safe; + funcs::GetDoubleGradSafeTensor( + dev_ctx, dout, ddx.get_ptr(), &ddx_safe); + funcs::GetDoubleGradSafeTensor( + dev_ctx, y, ddy.get_ptr(), &ddy_safe); + + ddout->mutable_data(dev_ctx.GetPlace()); + auto ddx_dims = ddx_safe.dims(); + auto ddy_dims = ddy_safe.dims(); + if (ddx_dims.size() >= ddy_dims.size()) { + grad_func( + dev_ctx, ddx_safe, ddy_safe, axis, funcs::AddFunctor(), ddout); + } else { + grad_inverse_func(dev_ctx, + ddx_safe, + ddy_safe, + axis, + funcs::InverseAddFunctor(), + ddout); + } + } +} + +} // namespace pten -- GitLab