From 567e6bbc8063428b28ef0a3804178bcb7c3e9fb7 Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Wed, 8 Dec 2021 14:07:59 +0800 Subject: [PATCH] implementation of broadcast sub backward by reduce (#37754) * add boardcast_sub * add boardcast_sub --- .../elementwise/elementwise_sub_op.cu | 61 ++++++++++++++++++- .../elementwise/elementwise_sub_op.h | 35 ++++++++--- .../kernel_primitives/functor_primitives.h | 14 +++++ .../operators/reduce_ops/reduce_functor_op.h | 11 ++++ 4 files changed, 113 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 00562767c97..2b44c81a455 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -14,6 +14,8 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -30,12 +32,69 @@ static __global__ void SimpleElemwiseSubGradCUDAKernel(const T* dout, int col = blockIdx.x * blockDim.x + threadIdx.x; while (col < size) { - dx[col] = dout[col]; + if (dx != nullptr) { + dx[col] = dout[col]; + } dy[col] = -dout[col]; col += blockDim.x * gridDim.x; } } +template +typename std::enable_if< + std::is_same::value>::type +default_elementwise_sub_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + auto* dout_data = dout->data(); + // dx + if (dx != nullptr) { + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + if (dx->dims() == dout->dims()) { + if (dx_data != dout_data) { + framework::TensorCopy( + *dout, ctx.GetPlace(), + ctx.template device_context(), dx); + } + } else { + // For inplace strategy, dx will be stored in addr of dout, which makes + // the result of dy wrong. + if (dx->IsSharedBufferWith(*dout)) { + dx->clear(); + dx->mutable_data(x->dims(), ctx.GetPlace()); + } + std::vector reduce_dims = GetReduceDim(x->dims(), out->dims(), axis); + gpuStream_t stream = ctx.cuda_device_context().stream(); + TensorReduceFunctorImpl(*dout, dx, reduce_dims, stream); + } + } + // dy + if (dy != nullptr) { + auto* dy_data = dy->mutable_data(ctx.GetPlace()); + if (dy->dims() == dout->dims()) { + if (dy_data != dout_data) { + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); + auto size = dy->numel(); + dim3 grid_size = dim3( + (size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); + SimpleElemwiseSubGradCUDAKernel<<< + grid_size, block_size, 0, + ctx.template device_context().stream()>>>( + dout->data(), size, nullptr, + dy->mutable_data(ctx.GetPlace())); + } + } else { + std::vector reduce_dims = GetReduceDim(y->dims(), out->dims(), axis); + gpuStream_t stream = ctx.cuda_device_context().stream(); + TensorReduceFunctorImpl(*dout, dy, reduce_dims, stream); + } + } +} + template typename std::enable_if< std::is_same::value>::type diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 94c8edf24a1..08a4e709a37 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -71,6 +71,21 @@ struct SubGradDY { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return -dout; } }; +template +typename std::enable_if< + std::is_same::value>::type +default_elementwise_sub_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + + ElemwiseExplicitGradCompute, SubGradDY>( + ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), SubGradDY()); +} + template typename std::enable_if< std::is_same::value>::type @@ -79,13 +94,21 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - ElemwiseExplicitGradCompute, SubGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), SubGradDY()); + default_elementwise_sub_grad(ctx, x, y, out, dout, dx, dy); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // cuda definition +template +typename std::enable_if< + std::is_same::value>::type +default_elementwise_sub_grad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, + const framework::Tensor* y, + const framework::Tensor* out, + const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy); + template typename std::enable_if< std::is_same::value>::type @@ -108,15 +131,13 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); // skip out auto* out = dout; if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { elementwise_sub_grad(ctx, x, y, out, dout, dx, dy); } else { - ElemwiseExplicitGradCompute, SubGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, SubGradDX(), - SubGradDY()); + default_elementwise_sub_grad(ctx, x, y, out, dout, dx, + dy); } } }; diff --git a/paddle/fluid/operators/kernel_primitives/functor_primitives.h b/paddle/fluid/operators/kernel_primitives/functor_primitives.h index 3fce3b1c092..d7aed8595ba 100644 --- a/paddle/fluid/operators/kernel_primitives/functor_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/functor_primitives.h @@ -86,6 +86,20 @@ struct DivideFunctor { Tx n_inv; }; +/** + * @brief Default inverse functor + */ +template +struct InverseFunctor { + HOSTDEVICE inline InverseFunctor() {} + + HOSTDEVICE explicit inline InverseFunctor(int n) {} + + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(-x); + } +}; + /** * @brief Default unary square functor */ diff --git a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h index 90adea60927..dc79666b72f 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h @@ -64,6 +64,17 @@ struct CustomSum { } }; +template +struct CustomSub { + using Transformer = kps::InverseFunctor; + + inline Ty initial() { return static_cast(0.0f); } + + __device__ __forceinline__ Ty operator()(const Ty &a, const Ty &b) const { + return b + a; + } +}; + template struct CustomMean { using Transformer = kps::DivideFunctor; -- GitLab