提交 21beb082 编写于 作者: P phlrain

add some grad kernel; test=develop

上级 4be77e53
......@@ -440,19 +440,6 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct Expm1GradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out + dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// relu(x) = max(x, 0)
template <typename T>
......
......@@ -140,18 +140,6 @@ struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return dout * out + dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
......
......@@ -104,6 +104,7 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, TanhShrinkGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, SiluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, Expm1GradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor);
......@@ -214,3 +215,11 @@ PD_REGISTER_KERNEL(exp_grad,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(expm1_grad,
CPU,
ALL_LAYOUT,
phi::Expm1GradKernel,
float,
double,
phi::dtype::float16) {}
......@@ -120,12 +120,12 @@ PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
// PD_REGISTER_ACTIVATION_KERNEL(mish, Mish)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(rsqrt, RsqrtKernel)
// PD_REGISTER_ACTIVATION_KERNEL(softplus, Softplus)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_KERNEL(
......
......@@ -157,9 +157,10 @@ struct LogitFunctor {
}
};
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
// // mish(x) = x * tanh(softplus(x))
// // softplus(x) = x, if x > threshold
// // = ln(1 + exp(x)), otherwise
template <typename T>
struct MishFunctor : public BaseActivationFunctor<T> {
float threshold;
......@@ -168,7 +169,7 @@ struct MishFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
void operator()(Device d, X x, Out out) const {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
out.device(d) = x * sp.tanh();
......@@ -244,20 +245,41 @@ struct RsqrtFunctor : public BaseActivationFunctor<T> {
}
};
// For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
// 1, threshold = 20 by default), otherwise x
// // For numerical stability, using the following formula instead of
// softplus(x) =
// // log(1 + exp(x))
// // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <=
// threshold(beta =
// // 1, threshold = 20 by default), otherwise x
// template <typename T>
// struct SoftplusFunctor : public BaseActivationFunctor<T> {
// float beta;
// float threshold;
// typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
// return {{"beta", &beta}, {"threshold", &threshold}};
// }
// template <typename Device, typename X, typename Out>
// void operator()(Device d, X x, Out out) {
// auto x_beta = static_cast<T>(beta) * x;
// out.device(d) = (x_beta > static_cast<T>(threshold))
// .select(x,
// (static_cast<T>(1) + x_beta.exp()).log() /
// static_cast<T>(beta));
// }
// };
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
void operator()(Device d, X x, Out out) const {
auto x_beta = static_cast<T>(beta) * x;
out.device(d) = (x_beta > static_cast<T>(threshold))
.select(x,
......@@ -602,6 +624,22 @@ struct Expm1Functor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct Expm1GradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out + dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// relu(x) = max(x, 0)
template <typename T>
struct ReluCPUFunctor : public BaseActivationFunctor<T> {
......@@ -822,11 +860,10 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
......@@ -1264,6 +1301,18 @@ struct CudaExpm1Functor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaExpm1GradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return dout * out + dout;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaSinFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......
......@@ -158,6 +158,7 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, CudaTanhShrinkGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, CudaSiluGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, CudaExpGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, CudaExpm1GradFunctor);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
CudaLeakyReluGradFunctor,
......@@ -274,9 +275,18 @@ PD_REGISTER_KERNEL(exp_grad,
double,
int,
int64_t) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_KERNEL(expm1_grad,
GPU,
ALL_LAYOUT,
phi::Expm1GradKernel,
float,
double,
phi::dtype::float16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册