From 660c1a65f3089c504503499a165d8684b97fe137 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 14 Jun 2019 10:20:28 +0800 Subject: [PATCH] Optimize fused_elewise_activation_grad op. (#18041) test=develop --- .../elementwise/elementwise_op_function.h | 26 +++++++++---------- .../fluid/operators/math/compound_functors.h | 12 +++++++++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index ad9d0b2a0d2..2b108efef4a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -1005,24 +1005,24 @@ template struct FusedElemwiseAndActGradNoBroadcast { HOSTDEVICE void operator()(size_t i) { + T x_val = x_[i]; + T y_val = y_[i]; + T out_val = out_[i]; + T dout_val = dout_[i]; + T intermediate_out_val = UseIntermediateOut + ? intermediate_out_[i] + : dx_op_.GetIntermediateOut(x_val, y_val); if (dx_ != nullptr) { - dx_[i] = UseIntermediateOut - ? dx_op_.UseIntermediateOut( - x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i]) - : dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); + dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val, + out_val, dout_val); } if (dy_ != nullptr) { - dy_[i] = UseIntermediateOut - ? dy_op_.UseIntermediateOut( - x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i]) - : dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); + dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val, + out_val, dout_val); } if (dintermediate_ != nullptr) { - dintermediate_[i] = - UseIntermediateOut - ? dintermediate_op_.UseIntermediateOut( - x_[i], intermediate_out_[i], out_[i], dout_[i]) - : dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); + dintermediate_[i] = dintermediate_op_.UseIntermediateOut( + x_val, intermediate_out_val, out_val, dout_val); } } diff --git a/paddle/fluid/operators/math/compound_functors.h b/paddle/fluid/operators/math/compound_functors.h index 7aba4a917cd..6a43215bf52 100644 --- a/paddle/fluid/operators/math/compound_functors.h +++ b/paddle/fluid/operators/math/compound_functors.h @@ -74,6 +74,8 @@ struct BinaryCompoundGradDxFunctor { return dout * d_binary_fun_.Dx(x, intermediate_out); } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); } + private: DBinaryFun d_binary_fun_; UnaryFun unary_fun_; @@ -105,6 +107,8 @@ struct BinaryCompoundGradDyFunctor { } } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); } + private: DBinaryFun d_binary_fun_; UnaryFun unary_fun_; @@ -143,6 +147,8 @@ struct UnaryCompoundGradDxFunctor { return base * d_binary_fun_.Dx(x, y); } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); } + private: DUnaryFun d_unary_fun_; BinaryFun binary_fun_; @@ -181,6 +187,8 @@ struct UnaryCompoundGradDyFunctor { return base * d_binary_fun_.Dy(x, y); } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); } + private: DUnaryFun d_unary_fun_; BinaryFun binary_fun_; @@ -203,6 +211,8 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor { return dout * d_binary_fun_.Dy(x, intermediate_out); } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); } + private: DBinaryFun d_binary_fun_; UnaryFun unary_fun_; @@ -232,6 +242,8 @@ struct UnaryCompoundGradDIntermediateFunctor { } } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); } + private: DUnaryFun d_unary_fun_; BinaryFun binary_fun_; -- GitLab