From c2f825d776171cea7176e4c567e88afde2876ec7 Mon Sep 17 00:00:00 2001 From: Lijunhui <1578034415@qq.com> Date: Wed, 12 Jan 2022 16:51:37 +0800 Subject: [PATCH] optimize elementwise_min_grad using new reduce interface (#38236) * ini commit * multi-outputs init commit * optimize code * remove inplace --- .../elementwise/elementwise_functor.h | 26 ++++++++++++++++ .../elementwise/elementwise_min_op.cu | 30 +++++++++++++++++-- .../elementwise/elementwise_min_op.h | 26 ++++++++++++++-- 3 files changed, 77 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index e2689cefd4..438a47f5dc 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -233,6 +233,32 @@ struct FMinFunctor { } }; +template +struct MinGradXFunctor { + inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const { + return dout * static_cast(x < y); + } +}; +template +struct MinGradYFunctor { + inline HOSTDEVICE T operator()(const T& x, const T& y, const T& dout) const { + return dout * static_cast(x >= y); + } +}; + +template +struct MinGradXYFunctor { + inline HOSTDEVICE paddle::framework::Array operator()( + const InT& x, const InT& y, const InT& dout) { + paddle::framework::Array outs; + // dx = dout * (x < y) + outs[0] = static_cast(dout * static_cast(x < y)); + // dy = dout * (x >= y) + outs[1] = static_cast(dout * static_cast(x >= y)); + return outs; + } +}; + template struct MulGradFunctor { inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index b51dbcd883..a733b4a66f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -24,15 +24,41 @@ class ElementwiseMinKernel void Compute(const framework::ExecutionContext& ctx) const override { std::vector ins; std::vector outs; - const auto& cuda_ctx = + const auto& dev_ctx = ctx.template device_context(); int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, MinFunctor()); + dev_ctx, ins, &outs, axis, MinFunctor()); } }; +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMinGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + const auto& dev_ctx = + ctx.template device_context(); + const auto place = ctx.GetPlace(); + if (dx != nullptr && dy != nullptr) { + std::vector ins = {x, y, dout}; + GetGradXAndYOut( + dev_ctx, place, axis, ins, dout, dx, dy, MinGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + std::vector ins = {x, y, dout}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dx, MinGradXFunctor()); + } else if (dx == nullptr && dy != nullptr) { + std::vector ins = {x, y, dout}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, MinGradYFunctor()); + } +} + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index ffb8c96535..88fb044d42 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -86,6 +86,28 @@ struct MinGradDy { }; #endif +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMinGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy) { + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, MinGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx(), MinGradDy()); +} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +template +typename std::enable_if< + std::is_same::value>::type +ElementwiseMinGrad(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + const framework::Tensor* out, const framework::Tensor* dout, + framework::Tensor* dx, framework::Tensor* dy); +#endif + template class ElementwiseMinGradKernel : public ElemwiseGradKernel { public: @@ -99,9 +121,7 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel { auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); auto* out = dout; // Fake out, not used - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, MinGradDy>( - ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx(), MinGradDy()); + ElementwiseMinGrad(ctx, x, y, out, dout, dx, dy); } }; -- GitLab