提交 21beb082 编写于 作者: P phlrain

add some grad kernel; test=develop

上级 4be77e53
...@@ -440,19 +440,6 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> { ...@@ -440,19 +440,6 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct Expm1GradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out + dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
template <typename T> template <typename T>
......
...@@ -140,18 +140,6 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> { ...@@ -140,18 +140,6 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return dout * out + dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T> template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> { struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
......
...@@ -104,6 +104,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor); ...@@ -104,6 +104,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, TanhShrinkGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, TanhShrinkGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, SiluGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, SiluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, Expm1GradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor);
...@@ -214,3 +215,11 @@ PD_REGISTER_KERNEL(exp_grad, ...@@ -214,3 +215,11 @@ PD_REGISTER_KERNEL(exp_grad,
double, double,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(expm1_grad,
CPU,
ALL_LAYOUT,
phi::Expm1GradKernel,
float,
double,
phi::dtype::float16) {}
...@@ -120,12 +120,12 @@ PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel) ...@@ -120,12 +120,12 @@ PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
// PD_REGISTER_ACTIVATION_KERNEL(mish, Mish) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel) PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel) PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
// PD_REGISTER_ACTIVATION_KERNEL(softplus, Softplus) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
......
...@@ -157,9 +157,10 @@ struct LogitFunctor { ...@@ -157,9 +157,10 @@ struct LogitFunctor {
} }
}; };
// mish(x) = x * tanh(softplus(x)) // // mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold // // softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise // // = ln(1 + exp(x)), otherwise
template <typename T> template <typename T>
struct MishFunctor : public BaseActivationFunctor<T> { struct MishFunctor : public BaseActivationFunctor<T> {
float threshold; float threshold;
...@@ -168,7 +169,7 @@ struct MishFunctor : public BaseActivationFunctor<T> { ...@@ -168,7 +169,7 @@ struct MishFunctor : public BaseActivationFunctor<T> {
} }
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) { void operator()(Device d, X x, Out out) const {
auto sp = (x > static_cast<T>(threshold)) auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log()); .select(x, (static_cast<T>(1) + x.exp()).log());
out.device(d) = x * sp.tanh(); out.device(d) = x * sp.tanh();
...@@ -244,20 +245,41 @@ struct RsqrtFunctor : public BaseActivationFunctor<T> { ...@@ -244,20 +245,41 @@ struct RsqrtFunctor : public BaseActivationFunctor<T> {
} }
}; };
// For numerical stability, using the following formula instead of softplus(x) = // // For numerical stability, using the following formula instead of
// log(1 + exp(x)) // softplus(x) =
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = // // log(1 + exp(x))
// 1, threshold = 20 by default), otherwise x // // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <=
// threshold(beta =
// // 1, threshold = 20 by default), otherwise x
// template <typename T>
// struct SoftplusFunctor : public BaseActivationFunctor<T> {
// float beta;
// float threshold;
// typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
// return {{"beta", &beta}, {"threshold", &threshold}};
// }
// template <typename Device, typename X, typename Out>
// void operator()(Device d, X x, Out out) {
// auto x_beta = static_cast<T>(beta) * x;
// out.device(d) = (x_beta > static_cast<T>(threshold))
// .select(x,
// (static_cast<T>(1) + x_beta.exp()).log() /
// static_cast<T>(beta));
// }
// };
template <typename T> template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> { struct SoftplusFunctor : public BaseActivationFunctor<T> {
float beta; float beta;
float threshold; float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}}; return {{"beta", &beta}, {"threshold", &threshold}};
} }
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) { void operator()(Device d, X x, Out out) const {
auto x_beta = static_cast<T>(beta) * x; auto x_beta = static_cast<T>(beta) * x;
out.device(d) = (x_beta > static_cast<T>(threshold)) out.device(d) = (x_beta > static_cast<T>(threshold))
.select(x, .select(x,
...@@ -602,6 +624,22 @@ struct Expm1Functor : public BaseActivationFunctor<T> { ...@@ -602,6 +624,22 @@ struct Expm1Functor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct Expm1GradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out + dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
template <typename T> template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> { struct ReluCPUFunctor : public BaseActivationFunctor<T> {
...@@ -822,11 +860,10 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> { ...@@ -822,11 +860,10 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// softsign(x) = x / (1 + |x|)
template <typename T> template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> { struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) { void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + x.abs()); out.device(d) = x / (static_cast<T>(1) + x.abs());
} }
}; };
...@@ -1264,6 +1301,18 @@ struct CudaExpm1Functor : public BaseActivationFunctor<T> { ...@@ -1264,6 +1301,18 @@ struct CudaExpm1Functor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return dout * out + dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T> template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> { struct CudaSinFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type; using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......
...@@ -158,6 +158,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor); ...@@ -158,6 +158,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, CudaTanhShrinkGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, CudaTanhShrinkGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, CudaSiluGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, CudaSiluGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, CudaExpGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, CudaExpGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, CudaExpm1GradFunctor);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
CudaLeakyReluGradFunctor, CudaLeakyReluGradFunctor,
...@@ -274,9 +275,18 @@ PD_REGISTER_KERNEL(exp_grad, ...@@ -274,9 +275,18 @@ PD_REGISTER_KERNEL(exp_grad,
double, double,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_KERNEL(expm1_grad,
GPU,
ALL_LAYOUT,
phi::Expm1GradKernel,
float,
double,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册