You need to sign in or sign up before continuing.
未验证 提交 660c1a65 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize fused_elewise_activation_grad op. (#18041)

test=develop
上级 46625415
...@@ -1005,24 +1005,24 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP, ...@@ -1005,24 +1005,24 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut> bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast { struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) { HOSTDEVICE void operator()(size_t i) {
T x_val = x_[i];
T y_val = y_[i];
T out_val = out_[i];
T dout_val = dout_[i];
T intermediate_out_val = UseIntermediateOut
? intermediate_out_[i]
: dx_op_.GetIntermediateOut(x_val, y_val);
if (dx_ != nullptr) { if (dx_ != nullptr) {
dx_[i] = UseIntermediateOut dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
? dx_op_.UseIntermediateOut( out_val, dout_val);
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
} }
if (dy_ != nullptr) { if (dy_ != nullptr) {
dy_[i] = UseIntermediateOut dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
? dy_op_.UseIntermediateOut( out_val, dout_val);
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
} }
if (dintermediate_ != nullptr) { if (dintermediate_ != nullptr) {
dintermediate_[i] = dintermediate_[i] = dintermediate_op_.UseIntermediateOut(
UseIntermediateOut x_val, intermediate_out_val, out_val, dout_val);
? dintermediate_op_.UseIntermediateOut(
x_[i], intermediate_out_[i], out_[i], dout_[i])
: dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
} }
} }
......
...@@ -74,6 +74,8 @@ struct BinaryCompoundGradDxFunctor { ...@@ -74,6 +74,8 @@ struct BinaryCompoundGradDxFunctor {
return dout * d_binary_fun_.Dx(x, intermediate_out); return dout * d_binary_fun_.Dx(x, intermediate_out);
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private: private:
DBinaryFun d_binary_fun_; DBinaryFun d_binary_fun_;
UnaryFun unary_fun_; UnaryFun unary_fun_;
...@@ -105,6 +107,8 @@ struct BinaryCompoundGradDyFunctor { ...@@ -105,6 +107,8 @@ struct BinaryCompoundGradDyFunctor {
} }
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private: private:
DBinaryFun d_binary_fun_; DBinaryFun d_binary_fun_;
UnaryFun unary_fun_; UnaryFun unary_fun_;
...@@ -143,6 +147,8 @@ struct UnaryCompoundGradDxFunctor { ...@@ -143,6 +147,8 @@ struct UnaryCompoundGradDxFunctor {
return base * d_binary_fun_.Dx(x, y); return base * d_binary_fun_.Dx(x, y);
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private: private:
DUnaryFun d_unary_fun_; DUnaryFun d_unary_fun_;
BinaryFun binary_fun_; BinaryFun binary_fun_;
...@@ -181,6 +187,8 @@ struct UnaryCompoundGradDyFunctor { ...@@ -181,6 +187,8 @@ struct UnaryCompoundGradDyFunctor {
return base * d_binary_fun_.Dy(x, y); return base * d_binary_fun_.Dy(x, y);
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private: private:
DUnaryFun d_unary_fun_; DUnaryFun d_unary_fun_;
BinaryFun binary_fun_; BinaryFun binary_fun_;
...@@ -203,6 +211,8 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor { ...@@ -203,6 +211,8 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor {
return dout * d_binary_fun_.Dy(x, intermediate_out); return dout * d_binary_fun_.Dy(x, intermediate_out);
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private: private:
DBinaryFun d_binary_fun_; DBinaryFun d_binary_fun_;
UnaryFun unary_fun_; UnaryFun unary_fun_;
...@@ -232,6 +242,8 @@ struct UnaryCompoundGradDIntermediateFunctor { ...@@ -232,6 +242,8 @@ struct UnaryCompoundGradDIntermediateFunctor {
} }
} }
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private: private:
DUnaryFun d_unary_fun_; DUnaryFun d_unary_fun_;
BinaryFun binary_fun_; BinaryFun binary_fun_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册