From 4a64ca1e9df499d3d8822c00304b41b5215f2a93 Mon Sep 17 00:00:00 2001 From: Lijunhui <1578034415@qq.com> Date: Wed, 12 Jan 2022 17:08:04 +0800 Subject: [PATCH] optimize elementwise_max_grad using new interfaces (#37906) * init elem_max_grad op * optimize code and reply review comments * ternary functors * apply new reduce func * move functor to .h * multi-outputs init * rearrange code * modifed functors * optimizer code * pass nullptr * revert the last change as seg fault occurs * optimize code * remove inplace * remove comments --- .../elementwise/elementwise_functor.h | 27 +++++++++++++++++ .../elementwise/elementwise_max_op.cu | 30 +++++++++++++++++-- .../elementwise/elementwise_max_op.h | 29 +++++++++++++++--- 3 files changed, 80 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 438a47f5dc5..a8c9640d479 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -301,5 +301,32 @@ struct MulGradXYFunctor, Complex> { } }; +// Ternary compare +template +struct MaxGradXFunctor { + inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const { + return dout * static_cast(x > y); + } +}; +template +struct MaxGradYFunctor { + inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const { + return dout * static_cast(x <= y); + } +}; + +template +struct MaxGradXYFunctor { + inline HOSTDEVICE paddle::framework::Array operator()( + const InT& x, const InT& y, const InT& dout) { + paddle::framework::Array outs; + // dx = dout * (x > y) + outs[0] = static_cast(dout * static_cast(x > y)); + // dy = dout * (x <= y) + outs[1] = static_cast(dout * static_cast(x <= y)); + return outs; + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 76042920088..eaf77744285 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -24,15 +24,41 @@ class ElementwiseMaxKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - const auto& cuda_ctx = + const auto& dev_ctx = ctx.template device_context(); int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, MaxFunctor()); + dev_ctx, ins, &outs, axis, MaxFunctor()); } }; +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMaxGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + const auto& dev_ctx = + ctx.template device_context(); + const auto place = ctx.GetPlace(); + if (dx != nullptr && dy != nullptr) { + std::vector ins = {x, y, dout}; + GetGradXAndYOut( + dev_ctx, place, axis, ins, dout, dx, dy, MaxGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + std::vector ins = {x, y, dout}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dx, MaxGradXFunctor()); + } else if (dx == nullptr && dy != nullptr) { + std::vector ins = {x, y, dout}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, MaxGradYFunctor()); + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index a7a49fed871..cff30be50a3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -64,6 +64,28 @@ struct MaxGradDy { } }; +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMaxGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, MaxGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx(), MaxGradDy()); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMaxGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy); +#endif + template class ElementwiseMaxGradKernel : public ElemwiseGradKernel { public: @@ -74,12 +96,11 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* out = dout; // out is not necessary auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); - auto* out = dout; // Fake out, not used - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, MaxGradDy>( - ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx(), MaxGradDy()); + + ElementwiseMaxGrad(ctx, x, y, out, dout, dx, dy); } }; -- GitLab