提交 77812c05 编写于 作者: P phlrain

move gpu kernel; test=develop

上级 21beb082
......@@ -1641,9 +1641,7 @@ REGISTER_OPERATOR(logit, ops::LogitOp, ops::LogitOpMaker,
ops::LogitGradOpMaker<paddle::framework::OpDesc>,
ops::LogitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(logit_grad, ops::LogitGradOp);
REGISTER_OP_CPU_KERNEL(
logit_grad, ops::LogitGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogitGradKernel<paddle::platform::CPUDeviceContext, double>);
/* ========================================================================== */
/* ======================== celu register ============================
......@@ -1692,7 +1690,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<float>>,
......@@ -1720,7 +1717,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
rsqrt_grad_grad,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
......@@ -1749,25 +1745,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<double>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<int>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::SquareFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<int>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::SquareGradFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
square_grad_grad,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
......@@ -1829,16 +1806,7 @@ REGISTER_OPERATOR(
paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::Expm1GradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(expm1_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
expm1_grad, ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<float>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<double>>,
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext,
ops::Expm1GradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== Log register ==================================*/
......
......@@ -287,6 +287,16 @@ USE_PHI_FUNCTOR(Silu)
USE_PHI_FUNCTOR(ELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(ELU)
USE_PHI_FUNCTOR(Expm1)
USE_PHI_FUNCTOR(Mish)
USE_PHI_FUNCTOR(STanh)
USE_PHI_FUNCTOR(Reciprocal)
USE_PHI_FUNCTOR(Square)
USE_PHI_FUNCTOR(Sqrt)
USE_PHI_FUNCTOR(Rsqrt)
USE_PHI_FUNCTOR(Softplus)
USE_PHI_FUNCTOR(Softsign)
template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
......@@ -454,26 +464,62 @@ template <typename T>
using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
template <typename T>
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0.5) * dout / out;
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out;
}
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
......@@ -519,19 +565,6 @@ struct RoundFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(-1) * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// log(x) = natural logarithm of x
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
......@@ -614,17 +647,6 @@ struct Log1pGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(2) * x;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
......@@ -706,66 +728,6 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// For numerical stability, using the following formula instead of
// d(softplus(x))/dx = 1 / (1 + exp(-x))
// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
// = 1, threshold = 20 by default), otherwise x
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
auto x_beta = static_cast<T>(beta) * x;
dx.device(d) =
(x_beta > static_cast<T>(threshold))
.select(dout, dout / (static_cast<T>(1) + (-x_beta).exp()));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
template <typename T>
struct MishGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
auto gsp = static_cast<T>(1) - (-sp).exp();
auto tsp = sp.tanh();
dx.device(d) = dout * (tsp + x * (static_cast<T>(1) - tsp * tsp) * gsp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
dx.device(d) =
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SoftReluFunctor : public BaseActivationFunctor<T> {
float threshold;
......@@ -909,38 +871,6 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct LogitGradFunctor {
template <typename Device, typename X, typename dOut, typename dX, typename P>
void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const {
// logit(x)' = 1/(x*(1-x))
dx.device(d) =
(x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
.select(p.constant(static_cast<T>(0)),
dout * (static_cast<T>(1) / ((static_cast<T>(1) - x) * x)));
}
};
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto a = static_cast<T>(scale_a);
auto b = static_cast<T>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
float slope;
......@@ -1071,68 +1001,6 @@ struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
......@@ -1718,24 +1586,7 @@ class PowGradKernel
template <typename DeviceContext, typename T>
class LogitGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto eps = context.Attr<float>("eps");
dx->mutable_data<T>(dout->place());
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_dout = framework::EigenVector<T>::Flatten(*dout);
auto eigen_dx = framework::EigenVector<T>::Flatten(*dx);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto eigen_p = framework::EigenVector<T>::Flatten(*x);
LogitGradFunctor<T> functor;
functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps);
}
void Compute(const framework::ExecutionContext& context) const override {}
};
template <typename T>
......@@ -1776,17 +1627,12 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
__macro(log2, Log2, Log2Functor, Log2GradFunctor); \
__macro(log10, Log10, Log10Functor, Log10GradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
HardSigmoidGradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
__macro(mish, Mish, MishFunctor, MishGradFunctor); \
__macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor);
......@@ -128,18 +128,6 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
// dx = -dout * out^2
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return -dout * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaLogFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
......@@ -161,46 +149,6 @@ struct CudaLogGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
T two = static_cast<T>(2.0f);
// dx = dout * 2 * x
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout * two * x;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
T one_half = static_cast<T>(0.5f);
// dx = dout * 0.5 / out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return one_half * dout / out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f);
// dx = -0.5 * dout * out^3
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return minus_one_half * dout * out * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaLog1pFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
......@@ -320,69 +268,6 @@ struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
// dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b);
MPType temp = tanh(a * x);
return static_cast<T>(dout * a * b * (one - temp * temp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
// dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta;
return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout / (1 + abs(x))^2
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T temp = one + abs(x);
return dout / (temp * temp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaRelu6Functor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
......@@ -506,34 +391,6 @@ struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
MPType gsp =
(x > static_cast<MPType>(threshold)) ? one : one / (one + exp(-x));
MPType tsp = tanh(sp);
return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
......@@ -831,8 +688,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, CudaSqrtFunctor,
CudaSqrtGradFunctor);
REGISTER_OP_CUDA_KERNEL(
sqrt_grad_grad,
......@@ -848,8 +703,6 @@ REGISTER_OP_CUDA_KERNEL(
/* =========================== rsqrt register =============================
*/
REGISTER_ACTIVATION_CUDA_KERNEL(rsqrt, Rsqrt, CudaRsqrtFunctor,
CudaRsqrtGradFunctor);
REGISTER_OP_CUDA_KERNEL(
rsqrt_grad_grad,
......@@ -862,8 +715,6 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* =========================== square register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL_INT(square, Square, CudaSquareFunctor,
CudaSquareGradFunctor);
REGISTER_OP_CUDA_KERNEL(
square_grad_grad,
......@@ -900,53 +751,12 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================== logit register ============================ */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
logit_grad,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
/* ========================================================================== */
/* ========================== exp register ============================ */
REGISTER_OP_CUDA_KERNEL(
exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpFunctor<float>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpFunctor<double>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int>>,
ops::ActivationKernel<plat::CUDADeviceContext, ops::ExpFunctor<int64_t>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
exp_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<int>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<int64_t>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== expm1 register ============================ */
REGISTER_OP_CUDA_KERNEL(
expm1, ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1Functor<float>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1Functor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1Functor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
expm1_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1GradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1GradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaExpm1GradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== Log register ==================================*/
......@@ -969,15 +779,10 @@ REGISTER_OP_CUDA_KERNEL(
__macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \
__macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \
__macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
CudaReciprocalGradFunctor); \
__macro(log1p, Log1p, CudaLog1pFunctor, CudaLog1pGradFunctor); \
__macro(log2, Log2, CudaLog2Functor, CudaLog2GradFunctor); \
__macro(log10, Log10, CudaLog10Functor, CudaLog10GradFunctor); \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(stanh, STanh, CudaSTanhFunctor, CudaSTanhGradFunctor); \
__macro(softplus, Softplus, CudaSoftplusFunctor, CudaSoftplusGradFunctor); \
__macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor); \
__macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \
__macro(tanh_shrink, TanhShrink, CudaTanhShrinkFunctor, \
CudaTanhShrinkGradFunctor); \
......@@ -986,7 +791,6 @@ REGISTER_OP_CUDA_KERNEL(
__macro(hard_sigmoid, HardSigmoid, CudaHardSigmoidFunctor, \
CudaHardSigmoidGradFunctor); \
__macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \
__macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); \
__macro(hard_swish, HardSwish, CudaHardSwishFunctor, \
CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
......
......@@ -103,8 +103,14 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acosh, AcoshGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, AtanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, TanhShrinkGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, SiluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Square, SquareGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Softsign, SoftsignGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, ExpGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, Expm1GradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, ReciprocalGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, SqrtGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt, RsqrtGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor);
......@@ -122,11 +128,25 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink,
HardShrinkGradFunctor,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
MishGradFunctor,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
BReluGradFunctor,
t_min,
t_max);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh,
STanhGradFunctor,
scale_a,
scale_b);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus,
SoftplusGradFunctor,
beta,
threshold);
template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -190,6 +210,13 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad,
ReluDoubleGradKernel)
......@@ -223,3 +250,14 @@ PD_REGISTER_KERNEL(expm1_grad,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
logit_grad, CPU, ALL_LAYOUT, phi::LogitGradKernel, float, double) {}
PD_REGISTER_KERNEL(square_grad,
CPU,
ALL_LAYOUT,
phi::SquareGradKernel,
float,
double,
int,
int64_t) {}
......@@ -115,6 +115,22 @@ struct ReciprocalFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(-1) * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// cosine'(x) = -sin(x)
template <typename T>
struct CosGradFunctor : public BaseActivationFunctor<T> {
......@@ -157,9 +173,9 @@ struct LogitFunctor {
}
};
// // mish(x) = x * tanh(softplus(x))
// // softplus(x) = x, if x > threshold
// // = ln(1 + exp(x)), otherwise
// mish(x) = x * tanh(softplus(x))
// softplus(x) = x, if x > threshold
// = ln(1 + exp(x)), otherwise
template <typename T>
struct MishFunctor : public BaseActivationFunctor<T> {
......@@ -176,6 +192,33 @@ struct MishFunctor : public BaseActivationFunctor<T> {
}
};
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
template <typename T>
struct MishGradFunctor : public BaseActivationFunctor<T> {
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto sp = (x > static_cast<T>(threshold))
.select(x, (static_cast<T>(1) + x.exp()).log());
auto gsp = static_cast<T>(1) - (-sp).exp();
auto tsp = sp.tanh();
dx.device(d) = dout * (tsp + x * (static_cast<T>(1) - tsp * tsp) * gsp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
float scale_a;
......@@ -191,6 +234,29 @@ struct STanhFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct STanhGradFunctor : public BaseActivationFunctor<T> {
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto a = static_cast<T>(scale_a);
auto b = static_cast<T>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct Tangent {
HOSTDEVICE T operator()(const T& val) const { return tan(val); }
......@@ -227,6 +293,20 @@ struct SquareFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(2) * x;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
......@@ -236,6 +316,22 @@ struct SqrtFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0.5) * dout / out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// rsqrt(x) = x^(-1/2)
template <typename T>
struct RsqrtFunctor : public BaseActivationFunctor<T> {
......@@ -245,29 +341,28 @@ struct RsqrtFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// // For numerical stability, using the following formula instead of
// softplus(x) =
// // log(1 + exp(x))
// // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <=
// threshold(beta =
// // 1, threshold = 20 by default), otherwise x
// template <typename T>
// struct SoftplusFunctor : public BaseActivationFunctor<T> {
// float beta;
// float threshold;
// typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
// return {{"beta", &beta}, {"threshold", &threshold}};
// }
// template <typename Device, typename X, typename Out>
// void operator()(Device d, X x, Out out) {
// auto x_beta = static_cast<T>(beta) * x;
// out.device(d) = (x_beta > static_cast<T>(threshold))
// .select(x,
// (static_cast<T>(1) + x_beta.exp()).log() /
// static_cast<T>(beta));
// }
// };
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
......@@ -288,6 +383,33 @@ struct SoftplusFunctor : public BaseActivationFunctor<T> {
}
};
// For numerical stability, using the following formula instead of
// d(softplus(x))/dx = 1 / (1 + exp(-x))
// d(softplus(x))/dx = 1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
// = 1, threshold = 20 by default), otherwise x
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto x_beta = static_cast<T>(beta) * x;
dx.device(d) =
(x_beta > static_cast<T>(threshold))
.select(dout, dout / (static_cast<T>(1) + (-x_beta).exp()));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// Tangent(x) = tan(x)
template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> {
......@@ -479,6 +601,18 @@ struct AtanGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct LogitGradFunctor {
template <typename Device, typename X, typename dOut, typename dX, typename P>
void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const {
// logit(x)' = 1/(x*(1-x))
dx.device(d) =
(x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
.select(p.constant(static_cast<T>(0)),
dout * (static_cast<T>(1) / ((static_cast<T>(1) - x) * x)));
}
};
template <typename T>
struct Acosh {
HOSTDEVICE T operator()(const T& val) const { return acosh(val); }
......@@ -868,6 +1002,37 @@ struct SoftsignFunctor : public BaseActivationFunctor<T> {
}
};
// template <typename T>
// struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
// template <typename Device, typename X, typename Out, typename dOut,
// typename dX>
// void operator()(Device d, X x, Out out, dOut dout, dX dx) {
// dx.device(d) =
// dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
// }
// static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX;
// }
// };
// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) =
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
float alpha;
......@@ -1270,6 +1435,18 @@ struct CudaSquareFunctor : public BaseActivationFunctor<T> {
__device__ __forceinline__ T operator()(const T x) const { return x * x; }
};
template <typename T>
struct CudaSquareGradFunctor : public BaseActivationFunctor<T> {
T two = static_cast<T>(2.0f);
// dx = dout * 2 * x
__device__ __forceinline__ T operator()(const T dout, const T x) const {
return dout * two * x;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaExpGradFunctor : public BaseActivationFunctor<T> {
// dx = dout * out
......@@ -1290,6 +1467,18 @@ struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
__device__ __forceinline__ T operator()(const T x) const { return one / x; }
};
template <typename T>
struct CudaReciprocalGradFunctor : public BaseActivationFunctor<T> {
// dx = -dout * out^2
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return -dout * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaExpm1Functor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1334,6 +1523,19 @@ struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
// dx = dout / (1 + abs(x))^2
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T temp = one + abs(x);
return dout / (temp * temp);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1564,6 +1766,31 @@ struct CudaSTanhFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaSTanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float scale_a;
float scale_b;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
// dx = dout * a * b * (1 - tanh(a * x) * tanh(a * x))
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(scale_a);
MPType b = static_cast<MPType>(scale_b);
MPType temp = tanh(a * x);
return static_cast<T>(dout * a * b * (one - temp * temp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1585,6 +1812,31 @@ struct CudaSoftplusFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaSoftplusGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
// dx = x * beta > threshold ? dout : dout / (1 + exp(-beta * x))
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType t = static_cast<MPType>(threshold);
MPType x_beta = x * beta;
return x_beta > t ? arg_dout : static_cast<T>(dout / (one + exp(-x_beta)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAtanhGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1611,6 +1863,20 @@ struct CudaSqrtFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaSqrtGradFunctor : public BaseActivationFunctor<T> {
T one_half = static_cast<T>(0.5f);
// dx = dout * 0.5 / out
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return one_half * dout / out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1622,6 +1888,20 @@ struct CudaRsqrtFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaRsqrtGradFunctor : public BaseActivationFunctor<T> {
T minus_one_half = static_cast<T>(-0.5f);
// dx = -0.5 * dout * out^3
__device__ __forceinline__ T operator()(const T dout, const T out) const {
return minus_one_half * dout * out * out * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaAtanFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -1710,6 +1990,34 @@ struct CudaMishFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = dout * (tanh(sp) + x * (1 - tanh(sp) ** 2) * (1 - exp(-sp)))
// sp = softplus(x)
// Inputs: args[0], the input dout
// args[1], the input x
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType sp = (x > static_cast<MPType>(threshold)) ? x : log(one + exp(x));
MPType gsp =
(x > static_cast<MPType>(threshold)) ? one : one / (one + exp(-x));
MPType tsp = tanh(sp);
return static_cast<T>(dout * (tsp + x * (one - tsp * tsp) * gsp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaBReluGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
......
......@@ -157,8 +157,14 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acosh, CudaAcoshGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Atanh, CudaAtanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(TanhShrink, CudaTanhShrinkGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Silu, CudaSiluGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Softsign, CudaSoftsignGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Square, CudaSquareGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Exp, CudaExpGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Expm1, CudaExpm1GradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, CudaReciprocalGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, CudaSqrtGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Rsqrt, CudaRsqrtGradFunctor);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
CudaLeakyReluGradFunctor,
......@@ -173,11 +179,25 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink,
CudaHardShrinkGradFunctor,
threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
CudaMishGradFunctor,
threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
CudaBReluGradFunctor,
t_min,
t_max);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh,
CudaSTanhGradFunctor,
scale_a,
scale_b);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus,
CudaSoftplusGradFunctor,
beta,
threshold);
template <typename T, typename Context>
void EluGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -266,6 +286,11 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
ThresholdedReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_KERNEL(exp_grad,
GPU,
......@@ -290,3 +315,14 @@ PD_REGISTER_KERNEL(expm1_grad,
float,
double,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
logit_grad, GPU, ALL_LAYOUT, phi::LogitGradKernel, float, double) {}
PD_REGISTER_KERNEL(square_grad,
GPU,
ALL_LAYOUT,
phi::SquareGradKernel,
float,
double,
int,
int64_t) {}
......@@ -222,4 +222,22 @@ void EluDoubleGradKernel(const Context& dev_ctx,
functor(dev_ctx, &x, &ddx, ddout, &dout, dx);
}
template <typename T, typename Context>
void LogitGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
float eps,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
auto eigen_x = EigenVector<T>::Flatten(x);
auto eigen_dout = EigenVector<T>::Flatten(out_grad);
auto eigen_dx = EigenVector<T>::Flatten(*x_grad);
auto& place = *dev_ctx.eigen_device();
auto eigen_p = EigenVector<T>::Flatten(x);
funcs::LogitGradFunctor<T> functor;
functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps);
}
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册