From 36a102f85de37c4d6ade394eb3f552a1dc1a30b7 Mon Sep 17 00:00:00 2001 From: Lijunhui <1578034415@qq.com> Date: Wed, 5 Jan 2022 21:49:41 +0800 Subject: [PATCH] optimize elementwise_mul_grad using new interfaces (#37728) * init commit: new elem_mul_grad * add template speciallization for complex in multiply * reply review comments * correct dx and dy computation when T is complex * reply review comments * update to new ReduceRunctor * mul-output broadcast * call functions * call functions with comments * remove comments --- .../elementwise/elementwise_functor.h | 42 ++++++++ .../elementwise/elementwise_mul_op.cu | 95 +++++++------------ .../elementwise/elementwise_mul_op.h | 29 ++---- 3 files changed, 86 insertions(+), 80 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index b7bebcaa386..a62c531ff07 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -194,5 +194,47 @@ struct FMinFunctor { } }; +template +struct MulGradFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } +}; +template +struct MulGradFunctor> { + inline HOSTDEVICE Complex operator()(const Complex& a, + const Complex& b) const { + Complex b_conj(b.real, -b.imag); + return a * b_conj; + } +}; + +template +struct MulGradXYFunctor { + inline HOSTDEVICE paddle::framework::Array operator()(const InT& a, + const InT& b, + const InT& c) { + paddle::framework::Array outs; + // dx = dout * y + outs[0] = a * b; + // dy = dout * x + outs[1] = a * c; + return outs; + } +}; + +template +struct MulGradXYFunctor, Complex> { + inline HOSTDEVICE paddle::framework::Array, 2> operator()( + const Complex& a, const Complex& b, const Complex& c) { + paddle::framework::Array, 2> outs; + // dx = dout * y + Complex b_conj(b.real, -b.imag); + outs[0] = a * b_conj; + // dy = dout * x + Complex c_conj(c.real, -c.imag); + outs[1] = a * c_conj; + return outs; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 12e0062a698..cdf376fd6a8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -68,69 +69,41 @@ class ElementwiseMulKernel } }; -template -static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, - const T* out, - const T* dout, - int64_t size, T* dx, - T* dy) { - int col = blockIdx.x * blockDim.x + threadIdx.x; - - while (col < size) { - T o = dout[col]; - dx[col] = y[col] * o; - dy[col] = x[col] * o; - col += blockDim.x * gridDim.x; - } -} - -template <> -__global__ void SimpleElemwiseMulGradCUDAKernel>( - const plat::complex* x, const plat::complex* y, - const plat::complex* out, const plat::complex* dout, - int64_t size, plat::complex* dx, plat::complex* dy) { - int col = blockIdx.x * blockDim.x + threadIdx.x; - - while (col < size) { - plat::complex o = dout[col]; - dx[col] = plat::complex(y[col].real, -y[col].imag) * o; - dy[col] = plat::complex(x[col].real, -x[col].imag) * o; - col += blockDim.x * gridDim.x; - } -} - -template <> -__global__ void SimpleElemwiseMulGradCUDAKernel>( - const plat::complex* x, const plat::complex* y, - const plat::complex* out, const plat::complex* dout, - int64_t size, plat::complex* dx, plat::complex* dy) { - int col = blockIdx.x * blockDim.x + threadIdx.x; - - while (col < size) { - plat::complex o = dout[col]; - dx[col] = plat::complex(y[col].real, -y[col].imag) * o; - dy[col] = plat::complex(x[col].real, -x[col].imag) * o; - col += blockDim.x * gridDim.x; - } -} - template typename std::enable_if< - std::is_same::value>::type -elementwise_mul_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { - dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); - auto size = x->numel(); - dim3 grid_size = - dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); - SimpleElemwiseMulGradCUDAKernel< - T><<().stream()>>>( - x->data(), y->data(), out->data(), dout->data(), size, - dx->mutable_data(ctx.GetPlace()), dy->mutable_data(ctx.GetPlace())); + std::is_same::value>::type +ElementwiseMulGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + const auto& dev_ctx = + ctx.template device_context(); + const auto place = ctx.GetPlace(); + + if (dx != nullptr && dy != nullptr) { + dx->mutable_data(place); + if (dx->IsSharedBufferWith(*dout)) { + dx->clear(); + dx->mutable_data(x->dims(), place); + } + std::vector ins = {dout, y, x}; + GetGradXAndYOut( + dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + dx->mutable_data(place); + if (dx->IsSharedBufferWith(*dout)) { + dx->clear(); + dx->mutable_data(x->dims(), place); + } + std::vector ins = {dout, y}; + GetGradXOrYOut(dev_ctx, place, axis, ins, dout, + dx, MulGradFunctor()); + } else if (dx == nullptr && dy != nullptr) { + std::vector ins = {dout, x}; + GetGradXOrYOut(dev_ctx, place, axis, ins, dout, + dy, MulGradFunctor()); + } } } // namespace operators diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 3b0f0725722..5cff3173e81 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -174,26 +174,23 @@ struct MulGradDY> { template typename std::enable_if< std::is_same::value>::type -elementwise_mul_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy) { +ElementwiseMulGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { int axis = ctx.Attr("axis"); ElemwiseGradCompute, MulGradDY>( ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), MulGradDY()); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -// cuda definition template typename std::enable_if< std::is_same::value>::type -elementwise_mul_grad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, - const framework::Tensor* dout, framework::Tensor* dx, - framework::Tensor* dy); +ElementwiseMulGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy); #endif template @@ -209,14 +206,8 @@ class ElementwiseMulGradKernel : public ElemwiseGradKernel { auto* out = dout; // out is not necessary auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - int axis = ctx.Attr("axis"); - if (dx != nullptr && dy != nullptr && (dx->dims() == dy->dims())) { - elementwise_mul_grad(ctx, x, y, out, dout, dx, dy); - } else { - ElemwiseGradCompute, MulGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), - MulGradDY()); - } + + ElementwiseMulGrad(ctx, x, y, out, dout, dx, dy); } }; -- GitLab