diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index ad9d0b2a0d233ef0ab0c5ab3bfc5d935cfdf0895..2b108efef4a34b5e03bd55cd59adfbfb0df67e22 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -1005,24 +1005,24 @@ template struct FusedElemwiseAndActGradNoBroadcast { HOSTDEVICE void operator()(size_t i) { + T x_val = x_[i]; + T y_val = y_[i]; + T out_val = out_[i]; + T dout_val = dout_[i]; + T intermediate_out_val = UseIntermediateOut + ? intermediate_out_[i] + : dx_op_.GetIntermediateOut(x_val, y_val); if (dx_ != nullptr) { - dx_[i] = UseIntermediateOut - ? dx_op_.UseIntermediateOut( - x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i]) - : dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); + dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val, + out_val, dout_val); } if (dy_ != nullptr) { - dy_[i] = UseIntermediateOut - ? dy_op_.UseIntermediateOut( - x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i]) - : dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); + dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val, + out_val, dout_val); } if (dintermediate_ != nullptr) { - dintermediate_[i] = - UseIntermediateOut - ? dintermediate_op_.UseIntermediateOut( - x_[i], intermediate_out_[i], out_[i], dout_[i]) - : dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]); + dintermediate_[i] = dintermediate_op_.UseIntermediateOut( + x_val, intermediate_out_val, out_val, dout_val); } } diff --git a/paddle/fluid/operators/math/compound_functors.h b/paddle/fluid/operators/math/compound_functors.h index 7aba4a917cdea50f95bcc7627f707257606fc927..6a43215bf52a9b231a47241d1bb27695da031957 100644 --- a/paddle/fluid/operators/math/compound_functors.h +++ b/paddle/fluid/operators/math/compound_functors.h @@ -74,6 +74,8 @@ struct BinaryCompoundGradDxFunctor { return dout * d_binary_fun_.Dx(x, intermediate_out); } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); } + private: DBinaryFun d_binary_fun_; UnaryFun unary_fun_; @@ -105,6 +107,8 @@ struct BinaryCompoundGradDyFunctor { } } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); } + private: DBinaryFun d_binary_fun_; UnaryFun unary_fun_; @@ -143,6 +147,8 @@ struct UnaryCompoundGradDxFunctor { return base * d_binary_fun_.Dx(x, y); } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); } + private: DUnaryFun d_unary_fun_; BinaryFun binary_fun_; @@ -181,6 +187,8 @@ struct UnaryCompoundGradDyFunctor { return base * d_binary_fun_.Dy(x, y); } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); } + private: DUnaryFun d_unary_fun_; BinaryFun binary_fun_; @@ -203,6 +211,8 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor { return dout * d_binary_fun_.Dy(x, intermediate_out); } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); } + private: DBinaryFun d_binary_fun_; UnaryFun unary_fun_; @@ -232,6 +242,8 @@ struct UnaryCompoundGradDIntermediateFunctor { } } + inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); } + private: DUnaryFun d_unary_fun_; BinaryFun binary_fun_;