未验证 提交 57f54d3b 编写于 作者: Y YuanRisheng 提交者: GitHub

move activation kernel (#40565)

上级 603f8425
...@@ -1485,6 +1485,13 @@ REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor); ...@@ -1485,6 +1485,13 @@ REGISTER_ACTIVATION_OP(atanh, Atanh, AtanhFunctor, AtanhGradFunctor);
REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor); REGISTER_ACTIVATION_OP(brelu, BRelu, BReluFunctor, BReluGradFunctor);
REGISTER_ACTIVATION_OP(thresholded_relu, ThresholdedRelu, REGISTER_ACTIVATION_OP(thresholded_relu, ThresholdedRelu,
ThresholdedReluFunctor, ThresholdedReluGradFunctor); ThresholdedReluFunctor, ThresholdedReluGradFunctor);
REGISTER_ACTIVATION_OP(hard_shrink, HardShrink, HardShrinkFunctor,
HardShrinkGradFunctor);
REGISTER_ACTIVATION_OP(softshrink, SoftShrink, SoftShrinkFunctor,
SoftShrinkGradFunctor);
REGISTER_ACTIVATION_OP(tanh_shrink, TanhShrink, TanhShrinkFunctor,
TanhShrinkGradFunctor);
REGISTER_ACTIVATION_OP(silu, Silu, SiluFunctor, SiluGradFunctor);
/* ========================== sigmoid register ============================= /* ========================== sigmoid register =============================
*/ */
...@@ -1626,22 +1633,6 @@ REGISTER_OPERATOR( ...@@ -1626,22 +1633,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(elu,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ELUFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ELUFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
elu_grad, ops::ELUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ELUGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<float>>,
ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<double>>,
ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* ======================== logit register ============================ /* ======================== logit register ============================
......
...@@ -279,6 +279,15 @@ USE_PHI_FUNCTOR(BRelu) ...@@ -279,6 +279,15 @@ USE_PHI_FUNCTOR(BRelu)
USE_PHI_FUNCTOR(ThresholdedRelu) USE_PHI_FUNCTOR(ThresholdedRelu)
USE_PHI_FUNCTOR(LeakyRelu) USE_PHI_FUNCTOR(LeakyRelu)
USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu) USE_PHI_DOUBLE_GRAD_FUNCTOR(LeakyRelu)
USE_PHI_FUNCTOR(HardShrink)
USE_PHI_FUNCTOR(SoftShrink)
USE_PHI_FUNCTOR(TanhShrink)
USE_PHI_FUNCTOR(Silu)
USE_PHI_FUNCTOR(ELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU)
template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
template <typename T> template <typename T>
struct SigmoidGradFunctor : public BaseActivationFunctor<T> { struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
...@@ -392,31 +401,6 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> { ...@@ -392,31 +401,6 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
// silu(x) = x / (1 + exp(-x))
template <typename T>
struct SiluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
out.device(d) = x * temp;
}
};
// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x}))
template <typename T>
struct SiluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) + (-x).exp(); // 1+e^(-x)
auto temp2 = x * (-x).exp(); // x*e^(-x)
dx.device(d) = dout * ((static_cast<T>(1) / temp1) *
(static_cast<T>(1) + (temp2 / temp1)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// Originally: logsigmoid(x) = -log (1 + exp(-x)) // Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick: // For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
...@@ -512,99 +496,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>; ...@@ -512,99 +496,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
template <typename T> template <typename T>
using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>; using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x - x.tanh();
}
};
template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x.tanh() * x.tanh());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
out.device(d) = x * (temp1 || temp2).template cast<T>();
}
};
template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
dx.device(d) = dout * (temp1 || temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>();
auto temp2 = (x < -lambdaT).template cast<T>();
out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
}
};
template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>();
auto temp2 = (x < -lambdaT).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// sqrt(x) = x^(1/2) // sqrt(x) = x^(1/2)
template <typename T> template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> { struct SqrtFunctor : public BaseActivationFunctor<T> {
...@@ -1036,59 +927,6 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1036,59 +927,6 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)), x);
}
};
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
dx.device(d) = (out > static_cast<T>(0))
.select(dout, dout * (out + static_cast<T>(alpha)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
dx.device(d) = (x > static_cast<T>(0))
.select(dout, dout * static_cast<T>(alpha) * x.exp());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ELUGradKernel : public framework::OpKernel<T> { class ELUGradKernel : public framework::OpKernel<T> {
public: public:
...@@ -1354,44 +1192,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1354,44 +1192,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad"));
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> { struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha; float alpha;
...@@ -2152,9 +1952,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2152,9 +1952,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
} // namespace paddle } // namespace paddle
#define FOR_EACH_ACTIVATION_OP(__macro) \ #define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(silu, Silu, SiluFunctor, SiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \ __macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \ __macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \ __macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \ __macro(round, Round, RoundFunctor, ZeroGradFunctor); \
...@@ -2167,8 +1965,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -2167,8 +1965,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \ __macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
HardSigmoidGradFunctor); \ HardSigmoidGradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ __macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
......
...@@ -44,35 +44,6 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -44,35 +44,6 @@ struct CudaSigmoidGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// silu(x) = x / (1 + exp(-x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(x / (one + exp(-x)));
}
};
template <typename T>
struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType temp = one / (one + exp(-x));
return static_cast<T>(dout * (temp * (one + x * (one - temp))));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> { struct CudaLogSigmoidFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
...@@ -110,43 +81,6 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -110,43 +81,6 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
// softshrink(x) = x - lambda, if x > lambda;
// x + lambda, if x < -lambda;
// 0, otherwise.
__device__ __forceinline__ T operator()(const T x) const {
T l = static_cast<T>(lambda);
T temp1 = static_cast<T>(x > l);
T temp2 = static_cast<T>(x < -l);
return temp1 * (x - l) + temp2 * (x + l);
}
};
template <typename T>
struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
// dx = dout, if x > lambda or x < -lambda else 0
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T l = static_cast<T>(lambda);
return (x >= -l && x <= l) ? zero : dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> { struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
...@@ -615,66 +549,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -615,66 +549,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// tanhshrink(x) = x - tanh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(x - tanh(x));
}
};
template <typename T>
struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// dx = dout * tanh(x)^2
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * tanh(x) * tanh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
__device__ __forceinline__ T operator()(const T x) const {
T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : x;
}
};
template <typename T>
struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = (x > -threshold && x < threshold) ? 0 : dout
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> { struct CudaHardSigmoidFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f); T zero = static_cast<T>(0.0f);
...@@ -863,110 +737,6 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> { ...@@ -863,110 +737,6 @@ struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct CudaELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// elu(x) = x, if x > 0
// elu(x) = alpha * (e^x - 1), if x <= 0
__device__ __forceinline__ T operator()(const T arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x) - one);
CT res = x > zero ? x : temp;
return static_cast<T>(res);
}
};
template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
__device__ __forceinline__ T operator()(T arg_dout, T arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType a = static_cast<MPType>(alpha);
MPType out_pos = static_cast<MPType>(out > zero);
MPType out_neg = static_cast<MPType>(out <= zero);
return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
__device__ __forceinline__ T operator()(const T arg_dout, const T arg_out,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType x_pos = static_cast<MPType>(x > zero);
MPType x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename DeviceContext, typename T>
class ELUGradCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = ctx.Input<framework::Tensor>("Out");
auto* x = ctx.Input<framework::Tensor>("X");
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const float alpha = ctx.Attr<float>("alpha");
auto& dev_ctx = ctx.device_context<DeviceContext>();
std::vector<const framework::Tensor*> ins = {d_out, out};
std::vector<framework::Tensor*> outs = {d_x};
if (alpha > 0) {
CudaELUGradFunctor<T> functor;
functor.alpha = alpha;
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
} else {
CudaELUGradNegativeAlphaFunctor<T> functor;
functor.alpha = alpha;
ins.push_back(x);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
}
}
};
template <typename T> template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> { struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type; using CT = typename details::MPTypeTrait<T>::Type;
...@@ -1099,6 +869,15 @@ USE_PHI_FUNCTOR(CudaTanh) ...@@ -1099,6 +869,15 @@ USE_PHI_FUNCTOR(CudaTanh)
USE_PHI_FUNCTOR(CudaBRelu) USE_PHI_FUNCTOR(CudaBRelu)
USE_PHI_FUNCTOR(CudaLeakyRelu) USE_PHI_FUNCTOR(CudaLeakyRelu)
USE_PHI_FUNCTOR(CudaThresholdedRelu) USE_PHI_FUNCTOR(CudaThresholdedRelu)
USE_PHI_FUNCTOR(CudaHardShrink)
USE_PHI_FUNCTOR(CudaSoftShrink)
USE_PHI_FUNCTOR(CudaTanhShrink)
USE_PHI_FUNCTOR(CudaSilu)
USE_PHI_FUNCTOR(CudaELU)
template <typename T>
using CudaELUGradNegativeAlphaFunctor =
phi::funcs::CudaELUGradNegativeAlphaFunctor<T>;
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -1158,26 +937,6 @@ namespace plat = paddle::platform; ...@@ -1158,26 +937,6 @@ namespace plat = paddle::platform;
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \ ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::bfloat16>>); ops::grad_functor<plat::bfloat16>>);
/* ======================== elu register ============================ */
REGISTER_OP_CUDA_KERNEL(
elu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaELUFunctor<float>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaELUFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaELUFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
elu_grad, ops::ELUGradCudaKernel<plat::CUDADeviceContext, float>,
ops::ELUGradCudaKernel<plat::CUDADeviceContext, double>,
ops::ELUGradCudaKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<float>>,
ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<double>>,
ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* ======================== celu register ============================ */ /* ======================== celu register ============================ */
...@@ -1359,7 +1118,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -1359,7 +1118,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */ /* ========================================================================== */
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(silu, Silu, CudaSiluFunctor, CudaSiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \ __macro(logsigmoid, LogSigmoid, CudaLogSigmoidFunctor, \
CudaLogSigmoidGradFunctor); \ CudaLogSigmoidGradFunctor); \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
......
...@@ -26,6 +26,23 @@ namespace phi { ...@@ -26,6 +26,23 @@ namespace phi {
const DenseTensor& dout, \ const DenseTensor& dout, \
DenseTensor* dx); DenseTensor* dx);
#define DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(name, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx);
#define DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(name, attr1, attr2) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& dout, \
float attr1, \
float attr2, \
DenseTensor* dx);
#define DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(name) \ #define DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(name) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
...@@ -33,6 +50,14 @@ namespace phi { ...@@ -33,6 +50,14 @@ namespace phi {
const DenseTensor& dout, \ const DenseTensor& dout, \
DenseTensor* dx); DenseTensor* dx);
#define DECLARE_ACTIVATION_GRAD_KERNEL_WITH_ONE_ATTRS_DepOut(name, attr) \
template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \
const DenseTensor& dout, \
float attr, \
DenseTensor* dx);
template <typename T, typename Context> template <typename T, typename Context>
void ReluDoubleGradKernel(const Context& dev_ctx, void ReluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
...@@ -59,34 +84,29 @@ void TanhTripleGradKernel(const Context& dev_ctx, ...@@ -59,34 +84,29 @@ void TanhTripleGradKernel(const Context& dev_ctx,
DenseTensor* d_ddx); DenseTensor* d_ddx);
template <typename T, typename Context> template <typename T, typename Context>
void BReluGradKernel(const Context& dev_ctx, void LeakyReluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout, const DenseTensor& ddx,
float t_min, float alpha,
float t_max, DenseTensor* ddout);
DenseTensor* dx);
template <typename T, typename Context> template <typename T, typename Context>
void LeakyReluGradKernel(const Context& dev_ctx, void EluGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout, const DenseTensor& dout,
float alpha, float alpha,
DenseTensor* dx); DenseTensor* dx);
template <typename T, typename Context> template <typename T, typename Context>
void LeakyReluDoubleGradKernel(const Context& dev_ctx, void EluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx, const DenseTensor& ddx,
float alpha, float alpha,
DenseTensor* dx,
DenseTensor* ddout); DenseTensor* ddout);
template <typename T, typename Context>
void ThresholdedReluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
DenseTensor* dx);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cos);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Tan);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acos); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acos);
...@@ -98,7 +118,17 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh); ...@@ -98,7 +118,17 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Cosh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Asinh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Acosh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh); DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Atanh);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(TanhShrink);
DECLARE_ACTIVATION_GRAD_KERNEL_DepX(Silu);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu); DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Relu);
DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Tanh); DECLARE_ACTIVATION_GRAD_KERNEL_DepOut(Tanh);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu, alpha)
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(ThresholdedRelu, threshold)
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(SoftShrink, lambda)
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(HardShrink, threshold)
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(BRelu, t_min, t_max)
} // namespace phi } // namespace phi
...@@ -24,6 +24,21 @@ namespace phi { ...@@ -24,6 +24,21 @@ namespace phi {
void name##Kernel( \ void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
#define DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(name, attr) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr, \
DenseTensor* out);
#define DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(name, attr1, attr2) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr1, \
float attr2, \
DenseTensor* out);
DECLARE_ACTIVATION_KERNEL(Cos) DECLARE_ACTIVATION_KERNEL(Cos)
DECLARE_ACTIVATION_KERNEL(Tan) DECLARE_ACTIVATION_KERNEL(Tan)
DECLARE_ACTIVATION_KERNEL(Acos) DECLARE_ACTIVATION_KERNEL(Acos)
...@@ -37,24 +52,15 @@ DECLARE_ACTIVATION_KERNEL(Acosh) ...@@ -37,24 +52,15 @@ DECLARE_ACTIVATION_KERNEL(Acosh)
DECLARE_ACTIVATION_KERNEL(Atanh) DECLARE_ACTIVATION_KERNEL(Atanh)
DECLARE_ACTIVATION_KERNEL(Relu) DECLARE_ACTIVATION_KERNEL(Relu)
DECLARE_ACTIVATION_KERNEL(Tanh) DECLARE_ACTIVATION_KERNEL(Tanh)
DECLARE_ACTIVATION_KERNEL(TanhShrink)
DECLARE_ACTIVATION_KERNEL(Silu)
template <typename T, typename Context> DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha)
void BReluKernel(const Context& dev_ctx, DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold)
const DenseTensor& x, DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
float t_min, DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
float t_max, DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha)
DenseTensor* out);
template <typename T, typename Context> DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max)
void LeakyReluKernel(const Context& dev_ctx,
const DenseTensor& x,
float alpha,
DenseTensor* out);
template <typename T, typename Context>
void ThresholdedReluKernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -21,18 +21,18 @@ limitations under the License. */ ...@@ -21,18 +21,18 @@ limitations under the License. */
namespace phi { namespace phi {
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(name, functor_class) \ #define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
const DenseTensor& dout, \ const DenseTensor& dout, \
DenseTensor* dx) { \ DenseTensor* dx) { \
functor_class<T> functor; \ funcs::functor_class<T> functor; \
ActivationGradImpl<T, Context, functor_class<T>>( \ ActivationGradImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, &x, nullptr, &dout, dx, functor); \ dev_ctx, &x, nullptr, &dout, dx, functor); \
} }
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX( \ #define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \
name, functor_class, attr) \ name, functor_class, attr) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
...@@ -40,14 +40,14 @@ namespace phi { ...@@ -40,14 +40,14 @@ namespace phi {
const DenseTensor& dout, \ const DenseTensor& dout, \
float attr, \ float attr, \
DenseTensor* dx) { \ DenseTensor* dx) { \
functor_class<T> functor; \ funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \ auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \ *(attrs[0].second) = attr; \
ActivationGradImpl<T, Context, functor_class<T>>( \ ActivationGradImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, &x, nullptr, &dout, dx, functor); \ dev_ctx, &x, nullptr, &dout, dx, functor); \
} }
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX( \ #define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX( \
name, functor_class, attr1, attr2) \ name, functor_class, attr1, attr2) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
...@@ -56,26 +56,26 @@ namespace phi { ...@@ -56,26 +56,26 @@ namespace phi {
float attr1, \ float attr1, \
float attr2, \ float attr2, \
DenseTensor* dx) { \ DenseTensor* dx) { \
functor_class<T> functor; \ funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \ auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \ *(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \ *(attrs[1].second) = attr2; \
ActivationGradImpl<T, Context, functor_class<T>>( \ ActivationGradImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, &x, nullptr, &dout, dx, functor); \ dev_ctx, &x, nullptr, &dout, dx, functor); \
} }
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(name, functor_class) \ #define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \ const DenseTensor& out, \
const DenseTensor& dout, \ const DenseTensor& dout, \
DenseTensor* dx) { \ DenseTensor* dx) { \
functor_class<T> functor; \ funcs::functor_class<T> functor; \
ActivationGradImpl<T, Context, functor_class<T>>( \ ActivationGradImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, nullptr, &out, &dout, dx, functor); \ dev_ctx, nullptr, &out, &dout, dx, functor); \
} }
#define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepOut( \ #define DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \
name, functor_class, attr) \ name, functor_class, attr) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
...@@ -83,39 +83,78 @@ namespace phi { ...@@ -83,39 +83,78 @@ namespace phi {
const DenseTensor& dout, \ const DenseTensor& dout, \
float attr, \ float attr, \
DenseTensor* dx) { \ DenseTensor* dx) { \
functor_class<T> functor; \ funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \ auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \ *(attrs[0].second) = attr; \
ActivationGradImpl<T, Context, functor_class<T>>( \ ActivationGradImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, nullptr, &out, &dout, dx, functor); \ dev_ctx, nullptr, &out, &dout, dx, functor); \
} }
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, funcs::CosGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CosGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, funcs::TanGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, TanGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, funcs::AcosGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, AcosGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, funcs::SinGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Sin, SinGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, funcs::AsinGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Asin, AsinGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, funcs::AtanGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atan, AtanGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, funcs::SinhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Sinh, SinhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, funcs::CoshGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Cosh, CoshGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, funcs::AsinhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Asinh, AsinhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, funcs::AcoshGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acosh, AcoshGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, funcs::AtanhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, TanhShrinkGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, funcs::ReluGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, SiluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DepOut(Tanh, funcs::TanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu, DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor);
funcs::LeakyReluGradFunctor,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
LeakyReluGradFunctor,
alpha); alpha);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX( DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu,
ThresholdedRelu, funcs::ThresholdedReluGradFunctor, threshold); ThresholdedReluGradFunctor,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(BRelu, DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
funcs::BReluGradFunctor, SoftShrinkGradFunctor,
lambda);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink,
HardShrinkGradFunctor,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
BReluGradFunctor,
t_min, t_min,
t_max); t_max);
template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
float alpha,
DenseTensor* dx) {
dev_ctx.template Alloc<T>(dx);
auto x_flatten =
EigenVector<T>::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "elu_grad"));
auto out_flatten = EigenVector<T>::Flatten(
GET_DATA_SAFELY(&out, "Input", "Out", "elu_grad"));
auto dout_flatten = EigenVector<T>::Flatten(
GET_DATA_SAFELY(&dout, "Input", "dOut", "elu_grad"));
auto dx_flatten =
EigenVector<T>::Flatten(GET_DATA_SAFELY(dx, "Output", "dX", "elu_grad"));
auto* place = dev_ctx.eigen_device();
if (alpha > 0) {
funcs::ELUGradFunctor<T> functor;
functor.alpha = alpha;
functor(*place, x_flatten, out_flatten, dout_flatten, dx_flatten);
} else {
funcs::ELUGradNegativeAlphaFunctor<T> functor;
functor.alpha = alpha;
functor(*place, x_flatten, out_flatten, dout_flatten, dx_flatten);
}
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
...@@ -144,6 +183,11 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(brelu_grad, BReluGradKernel) ...@@ -144,6 +183,11 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(brelu_grad, BReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
ThresholdedReluGradKernel) ThresholdedReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad, PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad,
ReluDoubleGradKernel) ReluDoubleGradKernel)
...@@ -151,6 +195,7 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad, ...@@ -151,6 +195,7 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
TanhDoubleGradKernel) TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad, PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel) LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_KERNEL(tanh_triple_grad, PD_REGISTER_KERNEL(tanh_triple_grad,
CPU, CPU,
......
...@@ -23,8 +23,9 @@ namespace phi { ...@@ -23,8 +23,9 @@ namespace phi {
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##Kernel( \ void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
functor_class functor; \ funcs::functor_class<T> functor; \
ActivationImpl<T, Context, functor_class>(dev_ctx, x, out, functor); \ ActivationImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
} }
#define DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \ #define DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
...@@ -33,10 +34,11 @@ namespace phi { ...@@ -33,10 +34,11 @@ namespace phi {
const DenseTensor& x, \ const DenseTensor& x, \
float attr, \ float attr, \
DenseTensor* out) { \ DenseTensor* out) { \
functor_class<T> functor; \ funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \ auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \ *(attrs[0].second) = attr; \
ActivationImpl<T, Context, functor_class<T>>(dev_ctx, x, out, functor); \ ActivationImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
} }
#define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \ #define DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS( \
...@@ -47,50 +49,63 @@ namespace phi { ...@@ -47,50 +49,63 @@ namespace phi {
float attr1, \ float attr1, \
float attr2, \ float attr2, \
DenseTensor* out) { \ DenseTensor* out) { \
functor_class<T> functor; \ funcs::functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \ auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \ *(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \ *(attrs[1].second) = attr2; \
ActivationImpl<T, Context, functor_class<T>>(dev_ctx, x, out, functor); \ ActivationImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
} }
DEFINE_CPU_ACTIVATION_KERNEL(Sin, funcs::SinFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Sin, SinFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Cos, funcs::CosFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Cos, CosFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Tan, funcs::TanFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Tan, TanFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Asin, funcs::AsinFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Asin, AsinFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Atan, funcs::AtanFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Atan, AtanFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Acos, funcs::AcosFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Acos, AcosFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Sinh, funcs::SinhFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Sinh, SinhFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Cosh, funcs::CoshFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Cosh, CoshFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Asinh, funcs::AsinhFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Asinh, AsinhFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Acosh, funcs::AcoshFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Acosh, AcoshFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Atanh, funcs::AtanhFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Atanh, AtanhFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Relu, funcs::ReluCPUFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Relu, ReluCPUFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Tanh, funcs::TanhFunctor<T>) DEFINE_CPU_ACTIVATION_KERNEL(Tanh, TanhFunctor)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, funcs::LeakyReluFunctor, alpha) DEFINE_CPU_ACTIVATION_KERNEL(TanhShrink, TanhShrinkFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Silu, SiluFunctor)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
funcs::ThresholdedReluFunctor, ThresholdedReluFunctor,
threshold) threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, funcs::BReluFunctor, t_min, t_max) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max)
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ #define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func##Kernel, float, double) {} PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {}
PD_REGISTER_ACTIVATION_KERNEL(sin, Sin) PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, Cos) PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL(tan, Tan) PD_REGISTER_ACTIVATION_KERNEL(tan, TanKernel)
PD_REGISTER_ACTIVATION_KERNEL(acos, Acos) PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel)
PD_REGISTER_ACTIVATION_KERNEL(asin, Asin) PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel)
PD_REGISTER_ACTIVATION_KERNEL(atan, Atan) PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel)
PD_REGISTER_ACTIVATION_KERNEL(sinh, Sinh) PD_REGISTER_ACTIVATION_KERNEL(sinh, SinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(cosh, Cosh) PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(asinh, Asinh) PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, Acosh) PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, Atanh) PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, Tanh) PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(brelu, BRelu) PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyRelu) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedRelu) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
...@@ -29,11 +29,13 @@ ...@@ -29,11 +29,13 @@
#include <type_traits> #include <type_traits>
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/extensions.h"
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
#define __forceinline__ __inline__ #define __forceinline__ __inline__
...@@ -780,6 +782,236 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -780,6 +782,236 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x - x.tanh();
}
};
template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x.tanh() * x.tanh());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
out.device(d) = x * (temp1 || temp2).template cast<T>();
}
};
template <typename T>
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = x < static_cast<T>(threshold * -1.f);
auto temp2 = x > static_cast<T>(threshold);
dx.device(d) = dout * (temp1 || temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
// otherwise
template <typename T>
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>();
auto temp2 = (x < -lambdaT).template cast<T>();
out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
}
};
template <typename T>
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>();
auto temp2 = (x < -lambdaT).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct ELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)), x);
}
};
template <typename T>
struct ELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
dx.device(d) = (out > static_cast<T>(0))
.select(dout, dout * (out + static_cast<T>(alpha)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
dx.device(d) = (x > static_cast<T>(0))
.select(dout, dout * static_cast<T>(alpha) * x.exp());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* ddX,
DenseTensor* ddOut,
const DenseTensor* dOut,
DenseTensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad"));
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad"));
if (dX) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// silu(x) = x / (1 + exp(-x))
template <typename T>
struct SiluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
out.device(d) = x * temp;
}
};
// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x}))
template <typename T>
struct SiluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) + (-x).exp(); // 1+e^(-x)
auto temp2 = x * (-x).exp(); // x*e^(-x)
dx.device(d) = dout * ((static_cast<T>(1) / temp1) *
(static_cast<T>(1) + (temp2 / temp1)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename T> template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> { struct CudaReluFunctor : public BaseActivationFunctor<T> {
...@@ -1218,6 +1450,209 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -1218,6 +1450,209 @@ struct CudaLeakyReluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct CudaSoftShrinkFunctor : public BaseActivationFunctor<T> {
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
// softshrink(x) = x - lambda, if x > lambda;
// x + lambda, if x < -lambda;
// 0, otherwise.
__device__ __forceinline__ T operator()(const T x) const {
T l = static_cast<T>(lambda);
T temp1 = static_cast<T>(x > l);
T temp2 = static_cast<T>(x < -l);
return temp1 * (x - l) + temp2 * (x + l);
}
};
template <typename T>
struct CudaSoftShrinkGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float lambda;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
// dx = dout, if x > lambda or x < -lambda else 0
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T l = static_cast<T>(lambda);
return (x >= -l && x <= l) ? zero : dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaTanhShrinkFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// tanhshrink(x) = x - tanh(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(x - tanh(x));
}
};
template <typename T>
struct CudaTanhShrinkGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// dx = dout * tanh(x)^2
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(dout * tanh(x) * tanh(x));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaHardShrinkFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// hadrshrink(x) = (x > -threshold && x < threshold) ? 0 : x
__device__ __forceinline__ T operator()(const T x) const {
T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : x;
}
};
template <typename T>
struct CudaHardShrinkGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = (x > -threshold && x < threshold) ? 0 : dout
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T t = static_cast<T>(threshold);
return (x > -t && x < t) ? zero : dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaELUFunctor : public BaseActivationFunctor<T> {
using CT = typename phi::dtype::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// elu(x) = x, if x > 0
// elu(x) = alpha * (e^x - 1), if x <= 0
__device__ __forceinline__ T operator()(const T arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x) - one);
CT res = x > zero ? x : temp;
return static_cast<T>(res);
}
};
template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
__device__ __forceinline__ T operator()(T arg_dout, T arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType a = static_cast<MPType>(alpha);
MPType out_pos = static_cast<MPType>(out > zero);
MPType out_neg = static_cast<MPType>(out <= zero);
return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_out,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType x_pos = static_cast<MPType>(x > zero);
MPType x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSiluFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// silu(x) = x / (1 + exp(-x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(x / (one + exp(-x)));
}
};
template <typename T>
struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType temp = one / (one + exp(-x));
return static_cast<T>(dout * (temp * (one + x * (one - temp))));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
#endif #endif
} // namespace funcs } // namespace funcs
......
...@@ -73,7 +73,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ...@@ -73,7 +73,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
} }
} }
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(name, functor_class) \ #define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& x, \ const DenseTensor& x, \
...@@ -84,7 +84,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ...@@ -84,7 +84,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, &x, nullptr, &dout, dx, functor); \ dev_ctx, &x, nullptr, &dout, dx, functor); \
} }
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX( \ #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \
name, functor_class, attr) \ name, functor_class, attr) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
...@@ -99,7 +99,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ...@@ -99,7 +99,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, &x, nullptr, &dout, dx, functor); \ dev_ctx, &x, nullptr, &dout, dx, functor); \
} }
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX( \ #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX( \
name, functor_class, attr1, attr2) \ name, functor_class, attr1, attr2) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
...@@ -116,7 +116,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ...@@ -116,7 +116,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, &x, nullptr, &dout, dx, functor); \ dev_ctx, &x, nullptr, &dout, dx, functor); \
} }
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(name, functor_class) \ #define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
const DenseTensor& out, \ const DenseTensor& out, \
...@@ -127,7 +127,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ...@@ -127,7 +127,7 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, nullptr, &out, &dout, dx, functor); \ dev_ctx, nullptr, &out, &dout, dx, functor); \
} }
#define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepOut( \ #define DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \
name, functor_class, attr) \ name, functor_class, attr) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
...@@ -142,32 +142,62 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ...@@ -142,32 +142,62 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, nullptr, &out, &dout, dx, functor); \ dev_ctx, nullptr, &out, &dout, dx, functor); \
} }
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Relu, CudaReluGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, CudaReluGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepOut(Tanh, CudaTanhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, CudaTanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cos, CudaCosGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CudaCosGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Tan, CudaTanGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, CudaTanGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acos, CudaAcosGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, CudaAcosGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sin, CudaSinGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Sin, CudaSinGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asin, CudaAsinGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Asin, CudaAsinGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atan, CudaAtanGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atan, CudaAtanGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Sinh, CudaSinhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Sinh, CudaSinhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Cosh, CudaCoshGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Cosh, CudaCoshGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Asinh, CudaAsinhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Asinh, CudaAsinhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Acosh, CudaAcoshGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acosh, CudaAcoshGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DepX(Atanh, CudaAtanhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, CudaTanhShrinkGradFunctor);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(LeakyRelu, DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, CudaSiluGradFunctor);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
CudaLeakyReluGradFunctor, CudaLeakyReluGradFunctor,
alpha); alpha);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DepX(ThresholdedRelu, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu,
CudaThresholdedReluGradFunctor, CudaThresholdedReluGradFunctor,
threshold); threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
CudaSoftShrinkGradFunctor,
lambda);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink,
CudaHardShrinkGradFunctor,
threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DepX(BRelu, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
CudaBReluGradFunctor, CudaBReluGradFunctor,
t_min, t_min,
t_max); t_max);
template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
float alpha,
DenseTensor* dx) {
dev_ctx.template Alloc<T>(dx);
std::vector<const DenseTensor*> ins = {&dout, &out};
std::vector<DenseTensor*> outs = {dx};
if (alpha > 0) {
funcs::CudaELUGradFunctor<T> functor;
functor.alpha = alpha;
funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
} else {
funcs::CudaELUGradNegativeAlphaFunctor<T> functor;
functor.alpha = alpha;
ins.push_back(&x);
funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, functor);
}
}
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -234,3 +264,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad, ...@@ -234,3 +264,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel) LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
ThresholdedReluGradKernel) ThresholdedReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
...@@ -42,8 +42,9 @@ void ActivationGPUImpl(const Context& dev_ctx, ...@@ -42,8 +42,9 @@ void ActivationGPUImpl(const Context& dev_ctx,
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##Kernel( \ void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
functor_class functor; \ funcs::functor_class<T> functor; \
ActivationGPUImpl<T, Context, functor_class>(dev_ctx, x, out, functor); \ ActivationGPUImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, x, out, functor); \
} }
#define DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \ #define DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
...@@ -75,24 +76,31 @@ void ActivationGPUImpl(const Context& dev_ctx, ...@@ -75,24 +76,31 @@ void ActivationGPUImpl(const Context& dev_ctx,
dev_ctx, x, out, functor); \ dev_ctx, x, out, functor); \
} }
DEFINE_GPU_ACTIVATION_KERNEL(Cos, funcs::CudaCosFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Cos, CudaCosFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Tan, funcs::CudaTanFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Tan, CudaTanFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Acos, funcs::CudaAcosFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Acos, CudaAcosFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sin, funcs::CudaSinFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Sin, CudaSinFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Asin, funcs::CudaAsinFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Asin, CudaAsinFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Atan, funcs::CudaAtanFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Atan, CudaAtanFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Sinh, funcs::CudaSinhFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Sinh, CudaSinhFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Cosh, funcs::CudaCoshFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Cosh, CudaCoshFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Asinh, funcs::CudaAsinhFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Asinh, CudaAsinhFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Acosh, funcs::CudaAcoshFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Acosh, CudaAcoshFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Atanh, funcs::CudaAtanhFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Atanh, CudaAtanhFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Relu, funcs::CudaReluFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Relu, CudaReluFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Tanh, funcs::CudaTanhFunctor<T>) DEFINE_GPU_ACTIVATION_KERNEL(Tanh, CudaTanhFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(TanhShrink, CudaTanhShrinkFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Silu, CudaSiluFunctor)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
CudaThresholdedReluFunctor, CudaThresholdedReluFunctor,
threshold) threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
CudaHardShrinkFunctor,
threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max)
...@@ -142,3 +150,8 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel) ...@@ -142,3 +150,8 @@ PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel) PD_REGISTER_ACTIVATION_KERNEL(brelu, BReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel) PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
...@@ -202,4 +202,24 @@ void TanhTripleGradKernel(const Context& dev_ctx, ...@@ -202,4 +202,24 @@ void TanhTripleGradKernel(const Context& dev_ctx,
d_ddx); // output d_ddx); // output
} }
template <typename T, typename Context>
void EluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
float alpha,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
funcs::ELUGradGradFunctor<T> functor;
functor.alpha = alpha;
functor(dev_ctx, &x, &ddx, ddout, &dout, dx);
}
} // namespace phi } // namespace phi
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace phi { namespace phi {
#define DefineActGradDepXOpArgMap(func_name, op_name, attrs) \ #define DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \ KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \ const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \ return KernelSignature(op_name "_grad", \
...@@ -25,7 +25,7 @@ namespace phi { ...@@ -25,7 +25,7 @@ namespace phi {
{GradVarName("X")}); \ {GradVarName("X")}); \
} }
#define DefineActGradDepOutOpArgMap(func_name, op_name, attrs) \ #define DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \ KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \ const ArgumentMappingContext& ctx) { \
return KernelSignature(op_name "_grad", \ return KernelSignature(op_name "_grad", \
...@@ -36,25 +36,29 @@ namespace phi { ...@@ -36,25 +36,29 @@ namespace phi {
#define comma , #define comma ,
DefineActGradDepXOpArgMap(Cos, "cos", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cos, "cos", ); // NOLINT
DefineActGradDepXOpArgMap(Tan, "tan", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Tan, "tan", ); // NOLINT
DefineActGradDepXOpArgMap(Acos, "acos", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Acos, "acos", ); // NOLINT
DefineActGradDepXOpArgMap(Sin, "sin", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Sin, "sin", ); // NOLINT
DefineActGradDepXOpArgMap(Asin, "asin", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Asin, "asin", ); // NOLINT
DefineActGradDepXOpArgMap(Atan, "atan", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Atan, "atan", ); // NOLINT
DefineActGradDepXOpArgMap(Sinh, "sinh", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Sinh, "sinh", ); // NOLINT
DefineActGradDepXOpArgMap(Cosh, "cosh", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cosh, "cosh", ); // NOLINT
DefineActGradDepXOpArgMap(Asinh, "asinh", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Asinh, "asinh", ); // NOLINT
DefineActGradDepXOpArgMap(Acosh, "acosh", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Acosh, "acosh", ); // NOLINT
DefineActGradDepXOpArgMap(Atanh, "atanh", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Atanh, "atanh", ); // NOLINT
DefineActGradDepXOpArgMap(BRelu, "brelu", "t_min" comma "t_max"); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(BRelu, "brelu", "t_min" comma "t_max");
DefineActGradDepXOpArgMap(LeakyRelu, "leaky_relu", "alpha"); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(LeakyRelu, "leaky_relu", "alpha");
DefineActGradDepXOpArgMap(ThresholdedRelu, DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(ThresholdedRelu,
"thresholded_relu", "thresholded_relu",
"threshold"); // NOLINT "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(SoftShrink, "soft_shrink", "lambda");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardShrink, "hard_shrink", "threshold");
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(TanhShrink, "tanh_shrink", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Silu, "silu", ); // NOLINT
DefineActGradDepOutOpArgMap(Relu, "relu", ); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT
DefineActGradDepOutOpArgMap(Tanh, "tanh", ); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT
KernelSignature ReluDoubleGradOpArgumentMapping( KernelSignature ReluDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
...@@ -85,11 +89,31 @@ KernelSignature LeakyReluOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -85,11 +89,31 @@ KernelSignature LeakyReluOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("leaky_relu", {"X"}, {"alpha"}, {"Out"}); return KernelSignature("leaky_relu", {"X"}, {"alpha"}, {"Out"});
} }
KernelSignature EluOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("elu", {"X"}, {"alpha"}, {"Out"});
}
KernelSignature EluGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("elu_grad",
{"X", "Out", GradVarName("Out")},
{"alpha"},
{GradVarName("X")});
}
KernelSignature EluDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"elu_double_grad", {"X", "DOut", "DDX"}, {"alpha"}, {"DX", "DDOut"});
}
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad); PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(tanh_grad_grad, tanh_double_grad); PD_REGISTER_BASE_KERNEL_NAME(tanh_grad_grad, tanh_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(leaky_relu_grad_grad, leaky_relu_double_grad); PD_REGISTER_BASE_KERNEL_NAME(leaky_relu_grad_grad, leaky_relu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(softshrink, soft_shrink);
PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad);
PD_REGISTER_BASE_KERNEL_NAME(elu_grad_grad, elu_double_grad);
PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping);
...@@ -118,3 +142,13 @@ PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad_grad, ...@@ -118,3 +142,13 @@ PD_REGISTER_ARG_MAPPING_FN(leaky_relu_grad_grad,
phi::LeakyReluDoubleGradOpArgumentMapping); phi::LeakyReluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(thresholded_relu_grad, PD_REGISTER_ARG_MAPPING_FN(thresholded_relu_grad,
phi::ThresholdedReluGradOpArgumentMapping); phi::ThresholdedReluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softshrink_grad,
phi::SoftShrinkGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_shrink_grad,
phi::HardShrinkGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tanh_shrink_grad,
phi::TanhShrinkGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu, phi::EluOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu_grad, phi::EluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elu_grad_grad, phi::EluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(silu_grad, phi::SiluGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册