From 7a1e1193f5fcc91b904b0c775f01a64b50586f22 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 27 Jan 2022 16:43:26 +0800 Subject: [PATCH] refactor elementwise sub grad (#39225) --- .../elementwise/elementwise_div_op.h | 17 ++- .../elementwise/elementwise_sub_op.cu | 97 -------------- .../elementwise/elementwise_sub_op.h | 120 ++++-------------- paddle/pten/core/kernel_alias_name.h | 1 + paddle/pten/kernels/cpu/elementwise.h | 36 +++++- .../kernels/cpu/elementwise_grad_kernel.cc | 54 ++++++++ paddle/pten/kernels/elementwise_grad_kernel.h | 18 +++ paddle/pten/kernels/gpu/elementwise.h | 108 ++++++++++++++++ .../kernels/gpu/elementwise_grad_kernel.cu | 60 +++++++++ .../impl/elementwise_grad_kernel_impl.h | 23 ++++ 10 files changed, 336 insertions(+), 198 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index a45f09b63e9..791e548372f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -16,11 +16,26 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" namespace paddle { namespace operators { +template +void default_elementwise_sub(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, framework::Tensor* z) { + int axis = ctx.Attr("axis"); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + SubFunctor(), z); + } else { + ElementwiseComputeEx, DeviceContext, T>( + ctx, x, y, axis, InverseSubFunctor(), z); + } +} + template void default_elementwise_div(const framework::ExecutionContext& ctx, const framework::Tensor* x, diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 8f094767877..038cefe372f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -17,103 +17,6 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; -namespace paddle { -namespace operators { - -template -static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout, - int64_t size, T* dx, - T* dy) { - int col = blockIdx.x * blockDim.x + threadIdx.x; - - while (col < size) { - if (dx != nullptr) { - dx[col] = dout[col]; - } - dy[col] = -dout[col]; - col += blockDim.x * gridDim.x; - } -} - -template -typename std::enable_if< - std::is_same::value>::type -default_elementwise_sub_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - auto* dout_data = dout->data(); - // dx - if (dx != nullptr) { - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - if (dx->dims() == dout->dims()) { - if (dx_data != dout_data) { - framework::TensorCopy( - *dout, ctx.GetPlace(), - ctx.template device_context(), dx); - } - } else { - // For inplace strategy, dx will be stored in addr of dout, which makes - // the result of dy wrong. - if (dx->IsSharedBufferWith(*dout)) { - dx->clear(); - dx->mutable_data(x->dims(), ctx.GetPlace()); - } - std::vector reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); - gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl>( - *dout, dx, kps::IdentityFunctor(), reduce_dims, stream); - } - } - // dy - if (dy != nullptr) { - auto* dy_data = dy->mutable_data(ctx.GetPlace()); - if (dy->dims() == dout->dims()) { - if (dy_data != dout_data) { - dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); - auto size = dy->numel(); - dim3 grid_size = - dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); - SimpleElemwiseSubGradCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context().stream()>>>( - dout->data(), size, nullptr, - dy->mutable_data(ctx.GetPlace())); - } - } else { - std::vector reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); - gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceFunctorImpl>( - *dout, dy, kps::InverseFunctor(), reduce_dims, stream); - } - } -} - -template -typename std::enable_if< - std::is_same::value>::type -elementwise_sub_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { - dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); - auto size = x->numel(); - dim3 grid_size = - dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); - SimpleElemwiseSubGradCUDAKernel< - T><<().stream()>>>( - dout->data(), size, dx->mutable_data(ctx.GetPlace()), - dy->mutable_data(ctx.GetPlace())); -} - -} // namespace operators -} // namespace paddle - REGISTER_OP_CUDA_KERNEL( elementwise_sub, ops::ElementwiseSubKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 8fc6038ab65..fce630512d4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -17,26 +17,11 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/platform/place.h" +#include "paddle/pten/kernels/elementwise_grad_kernel.h" #include "paddle/pten/kernels/math_kernel.h" namespace paddle { namespace operators { -template -void default_elementwise_sub(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, framework::Tensor* z) { - int axis = ctx.Attr("axis"); - auto x_dims = x->dims(); - auto y_dims = y->dims(); - if (x_dims.size() >= y_dims.size()) { - ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, - SubFunctor(), z); - } else { - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, InverseSubFunctor(), z); - } -} - template class ElementwiseSubKernel : public framework::OpKernel { public: @@ -48,76 +33,13 @@ class ElementwiseSubKernel : public framework::OpKernel { auto& dev_ctx = ctx.device_context(); int axis = ctx.Attr("axis"); - auto pt_x = paddle::experimental::MakePtenDenseTensor(*x); - auto pt_y = paddle::experimental::MakePtenDenseTensor(*y); - auto pt_z = paddle::experimental::MakePtenDenseTensor(*z); pten::SubtractRawKernel( static_cast::TYPE&>(dev_ctx), - *pt_x.get(), *pt_y.get(), axis, pt_z.get()); + *x, *y, axis, z); } }; -template -struct SubGradDX { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } -}; - -template -struct SubGradDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; } -}; - -template -typename std::enable_if< - std::is_same::value>::type -default_elementwise_sub_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - const auto& dev_ctx = - ctx.template device_context(); - pten::ElemwiseExplicitGradCompute, SubGradDY>( - dev_ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), - SubGradDY()); -} - -template -typename std::enable_if< - std::is_same::value>::type -elementwise_sub_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { - default_elementwise_sub_grad(ctx, x, y, out, dout, dx, dy); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -// cuda definition -template -typename std::enable_if< - std::is_same::value>::type -default_elementwise_sub_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, - const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy); - -template -typename std::enable_if< - std::is_same::value>::type -elementwise_sub_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy); -#endif - template class ElementwiseSubGradKernel : public ElemwiseGradKernel { public: @@ -130,14 +52,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - // skip out - auto* out = dout; - if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { - elementwise_sub_grad(ctx, x, y, out, dout, dx, dy); - } else { - default_elementwise_sub_grad(ctx, x, y, out, dout, dx, - dy); - } + int axis = ctx.Attr("axis"); + auto& dev_ctx = ctx.device_context(); + + pten::SubtractGradKernel( + static_cast::TYPE&>(dev_ctx), + *x, *y, *dout, axis, dx, dy); } }; @@ -153,18 +74,21 @@ class ElementwiseSubDoubleGradKernel : public framework::OpKernel { auto* ddy = ctx.Input("DDY"); auto* ddout = ctx.Output("DDOut"); + int axis = ctx.Attr("axis"); + auto& dev_ctx = ctx.device_context(); - // DDOut = ddx - ddy - if (ddout) { - Tensor ddx_safe, ddy_safe; - GetDoubleGradSafeTensor(ctx, dout, ddx, &ddx_safe); - GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); - - ddout->mutable_data(ctx.GetPlace()); - int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>( - ctx, &ddx_safe, &ddy_safe, axis, SubFunctor(), ddout); + paddle::optional ddx_optional = paddle::none; + paddle::optional ddy_optional = paddle::none; + if (ddx != nullptr) { + ddx_optional = *ddx; } + if (ddy != nullptr) { + ddy_optional = *ddy; + } + pten::SubtractDoubleGradKernel( + static_cast::TYPE&>(dev_ctx), + *y, ddx_optional, ddy_optional, *dout, axis, ddout); } }; diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index e473861dcf0..cfe3f757974 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -25,6 +25,7 @@ const std::unordered_map kernel_alias_name_map = { {"elementwise_div", "divide_raw"}, {"elementwise_mul", "muliply_raw"}, {"elementwise_sub", "subtract_raw"}, + {"elementwise_sub_grad", "subtract_grad"}, {"fill_any_like", "full_like"}, {"fill_constant", "full"}, {"flatten_contiguous_range", "flatten"}, diff --git a/paddle/pten/kernels/cpu/elementwise.h b/paddle/pten/kernels/cpu/elementwise.h index 2d717414d70..0cd50be511a 100644 --- a/paddle/pten/kernels/cpu/elementwise.h +++ b/paddle/pten/kernels/cpu/elementwise.h @@ -743,8 +743,11 @@ void ElemwiseExplicitGradCompute(const CPUContext& dev_ctx, } } -// Add Grad - +/* +****************************** + Add Grad +****************************** +*/ template struct IdentityGrad { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } @@ -786,4 +789,33 @@ elementwise_add_grad(const CPUContext& ctx, ctx, x, y, out, dout, axis, dx, dy, IdentityGrad(), IdentityGrad()); } +/* +****************************** + Sub Grad +****************************** +*/ + +template +struct SubGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout; } +}; + +template +struct SubGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; } +}; + +template +void elementwise_sub_grad(const CPUContext& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + DenseTensor* dx, + DenseTensor* dy, + int axis = -1) { + ElemwiseExplicitGradCompute, SubGradDY>( + ctx, x, y, out, dout, axis, dx, dy, SubGradDX(), SubGradDY()); +} + } // namespace pten diff --git a/paddle/pten/kernels/cpu/elementwise_grad_kernel.cc b/paddle/pten/kernels/cpu/elementwise_grad_kernel.cc index 4a940c2be15..d3d3aa79edb 100644 --- a/paddle/pten/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/pten/kernels/cpu/elementwise_grad_kernel.cc @@ -92,6 +92,38 @@ void AddTripleGradKernel(const Context& dev_ctx, dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc); } +template +void SubtractGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + // skip out + auto* out = &dout; + elementwise_sub_grad(dev_ctx, x, y, *out, dout, dx, dy, axis); +} + +template +void SubtractDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& y, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& dout, + int axis, + DenseTensor* ddout) { + pten::SubtractDoubleGradImpl( + dev_ctx, + y, + ddx, + ddy, + dout, + axis, + ddout, + ElementwiseCompute, T>); +} + } // namespace pten PT_REGISTER_KERNEL(add_grad, @@ -126,3 +158,25 @@ PT_REGISTER_KERNEL(add_triple_grad, int64_t, paddle::platform::complex, paddle::platform::complex) {} + +PT_REGISTER_KERNEL(subtract_grad, + CPU, + ALL_LAYOUT, + pten::SubtractGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(subtract_double_grad, + CPU, + ALL_LAYOUT, + pten::SubtractDoubleGradKernel, + float, + double, + int, + int64_t, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/elementwise_grad_kernel.h b/paddle/pten/kernels/elementwise_grad_kernel.h index 067eebc9e15..a37c405d068 100644 --- a/paddle/pten/kernels/elementwise_grad_kernel.h +++ b/paddle/pten/kernels/elementwise_grad_kernel.h @@ -46,4 +46,22 @@ void AddTripleGradKernel(const Context& dev_ctx, DenseTensor* d_ddx, DenseTensor* d_ddy); +template +void SubtractGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy); + +template +void SubtractDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& y, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& dout, + int axis, + DenseTensor* ddout); + } // namespace pten diff --git a/paddle/pten/kernels/gpu/elementwise.h b/paddle/pten/kernels/gpu/elementwise.h index 6f744212cd5..f988f5abdb1 100644 --- a/paddle/pten/kernels/gpu/elementwise.h +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -1952,6 +1952,12 @@ void ElemwiseGradComputeWithBroadcast(const GPUContext &ctx, } } +/* +****************************** + Add Grad +****************************** +*/ + template static __global__ void SimpleElemwiseAddGradCUDAKernel( const T *__restrict__ dout, int size, int vec_size, T *dx, T *dy) { @@ -2078,4 +2084,106 @@ void elementwise_add_grad(const GPUContext &ctx, } } +/* +****************************** + Sub Grad +****************************** +*/ + +template +static __global__ void SimpleElemwiseSubGradCUDAKernel(const T *dout, + int64_t size, + T *dx, + T *dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + if (dx != nullptr) { + dx[col] = dout[col]; + } + dy[col] = -dout[col]; + col += blockDim.x * gridDim.x; + } +} + +template +void default_elementwise_sub_grad(const GPUContext &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int axis = -1) { + auto *dout_data = dout.data(); + // dx + if (dx != nullptr) { + auto *dx_data = dx->mutable_data(ctx.GetPlace()); + if (dx->dims() == dout.dims()) { + if (dx_data != dout_data) { + pten::Copy(ctx, dout, false, dx); + } + } else { + // For inplace strategy, dx will be stored in addr of dout, which makes + // the result of dy wrong. + if (dx->IsSharedBufferWith(dout)) { + dx->clear(); + dx->mutable_data(x.dims(), ctx.GetPlace()); + } + std::vector reduce_dims = + funcs::GetReduceDim(x.dims(), out.dims(), axis); + gpuStream_t stream = ctx.stream(); + kernels::TensorReduceFunctorImpl>( + dout, dx, kps::IdentityFunctor(), reduce_dims, stream); + } + } + // dy + if (dy != nullptr) { + auto *dy_data = dy->mutable_data(ctx.GetPlace()); + if (dy->dims() == dout.dims()) { + if (dy_data != dout_data) { + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); + auto size = dy->numel(); + dim3 grid_size = + dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); + SimpleElemwiseSubGradCUDAKernel< + T><<>>( + dout.data(), size, nullptr, dy->mutable_data(ctx.GetPlace())); + } + } else { + std::vector reduce_dims = + funcs::GetReduceDim(y.dims(), out.dims(), axis); + gpuStream_t stream = ctx.stream(); + kernels::TensorReduceFunctorImpl>( + dout, dy, kps::InverseFunctor(), reduce_dims, stream); + } + } +} + +template +void elementwise_sub_grad(const GPUContext &ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &out, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy) { + dim3 block_size = dim3(PREDEFINED_BLOCK_SIZE, 1); + auto size = x.numel(); + dim3 grid_size = + dim3((size + PREDEFINED_BLOCK_SIZE - 1) / PREDEFINED_BLOCK_SIZE, 1); + SimpleElemwiseSubGradCUDAKernel< + T><<>>( + dout.data(), + size, + dx->mutable_data(ctx.GetPlace()), + dy->mutable_data(ctx.GetPlace())); +} + } // namespace pten diff --git a/paddle/pten/kernels/gpu/elementwise_grad_kernel.cu b/paddle/pten/kernels/gpu/elementwise_grad_kernel.cu index 76af94f42fd..f1b3f53b809 100644 --- a/paddle/pten/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/pten/kernels/gpu/elementwise_grad_kernel.cu @@ -82,6 +82,42 @@ void AddTripleGradKernel(const Context& dev_ctx, dev_ctx, ddx, ddy, d_ddout, axis, d_ddx, d_ddy, AddGradFunc); } +template +void SubtractGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + // skip out + auto* out = &dout; + if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { + elementwise_sub_grad(dev_ctx, x, y, *out, dout, dx, dy); + } else { + default_elementwise_sub_grad(dev_ctx, x, y, *out, dout, dx, dy, axis); + } +} + +template +void SubtractDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& y, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& dout, + int axis, + DenseTensor* ddout) { + pten::SubtractDoubleGradImpl( + dev_ctx, + y, + ddx, + ddy, + dout, + axis, + ddout, + ElementwiseCompute, T>); +} + } // namespace pten PT_REGISTER_KERNEL(add_grad, @@ -119,3 +155,27 @@ PT_REGISTER_KERNEL(add_triple_grad, paddle::platform::float16, paddle::platform::complex, paddle::platform::complex) {} + +PT_REGISTER_KERNEL(subtract_grad, + GPU, + ALL_LAYOUT, + pten::SubtractGradKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} + +PT_REGISTER_KERNEL(subtract_double_grad, + GPU, + ALL_LAYOUT, + pten::SubtractDoubleGradKernel, + float, + double, + int, + int64_t, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h index a74c9c0b6be..e35f0891af6 100644 --- a/paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/pten/kernels/impl/elementwise_grad_kernel_impl.h @@ -85,4 +85,27 @@ void AddDoubleGradImpl(const Context& dev_ctx, } } +template +void SubtractDoubleGradImpl(const Context& dev_ctx, + const DenseTensor& y, + const paddle::optional& ddx, + const paddle::optional& ddy, + const DenseTensor& dout, + int axis, + DenseTensor* ddout, + GradFunc grad_func) { + // DDOut = ddx - ddy + if (ddout) { + DenseTensor ddx_safe, ddy_safe; + funcs::GetDoubleGradSafeTensor( + dev_ctx, dout, ddx.get_ptr(), &ddx_safe); + funcs::GetDoubleGradSafeTensor( + dev_ctx, y, ddy.get_ptr(), &ddy_safe); + + ddout->mutable_data(dev_ctx.GetPlace()); + grad_func( + dev_ctx, ddx_safe, ddy_safe, axis, funcs::SubtractFunctor(), ddout); + } +} + } // namespace pten -- GitLab