From 55cd9cb867566c58876d671f4c35039722fd1b34 Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Wed, 5 Jan 2022 10:27:37 +0800 Subject: [PATCH] implementation of broadcast div backward by reduce (#38044) * add elementwise div * move mul and div grad functor * Combine multiple CUDA kernels * Update the reduce interface call * add multi-output * add multi-output div * add branch judge * Package branch * Combine the x and y functions into one --- .../elementwise/elementwise_div_op.cu | 109 +++++------------- .../elementwise/elementwise_div_op.h | 29 ++--- .../elementwise/elementwise_functor.h | 67 +++++++++++ .../elementwise/elementwise_op_function.h | 73 ++++++++++++ 4 files changed, 181 insertions(+), 97 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 80089243f25..7a25f653669 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -13,9 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" namespace ops = paddle::operators; namespace plat = paddle::platform; @@ -23,83 +20,39 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -template -static __global__ void SimpleElemwiseDivGradCUDAKernel(const T* x, const T* y, - const T* out, - const T* dout, - int64_t size, T* dx, - T* dy) { - int col = blockIdx.x * blockDim.x + threadIdx.x; - - while (col < size) { - T o = dout[col]; - dx[col] = o / y[col]; - dy[col] = -o * out[col] / y[col]; - col += blockDim.x * gridDim.x; - } -} - -template <> -__global__ void -SimpleElemwiseDivGradCUDAKernel>( - const paddle::platform::complex* x, - const paddle::platform::complex* y, - const paddle::platform::complex* out, - const paddle::platform::complex* dout, int64_t size, - paddle::platform::complex* dx, - paddle::platform::complex* dy) { - int col = blockIdx.x * blockDim.x + threadIdx.x; - - while (col < size) { - paddle::platform::complex o = dout[col]; - paddle::platform::complex y_conj(y[col].real, -y[col].imag); - paddle::platform::complex out_div_y_conj((out[col] / y[col]).real, - -(out[col] / y[col]).imag); - dx[col] = o / y_conj; - dy[col] = -o * out_div_y_conj; - col += blockDim.x * gridDim.x; - } -} - -template <> -__global__ void -SimpleElemwiseDivGradCUDAKernel>( - const paddle::platform::complex* x, - const paddle::platform::complex* y, - const paddle::platform::complex* out, - const paddle::platform::complex* dout, int64_t size, - paddle::platform::complex* dx, - paddle::platform::complex* dy) { - int col = blockIdx.x * blockDim.x + threadIdx.x; - - while (col < size) { - paddle::platform::complex o = dout[col]; - paddle::platform::complex y_conj(y[col].real, -y[col].imag); - paddle::platform::complex out_div_y_conj((out[col] / y[col]).real, - -(out[col] / y[col]).imag); - dx[col] = o / y_conj; - dy[col] = -o * out_div_y_conj; - col += blockDim.x * gridDim.x; - } -} - template typename std::enable_if< - std::is_same::value>::type -elementwise_div_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { - dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); - auto size = x->numel(); - dim3 grid_size = - dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); - SimpleElemwiseDivGradCUDAKernel< - T><<().stream()>>>( - x->data(), y->data(), out->data(), dout->data(), size, - dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); + std::is_same::value>::type +ElementwiseDivGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + const auto& dev_ctx = ctx.template device_context(); + const auto place = ctx.GetPlace(); + if (dx != nullptr && dy != nullptr) { + dx->mutable_data(place); + if (dx->IsSharedBufferWith(*dout)) { + dx->clear(); + dx->mutable_data(x->dims(), place); + } + std::vector ins = {dout, out, y}; + GetGradXAndYOut( + dev_ctx, place, axis, ins, dout, dx, dy, DivGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + dx->mutable_data(place); + if (dx->IsSharedBufferWith(*dout)) { + dx->clear(); + dx->mutable_data(x->dims(), place); + } + std::vector ins = {dout, y}; + GetGradXOrYOut(dev_ctx, place, axis, ins, dout, + dx, DivGradXFunctor()); + } else if (dy != nullptr && dx == nullptr) { + std::vector ins = {dout, out, y}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, DivGradYFunctor()); + } } } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index c886644bbdd..b13a0539ec6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -111,26 +111,24 @@ struct DivDoubleDY { template typename std::enable_if< std::is_same::value>::type -elementwise_div_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { +ElementwiseDivGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { int axis = ctx.Attr("axis"); + ElemwiseGradCompute, DivGradDY>( ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX(), DivGradDY()); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -// cuda definition template typename std::enable_if< std::is_same::value>::type -elementwise_div_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy); +ElementwiseDivGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy); #endif template @@ -146,15 +144,8 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel { auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); - if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { - elementwise_div_grad(ctx, x, y, out, dout, dx, dy); - } else { - ElemwiseGradCompute, DivGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, DivGradDX(), - DivGradDY()); - } + ElementwiseDivGrad(ctx, x, y, out, dout, dx, dy); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 7ff8e6a1543..b7bebcaa386 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -14,6 +14,8 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/framework/array.h" +#include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/hostdevice.h" @@ -87,6 +89,71 @@ struct MinFunctor { } }; +template +using Complex = paddle::platform::complex; + +template +struct DivGradXYFunctor { + inline HOSTDEVICE paddle::framework::Array operator()(const InT a, + const InT b, + const InT c) { + // dx = dout / y + // dy = - dout * out / y + paddle::framework::Array outs; + outs[0] = a / c; + outs[1] = -a * b / c; + return outs; + } +}; + +template +struct DivGradXYFunctor, Complex> { + inline HOSTDEVICE paddle::framework::Array, 2> operator()( + const Complex a, const Complex b, const Complex c) { + paddle::framework::Array, 2> outs; + Complex c_conj(c.real, -c.imag); + Complex out_div_c_conj((b / c).real, -(b / c).imag); + outs[0] = a / c_conj; + outs[1] = -a * out_div_c_conj; + return outs; + } +}; + +// Float div grad +template +struct DivGradXFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; } +}; + +// Complex div grad +template +struct DivGradXFunctor> { + inline HOSTDEVICE Complex operator()(const Complex& a, + const Complex& b) const { + Complex b_conj(b.real, -b.imag); + return a / b_conj; + } +}; + +// Float mul and div +template +struct DivGradYFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b, const T& c) const { + return -a * b / c; + } +}; + +// Complex mul and div +template +struct DivGradYFunctor> { + inline HOSTDEVICE Complex operator()(const Complex& a, + const Complex& b, + const Complex& c) const { + Complex out_div_c_conj((b / c).real, -(b / c).imag); + return -a * out_div_c_conj; + } +}; + // Fmax template struct FMaxFunctor { diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 6f3e17ea4d4..a145848bad9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -42,6 +42,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" @@ -2556,5 +2557,77 @@ static inline std::vector GetReduceDim(const framework::DDim &in, } return dims; } + +#if defined(__NVCC__) || defined(__HIPCC__) +template +void ReduceWrapper(const platform::CUDADeviceContext &dev_ctx, int axis, + framework::Tensor *src, framework::Tensor *dst) { + std::vector reduce_dims = GetReduceDim(dst->dims(), src->dims(), axis); + TensorReduceFunctorImpl>( + *src, dst, kps::IdentityFunctor(), reduce_dims, dev_ctx.stream()); +} + +template +void GetGradXAndYOut(const platform::CUDADeviceContext &dev_ctx, + const platform::Place &place, int axis, + std::vector ins, + const framework::Tensor *dout, framework::Tensor *dx, + framework::Tensor *dy, Functor func) { + framework::Tensor tmp_dx; + framework::Tensor tmp_dy; + dy->mutable_data(place); + std::vector outs; + if (dx->dims() == dout->dims() && dy->dims() == dout->dims()) { + outs = {dx, dy}; + } else if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { + tmp_dx.mutable_data(dout->dims(), place); + outs = {&tmp_dx, dy}; + } else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { + tmp_dy.mutable_data(dout->dims(), place); + outs = {dx, &tmp_dy}; + } else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { + tmp_dy.mutable_data(dout->dims(), place); + tmp_dx.mutable_data(dout->dims(), place); + outs = {&tmp_dx, &tmp_dy}; + } + + LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, + axis, func); + + if (dx->dims() != dout->dims() && dy->dims() == dout->dims()) { + ReduceWrapper(dev_ctx, axis, &tmp_dx, dx); + } else if (dx->dims() == dout->dims() && dy->dims() != dout->dims()) { + ReduceWrapper(dev_ctx, axis, &tmp_dy, dy); + } else if (dx->dims() != dout->dims() && dy->dims() != dout->dims()) { + ReduceWrapper(dev_ctx, axis, &tmp_dx, dx); + ReduceWrapper(dev_ctx, axis, &tmp_dy, dy); + } +} + +template +void GetGradXOrYOut(const platform::CUDADeviceContext &dev_ctx, + const platform::Place &place, int axis, + std::vector ins, + const framework::Tensor *dout, framework::Tensor *dxy, + Functor func) { + framework::Tensor tmp_dxy; + dxy->mutable_data(place); + + std::vector outs; + if (dxy->dims() != dout->dims()) { + tmp_dxy.mutable_data(dout->dims(), place); + outs = {&tmp_dxy}; + } else { + outs = {dxy}; + } + + LaunchElementwiseCudaKernel(dev_ctx, ins, &outs, axis, func); + if (dxy->dims() != dout->dims()) { + ReduceWrapper(dev_ctx, axis, &tmp_dxy, dxy); + } +} + +#endif + } // namespace operators } // namespace paddle -- GitLab