未验证 提交 660c1a65 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize fused_elewise_activation_grad op. (#18041)

test=develop
上级 46625415
......@@ -1005,24 +1005,24 @@ template <typename T, typename DX_OP, typename DY_OP, typename DIntermediate_OP,
bool UseIntermediateOut>
struct FusedElemwiseAndActGradNoBroadcast {
HOSTDEVICE void operator()(size_t i) {
T x_val = x_[i];
T y_val = y_[i];
T out_val = out_[i];
T dout_val = dout_[i];
T intermediate_out_val = UseIntermediateOut
? intermediate_out_[i]
: dx_op_.GetIntermediateOut(x_val, y_val);
if (dx_ != nullptr) {
dx_[i] = UseIntermediateOut
? dx_op_.UseIntermediateOut(
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dx_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
dx_[i] = dx_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
out_val, dout_val);
}
if (dy_ != nullptr) {
dy_[i] = UseIntermediateOut
? dy_op_.UseIntermediateOut(
x_[i], y_[i], intermediate_out_[i], out_[i], dout_[i])
: dy_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
dy_[i] = dy_op_.UseIntermediateOut(x_val, y_val, intermediate_out_val,
out_val, dout_val);
}
if (dintermediate_ != nullptr) {
dintermediate_[i] =
UseIntermediateOut
? dintermediate_op_.UseIntermediateOut(
x_[i], intermediate_out_[i], out_[i], dout_[i])
: dintermediate_op_.Recompute(x_[i], y_[i], out_[i], dout_[i]);
dintermediate_[i] = dintermediate_op_.UseIntermediateOut(
x_val, intermediate_out_val, out_val, dout_val);
}
}
......
......@@ -74,6 +74,8 @@ struct BinaryCompoundGradDxFunctor {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
......@@ -105,6 +107,8 @@ struct BinaryCompoundGradDyFunctor {
}
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
......@@ -143,6 +147,8 @@ struct UnaryCompoundGradDxFunctor {
return base * d_binary_fun_.Dx(x, y);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
......@@ -181,6 +187,8 @@ struct UnaryCompoundGradDyFunctor {
return base * d_binary_fun_.Dy(x, y);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
......@@ -203,6 +211,8 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor {
return dout * d_binary_fun_.Dy(x, intermediate_out);
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return unary_fun_(y); }
private:
DBinaryFun d_binary_fun_;
UnaryFun unary_fun_;
......@@ -232,6 +242,8 @@ struct UnaryCompoundGradDIntermediateFunctor {
}
}
inline HOSTDEVICE T GetIntermediateOut(T x, T y) { return binary_fun_(x, y); }
private:
DUnaryFun d_unary_fun_;
BinaryFun binary_fun_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册