未验证 提交 be5918e0 编写于 作者: Y YuanRisheng 提交者: GitHub

move activation (#40913)

上级 c33b4f95
...@@ -1499,6 +1499,12 @@ REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor, ...@@ -1499,6 +1499,12 @@ REGISTER_ACTIVATION_OP(logsigmoid, LogSigmoid, LogSigmoidFunctor,
REGISTER_ACTIVATION_OP(log2, Log2, Log2Functor, Log2GradFunctor); REGISTER_ACTIVATION_OP(log2, Log2, Log2Functor, Log2GradFunctor);
REGISTER_ACTIVATION_OP(log10, Log10, Log10Functor, Log10GradFunctor); REGISTER_ACTIVATION_OP(log10, Log10, Log10Functor, Log10GradFunctor);
REGISTER_ACTIVATION_OP(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); REGISTER_ACTIVATION_OP(log1p, Log1p, Log1pFunctor, Log1pGradFunctor);
REGISTER_ACTIVATION_OP(hard_swish, HardSwish, HardSwishFunctor,
HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
REGISTER_ACTIVATION_OP(round, Round, RoundFunctor, ZeroGradFunctor);
REGISTER_ACTIVATION_OP(floor, Floor, FloorFunctor, ZeroGradFunctor);
REGISTER_ACTIVATION_OP(ceil, Ceil, CeilFunctor, ZeroGradFunctor);
/* ========================== sigmoid register ============================= /* ========================== sigmoid register =============================
*/ */
...@@ -1778,18 +1784,6 @@ REGISTER_OPERATOR( ...@@ -1778,18 +1784,6 @@ REGISTER_OPERATOR(
ops::ActFwdInplaceInferer, void>::type); ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad, REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
ops::ActivationGradOpInplaceInferer); ops::ActivationGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>,
ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<int>>,
ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<int64_t>>);
REGISTER_OP_CPU_KERNEL(
pow_grad,
ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>,
ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<int>>,
ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<int64_t>>);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== exp register ============================ */ /* ========================== exp register ============================ */
......
...@@ -286,10 +286,25 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Log) ...@@ -286,10 +286,25 @@ USE_PHI_DOUBLE_GRAD_FUNCTOR(Log)
USE_PHI_FUNCTOR(Log2) USE_PHI_FUNCTOR(Log2)
USE_PHI_FUNCTOR(Log10) USE_PHI_FUNCTOR(Log10)
USE_PHI_FUNCTOR(Log1p) USE_PHI_FUNCTOR(Log1p)
USE_PHI_FUNCTOR(Swish)
USE_PHI_FUNCTOR(HardSwish)
USE_PHI_FUNCTOR(Pow)
template <typename T> template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>; using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
template <typename T>
using RoundFunctor = phi::funcs::RoundFunctor<T>;
template <typename T>
using FloorFunctor = phi::funcs::FloorFunctor<T>;
template <typename T>
using CeilFunctor = phi::funcs::CeilFunctor<T>;
template <typename T>
using ZeroGradFunctor = phi::funcs::ZeroGradFunctor<T>;
// exp(x) = e^x // exp(x) = e^x
template <typename T> template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> { struct ExpFunctor : public BaseActivationFunctor<T> {
...@@ -391,46 +406,6 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> { ...@@ -391,46 +406,6 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.ceil();
}
};
template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0) * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kNoDeps;
}
};
// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.floor();
}
};
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.round();
}
};
// reciprocal(x) = 1 / x // reciprocal(x) = 1 / x
template <typename T> template <typename T>
struct ReciprocalFunctor : public BaseActivationFunctor<T> { struct ReciprocalFunctor : public BaseActivationFunctor<T> {
...@@ -509,51 +484,6 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -509,51 +484,6 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
} }
}; };
// HardSwish = min(max(0, x+3), 6) * x / 6
template <typename T>
struct HardSwishFunctor : public BaseActivationFunctor<T> {
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = (x + static_cast<T>(offset))
.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(threshold)) *
x / static_cast<T>(scale);
}
};
template <typename T>
struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto tmp = ((x + static_cast<T>(offset)) < static_cast<T>(threshold))
.template cast<T>();
dx.device(d) =
dout *
(((x + static_cast<T>(offset)) > static_cast<T>(0)).template cast<T>() *
(static_cast<T>(2) * x + static_cast<T>(offset)) /
static_cast<T>(scale) * tmp +
static_cast<T>(1) * (static_cast<T>(1) - tmp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// For numerical stability, using the following formula instead of softplus(x) = // For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x)) // log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta = // softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
...@@ -776,35 +706,6 @@ struct CELUGradFunctor : public BaseActivationFunctor<T> { ...@@ -776,35 +706,6 @@ struct CELUGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.pow(static_cast<T>(factor));
}
};
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor) - static_cast<T>(1));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct LogitFunctor { struct LogitFunctor {
template <typename Device, typename X, typename Out, typename P> template <typename Device, typename X, typename Out, typename P>
...@@ -870,39 +771,6 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> { ...@@ -870,39 +771,6 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
}
};
template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
auto out = x * temp1;
auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out));
dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct AbsGradGradFunctor : public BaseActivationFunctor<T> { struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device> template <typename Device>
...@@ -1267,110 +1135,6 @@ class RsqrtDoubleGradKernel ...@@ -1267,110 +1135,6 @@ class RsqrtDoubleGradKernel
} }
}; };
template <typename DeviceContext, typename Functor>
class PowKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* X = nullptr;
framework::Tensor* Out = nullptr;
ExtractActivationTensor(context, &X, &Out);
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "Pow"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "Pow"));
auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
// get FactorTensor
auto* factor_tensor = context.HasInput("FactorTensor")
? context.Input<framework::Tensor>("FactorTensor")
: nullptr;
if (factor_tensor) {
auto* factor_data = factor_tensor->data<float>();
framework::Tensor cpu_factor_tensor;
if (platform::is_gpu_place(factor_tensor->place())) {
framework::TensorCopySync(*factor_tensor, platform::CPUPlace(),
&cpu_factor_tensor);
factor_data = cpu_factor_tensor.data<float>();
}
auto factor =
std::vector<float>(factor_data, factor_data + factor_tensor->numel());
PADDLE_ENFORCE_EQ(
factor.size(), 1,
platform::errors::InvalidArgument(
"The shape of factor(tensor) must be [1] rather than %d",
factor.size()));
for (auto& attr : attrs) {
*attr.second = factor[0];
}
}
functor(*place, x, out);
}
};
template <typename DeviceContext, typename Functor>
class PowGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor *X, *Out, *dOut;
framework::Tensor* dX = nullptr;
X = Out = dOut = nullptr;
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX);
dX->mutable_data<T>(context.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "PowGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Input", "Out", "PowGrad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "X@GRAD", "PowGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "PowGrad"));
auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
// get FactorTensor
auto* factor_tensor =
context.HasInput("FactorTensor")
? context.Input<framework::LoDTensor>("FactorTensor")
: nullptr;
if (factor_tensor) {
auto* factor_data = factor_tensor->data<float>();
framework::Tensor cpu_factor_tensor;
if (platform::is_gpu_place(factor_tensor->place())) {
framework::TensorCopySync(*factor_tensor, platform::CPUPlace(),
&cpu_factor_tensor);
factor_data = cpu_factor_tensor.data<float>();
}
auto factor =
std::vector<float>(factor_data, factor_data + factor_tensor->numel());
PADDLE_ENFORCE_EQ(
factor.size(), 1,
platform::errors::InvalidArgument(
"The shape of factor(tensor) must be [1] rather than %d",
factor.size()));
for (auto& attr : attrs) {
*attr.second = factor[0];
}
}
functor(*place, x, out, dout, dx);
}
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LogitKernel : public framework::OpKernel<T> { class LogitKernel : public framework::OpKernel<T> {
public: public:
...@@ -1418,15 +1182,10 @@ class LogitGradKernel : public framework::OpKernel<T> { ...@@ -1418,15 +1182,10 @@ class LogitGradKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
#define FOR_EACH_ACTIVATION_OP(__macro) \ #define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \ __macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \ __macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \ __macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \ __macro(mish, Mish, MishFunctor, MishGradFunctor);
__macro(mish, Mish, MishFunctor, MishGradFunctor); \
__macro(hard_swish, HardSwish, HardSwishFunctor, HardSwishGradFunctor);
...@@ -20,51 +20,6 @@ limitations under the License. */ ...@@ -20,51 +20,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// ceil(x) = ceil(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(ceil(x));
}
};
template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// floor(x) = floor(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(floor(x));
}
};
template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
// round(x) = round(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(round(x));
}
};
// GradFunctor for ceil, floor and round
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(0.0f);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kNoDeps;
}
};
template <typename T> template <typename T>
struct CudaReciprocalFunctor : public BaseActivationFunctor<T> { struct CudaReciprocalFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f); T one = static_cast<T>(1.0f);
...@@ -395,50 +350,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> { ...@@ -395,50 +350,6 @@ struct CudaRelu6GradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// swish(x) = x / (1 + exp(-beta * x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
return static_cast<T>(x / (one + exp(-b * x)));
}
};
template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType temp1 = one / (one + exp(-b * x));
MPType out = x * temp1;
MPType temp2 = b * out;
MPType temp3 = temp1 * (one - temp2);
return static_cast<T>(dout * (temp2 + temp3));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct CudaMishFunctor : public BaseActivationFunctor<T> { struct CudaMishFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename details::MPTypeTrait<T>::Type;
...@@ -488,58 +399,6 @@ struct CudaMishGradFunctor : public BaseActivationFunctor<T> { ...@@ -488,58 +399,6 @@ struct CudaMishGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
// hard_swish(x) = 0, when x <= -offset
// x , when x >= threshold - offset
// x * (x + offset) / scale, otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T x) const {
T t = static_cast<T>(threshold);
T temp = x + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero;
T temp_min = temp_max < t ? temp_max : t;
return temp_min * x / static_cast<T>(scale);
}
};
template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
T two = static_cast<T>(2.0f);
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
// dx = 0, when x <= -offset
// dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T o = static_cast<T>(offset);
T s = static_cast<T>(scale);
T temp1 = static_cast<T>(x + o > zero);
T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> { struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type; using CT = typename details::MPTypeTrait<T>::Type;
...@@ -684,6 +543,20 @@ USE_PHI_FUNCTOR(CudaLog) ...@@ -684,6 +543,20 @@ USE_PHI_FUNCTOR(CudaLog)
USE_PHI_FUNCTOR(CudaLog2) USE_PHI_FUNCTOR(CudaLog2)
USE_PHI_FUNCTOR(CudaLog10) USE_PHI_FUNCTOR(CudaLog10)
USE_PHI_FUNCTOR(CudaLog1p) USE_PHI_FUNCTOR(CudaLog1p)
USE_PHI_FUNCTOR(CudaSwish)
USE_PHI_FUNCTOR(CudaHardSwish)
template <typename T>
using CudaRoundFunctor = phi::funcs::CudaRoundFunctor<T>;
template <typename T>
using CudaFloorFunctor = phi::funcs::CudaFloorFunctor<T>;
template <typename T>
using CudaCeilFunctor = phi::funcs::CudaCeilFunctor<T>;
template <typename T>
using CudaZeroGradFunctor = phi::funcs::CudaZeroGradFunctor<T>;
template <typename T> template <typename T>
using CudaELUGradNegativeAlphaFunctor = using CudaELUGradNegativeAlphaFunctor =
...@@ -813,23 +686,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -813,23 +686,6 @@ REGISTER_OP_CUDA_KERNEL(
ops::SquareGradGradFunctor<int64_t>>); ops::SquareGradGradFunctor<int64_t>>);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== pow register ============================ */
REGISTER_OP_CUDA_KERNEL(
pow, ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<float>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<double>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<int64_t>>,
ops::PowKernel<plat::CUDADeviceContext, ops::PowFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
pow_grad,
ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<float>>,
ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<double>>,
ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int>>,
ops::PowGradKernel<plat::CUDADeviceContext, ops::PowGradFunctor<int64_t>>,
ops::PowGradKernel<plat::CUDADeviceContext,
ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== logit register ============================ */ /* ========================== logit register ============================ */
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
...@@ -889,9 +745,6 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -889,9 +745,6 @@ REGISTER_OP_CUDA_KERNEL(
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \ __macro(softshrink, SoftShrink, CudaSoftShrinkFunctor, \
CudaSoftShrinkGradFunctor); \ CudaSoftShrinkGradFunctor); \
__macro(ceil, Ceil, CudaCeilFunctor, CudaZeroGradFunctor); \
__macro(floor, Floor, CudaFloorFunctor, CudaZeroGradFunctor); \
__macro(round, Round, CudaRoundFunctor, CudaZeroGradFunctor); \
__macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \ __macro(reciprocal, Reciprocal, CudaReciprocalFunctor, \
CudaReciprocalGradFunctor); \ CudaReciprocalGradFunctor); \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \ __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
...@@ -903,10 +756,7 @@ REGISTER_OP_CUDA_KERNEL( ...@@ -903,10 +756,7 @@ REGISTER_OP_CUDA_KERNEL(
CudaTanhShrinkGradFunctor); \ CudaTanhShrinkGradFunctor); \
__macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \ __macro(hard_shrink, HardShrink, CudaHardShrinkFunctor, \
CudaHardShrinkGradFunctor); \ CudaHardShrinkGradFunctor); \
__macro(swish, Swish, CudaSwishFunctor, CudaSwishGradFunctor); \ __macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor);
__macro(mish, Mish, CudaMishFunctor, CudaMishGradFunctor); \
__macro(hard_swish, HardSwish, CudaHardSwishFunctor, \
CudaHardSwishGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL) FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
#ifdef PADDLE_WITH_XPU_KP #ifdef PADDLE_WITH_XPU_KP
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
...@@ -50,6 +51,11 @@ namespace phi { ...@@ -50,6 +51,11 @@ namespace phi {
const DenseTensor& dout, \ const DenseTensor& dout, \
DenseTensor* dx); DenseTensor* dx);
#define DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(name) \
template <typename T, typename Context> \
void name##GradKernel( \
const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx);
#define DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(name, attr) \ #define DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(name, attr) \
template <typename T, typename Context> \ template <typename T, typename Context> \
void name##GradKernel(const Context& dev_ctx, \ void name##GradKernel(const Context& dev_ctx, \
...@@ -143,6 +149,22 @@ void LogDoubleGradKernel(const Context& dev_ctx, ...@@ -143,6 +149,22 @@ void LogDoubleGradKernel(const Context& dev_ctx,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* ddout); DenseTensor* ddout);
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
float scale,
float offset,
DenseTensor* dx);
template <typename T, typename Context>
void PowGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const Scalar& factor,
DenseTensor* dx);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos);
...@@ -166,10 +188,15 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu); ...@@ -166,10 +188,15 @@ DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh); DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid); DECLARE_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Round);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Floor);
DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Ceil);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, alpha); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, alpha);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
...@@ -60,13 +61,32 @@ DECLARE_ACTIVATION_KERNEL(Log) ...@@ -60,13 +61,32 @@ DECLARE_ACTIVATION_KERNEL(Log)
DECLARE_ACTIVATION_KERNEL(Log2) DECLARE_ACTIVATION_KERNEL(Log2)
DECLARE_ACTIVATION_KERNEL(Log10) DECLARE_ACTIVATION_KERNEL(Log10)
DECLARE_ACTIVATION_KERNEL(Log1p) DECLARE_ACTIVATION_KERNEL(Log1p)
DECLARE_ACTIVATION_KERNEL(Round)
DECLARE_ACTIVATION_KERNEL(Floor)
DECLARE_ACTIVATION_KERNEL(Ceil)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset)
template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
float scale,
float offset,
DenseTensor* out);
template <typename T, typename Context>
void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out);
} // namespace phi } // namespace phi
...@@ -107,6 +107,15 @@ namespace phi { ...@@ -107,6 +107,15 @@ namespace phi {
dev_ctx, nullptr, &out, &dout, dx, functor); \ dev_ctx, nullptr, &out, &dout, dx, functor); \
} }
#define DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel( \
const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { \
funcs::functor_class<T> functor; \
ActivationGradImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, nullptr, nullptr, &dout, dx, functor); \
}
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CosGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CosGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, TanGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, TanGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, AcosGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, AcosGradFunctor);
...@@ -130,6 +139,10 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor); ...@@ -130,6 +139,10 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, ReluGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, TanhGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, SigmoidGradFunctor); DEFINE_CPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, SigmoidGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Round, ZeroGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Floor, ZeroGradFunctor);
DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, ZeroGradFunctor);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu,
LeakyReluGradFunctor, LeakyReluGradFunctor,
alpha); alpha);
...@@ -142,6 +155,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, ...@@ -142,6 +155,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink,
HardShrinkGradFunctor, HardShrinkGradFunctor,
threshold); threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishGradFunctor, beta);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
BReluGradFunctor, BReluGradFunctor,
...@@ -183,6 +197,23 @@ void EluGradKernel(const Context& dev_ctx, ...@@ -183,6 +197,23 @@ void EluGradKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
float scale,
float offset,
DenseTensor* dx) {
funcs::HardSwishGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = threshold;
*(attrs[1].second) = scale;
*(attrs[2].second) = offset;
ActivationGradImpl<T, Context, funcs::HardSwishGradFunctor<T>>(
dev_ctx, &x, nullptr, &dout, dx, functor);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
...@@ -242,3 +273,17 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) ...@@ -242,3 +273,17 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
PD_REGISTER_KERNEL(pow_grad,
CPU,
ALL_LAYOUT,
phi::PowGradKernel,
float,
double,
int,
int64_t) {}
...@@ -78,6 +78,9 @@ DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor) ...@@ -78,6 +78,9 @@ DEFINE_CPU_ACTIVATION_KERNEL(Log, LogFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Log2, Log2Functor) DEFINE_CPU_ACTIVATION_KERNEL(Log2, Log2Functor)
DEFINE_CPU_ACTIVATION_KERNEL(Log10, Log10Functor) DEFINE_CPU_ACTIVATION_KERNEL(Log10, Log10Functor)
DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Log1p, Log1pFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Round, RoundFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
...@@ -86,6 +89,7 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, ...@@ -86,6 +89,7 @@ DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishFunctor, beta)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max) DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
...@@ -93,6 +97,22 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, ...@@ -93,6 +97,22 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
slope, slope,
offset) offset)
template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
float scale,
float offset,
DenseTensor* out) {
funcs::HardSwishFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = threshold;
*(attrs[1].second) = scale;
*(attrs[2].second) = offset;
ActivationImpl<T, Context, funcs::HardSwishFunctor<T>>(
dev_ctx, x, out, functor);
}
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {} PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
...@@ -126,3 +146,10 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) ...@@ -126,3 +146,10 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_KERNEL(
pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {}
...@@ -1350,6 +1350,165 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -1350,6 +1350,165 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
// HardSwish = min(max(0, x+3), 6) * x / 6
template <typename T>
struct HardSwishFunctor : public BaseActivationFunctor<T> {
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = (x + static_cast<T>(offset))
.cwiseMax(static_cast<T>(0))
.cwiseMin(static_cast<T>(threshold)) *
x / static_cast<T>(scale);
}
};
template <typename T>
struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto tmp = ((x + static_cast<T>(offset)) < static_cast<T>(threshold))
.template cast<T>();
dx.device(d) =
dout *
(((x + static_cast<T>(offset)) > static_cast<T>(0)).template cast<T>() *
(static_cast<T>(2) * x + static_cast<T>(offset)) /
static_cast<T>(scale) * tmp +
static_cast<T>(1) * (static_cast<T>(1) - tmp));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SwishFunctor : public BaseActivationFunctor<T> {
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
}
};
template <typename T>
struct SwishGradFunctor : public BaseActivationFunctor<T> {
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
auto out = x * temp1;
auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out));
dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.pow(static_cast<T>(factor));
}
};
template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor) - static_cast<T>(1));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.floor();
}
};
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.round();
}
};
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.ceil();
}
};
template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0) * out;
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kNoDeps;
}
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename T> template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> { struct CudaReluFunctor : public BaseActivationFunctor<T> {
...@@ -2190,6 +2349,147 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> { ...@@ -2190,6 +2349,147 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct CudaSwishFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// swish(x) = x / (1 + exp(-beta * x))
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
return static_cast<T>(x / (one + exp(-b * x)));
}
};
template <typename T>
struct CudaSwishGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float beta;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}};
}
// dx = dout * (1 + exp(-b * x) + b * x * exp(-b * x) / (1 + exp(-b * x))^2)
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType b = static_cast<MPType>(beta);
MPType temp1 = one / (one + exp(-b * x));
MPType out = x * temp1;
MPType temp2 = b * out;
MPType temp3 = temp1 * (one - temp2);
return static_cast<T>(dout * (temp2 + temp3));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaHardSwishFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
// hard_swish(x) = 0, when x <= -offset
// x , when x >= threshold - offset
// x * (x + offset) / scale, otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T x) const {
T t = static_cast<T>(threshold);
T temp = x + static_cast<T>(offset);
T temp_max = temp > zero ? temp : zero;
T temp_min = temp_max < t ? temp_max : t;
return temp_min * x / static_cast<T>(scale);
}
};
template <typename T>
struct CudaHardSwishGradFunctor : public BaseActivationFunctor<T> {
T zero = static_cast<T>(0.0f);
T one = static_cast<T>(1.0f);
T two = static_cast<T>(2.0f);
float threshold;
float scale;
float offset;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
// dx = 0, when x <= -offset
// dout , when x >= threshold - offset
// dout * (2 * x / scale + offset / scale), otherwise
// threshold = scale = 6, offset = 3 by default
__device__ __forceinline__ T operator()(const T dout, const T x) const {
T o = static_cast<T>(offset);
T s = static_cast<T>(scale);
T temp1 = static_cast<T>(x + o > zero);
T temp2 = static_cast<T>(x + o < static_cast<T>(threshold));
return dout * (temp1 * temp2 * (two * x + o) / s + one - temp2);
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaCeilFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// ceil(x) = ceil(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(ceil(x));
}
};
template <typename T>
struct CudaFloorFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// floor(x) = floor(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(floor(x));
}
};
template <typename T>
struct CudaRoundFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
// round(x) = round(x)
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
return static_cast<T>(round(x));
}
};
// GradFunctor for ceil, floor and round
template <typename T>
struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
__device__ __forceinline__ T operator()(const T x) const {
return static_cast<T>(0.0f);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kNoDeps;
}
};
#endif #endif
} // namespace funcs } // namespace funcs
......
...@@ -159,10 +159,23 @@ void ActivationGradGPUImpl(const Context& dev_ctx, ...@@ -159,10 +159,23 @@ void ActivationGradGPUImpl(const Context& dev_ctx,
dev_ctx, nullptr, &out, &dout, dx, functor); \ dev_ctx, nullptr, &out, &dout, dx, functor); \
} }
#define DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(name, functor_class) \
template <typename T, typename Context> \
void name##GradKernel( \
const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { \
funcs::functor_class<T> functor; \
ActivationGradGPUImpl<T, Context, funcs::functor_class<T>>( \
dev_ctx, nullptr, nullptr, &dout, dx, functor); \
}
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, CudaReluGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, CudaReluGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, CudaTanhGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, CudaTanhGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, CudaSigmoidGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, CudaSigmoidGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Round, CudaZeroGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Floor, CudaZeroGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, CudaZeroGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CudaCosGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Cos, CudaCosGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, CudaTanGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Tan, CudaTanGradFunctor);
DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, CudaAcosGradFunctor); DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Acos, CudaAcosGradFunctor);
...@@ -194,6 +207,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, ...@@ -194,6 +207,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink,
CudaHardShrinkGradFunctor, CudaHardShrinkGradFunctor,
threshold); threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
CudaSwishGradFunctor,
beta);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
CudaBReluGradFunctor, CudaBReluGradFunctor,
...@@ -227,6 +243,23 @@ void EluGradKernel(const Context& dev_ctx, ...@@ -227,6 +243,23 @@ void EluGradKernel(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
float threshold,
float scale,
float offset,
DenseTensor* dx) {
funcs::CudaHardSwishGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = threshold;
*(attrs[1].second) = scale;
*(attrs[2].second) = offset;
ActivationGradGPUImpl<T, Context, funcs::CudaHardSwishGradFunctor<T>>(
dev_ctx, &x, nullptr, &dout, dx, functor);
}
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -315,3 +348,18 @@ PD_REGISTER_KERNEL(log_double_grad, ...@@ -315,3 +348,18 @@ PD_REGISTER_KERNEL(log_double_grad,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
PD_REGISTER_KERNEL(pow_grad,
GPU,
ALL_LAYOUT,
phi::PowGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
...@@ -97,6 +97,9 @@ DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor) ...@@ -97,6 +97,9 @@ DEFINE_GPU_ACTIVATION_KERNEL(Log, CudaLogFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Log2, CudaLog2Functor) DEFINE_GPU_ACTIVATION_KERNEL(Log2, CudaLog2Functor)
DEFINE_GPU_ACTIVATION_KERNEL(Log10, CudaLog10Functor) DEFINE_GPU_ACTIVATION_KERNEL(Log10, CudaLog10Functor)
DEFINE_GPU_ACTIVATION_KERNEL(Log1p, CudaLog1pFunctor) DEFINE_GPU_ACTIVATION_KERNEL(Log1p, CudaLog1pFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Round, CudaRoundFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Floor, CudaFloorFunctor)
DEFINE_GPU_ACTIVATION_KERNEL(Ceil, CudaCeilFunctor)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
...@@ -107,6 +110,7 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, ...@@ -107,6 +110,7 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
threshold) threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, CudaSwishFunctor, beta)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
...@@ -114,6 +118,22 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, ...@@ -114,6 +118,22 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
slope, slope,
offset) offset)
template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
float scale,
float offset,
DenseTensor* out) {
funcs::CudaHardSwishFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = threshold;
*(attrs[1].second) = scale;
*(attrs[2].second) = offset;
ActivationGPUImpl<T, Context, funcs::CudaHardSwishFunctor<T>>(
dev_ctx, x, out, functor);
}
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
...@@ -172,3 +192,17 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) ...@@ -172,3 +192,17 @@ PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel) PD_REGISTER_ACTIVATION_KERNEL(log2, Log2Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel) PD_REGISTER_ACTIVATION_KERNEL(log10, Log10Kernel)
PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel) PD_REGISTER_ACTIVATION_KERNEL(log1p, Log1pKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_KERNEL(pow,
GPU,
ALL_LAYOUT,
phi::PowKernel,
float,
double,
int,
int64_t,
phi::dtype::float16) {}
...@@ -293,4 +293,28 @@ void LogDoubleGradKernel(const Context& dev_ctx, ...@@ -293,4 +293,28 @@ void LogDoubleGradKernel(const Context& dev_ctx,
functor(dev_ctx, &x, &ddx, ddout, &dout, dx); functor(dev_ctx, &x, &ddx, ddout, &dout, dx);
} }
template <typename T, typename Context>
void PowGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const Scalar& factor,
DenseTensor* dx) {
PADDLE_ENFORCE_NOT_NULL(
dx, errors::NotFound("The output DenseTensor dX can not be nullptr"));
if (dx) {
dev_ctx.template Alloc<T>(dx);
}
auto dout_flatten = EigenVector<T>::Flatten(
GET_DATA_SAFELY(&dout, "Input", "Out@GRAD", "PowGrad"));
auto dx_flatten = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dx, "Output", "X@GRAD", "PowGrad"));
auto x_flatten =
EigenVector<T>::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad"));
auto* place = dev_ctx.eigen_device();
phi::funcs::PowGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
}
} // namespace phi } // namespace phi
...@@ -47,4 +47,23 @@ void ActivationImpl(const Context& dev_ctx, ...@@ -47,4 +47,23 @@ void ActivationImpl(const Context& dev_ctx,
} }
} }
template <typename T, typename Context>
void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
PADDLE_ENFORCE_NOT_NULL(out,
errors::NotFound("Output Out should not be nullptr"));
dev_ctx.template Alloc<T>(out);
auto x_flatten = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(&x, "Input", "X", "Activation"));
auto out_flatten = phi::EigenVector<T>::Flatten(
GET_DATA_SAFELY(out, "Output", "Out", "Activation"));
auto* place = dev_ctx.eigen_device();
phi::funcs::PowFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
functor(*place, x_flatten, out_flatten);
}
} // namespace phi } // namespace phi
...@@ -34,6 +34,13 @@ namespace phi { ...@@ -34,6 +34,13 @@ namespace phi {
{GradVarName("X")}); \ {GradVarName("X")}); \
} }
#define DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(func_name, op_name, attrs) \
KernelSignature func_name##GradOpArgumentMapping( \
const ArgumentMappingContext& ctx) { \
return KernelSignature( \
op_name "_grad", {GradVarName("Out")}, {attrs}, {GradVarName("X")}); \
}
#define comma , #define comma ,
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cos, "cos", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Cos, "cos", ); // NOLINT
...@@ -61,6 +68,11 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT ...@@ -61,6 +68,11 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log10, "log10", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log10, "log10", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log1p, "log1p", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log1p, "log1p", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish,
"hard_swish",
"threshold" comma "scale" comma
"offset"); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Swish, "swish", "beta"); // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Relu, "relu", ); // NOLINT
DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(Tanh, "tanh", ); // NOLINT
...@@ -69,6 +81,10 @@ DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(HardSigmoid, ...@@ -69,6 +81,10 @@ DEFINE_ACT_GRAD_DEPOUT_OP_ARGMAP(HardSigmoid,
"hard_sigmoid", "hard_sigmoid",
"slope" comma "offset"); // NOLINT "slope" comma "offset"); // NOLINT
DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(Round, "round", ); // NOLINT
DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(Floor, "floor", ); // NOLINT
DEFINE_ACT_GRAD_NODEP_OP_ARGMAP(Ceil, "ceil", ); // NOLINT
KernelSignature ReluDoubleGradOpArgumentMapping( KernelSignature ReluDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"}); return KernelSignature("relu_double_grad", {"Out", "DDX"}, {}, {"DDOut"});
...@@ -135,6 +151,26 @@ KernelSignature LogDoubleGradOpArgumentMapping( ...@@ -135,6 +151,26 @@ KernelSignature LogDoubleGradOpArgumentMapping(
"log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"}); "log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"});
} }
KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"});
} else {
return KernelSignature("pow", {"X"}, {"factor"}, {"Out"});
}
}
KernelSignature PowGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow_grad",
{"X", GradVarName("Out")},
{"FactorTensor"},
{GradVarName("X")});
} else {
return KernelSignature(
"pow_grad", {"X", GradVarName("Out")}, {"factor"}, {GradVarName("X")});
}
}
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad); PD_REGISTER_BASE_KERNEL_NAME(relu_grad_grad, relu_double_grad);
...@@ -197,3 +233,11 @@ PD_REGISTER_ARG_MAPPING_FN(log_grad_grad, phi::LogDoubleGradOpArgumentMapping); ...@@ -197,3 +233,11 @@ PD_REGISTER_ARG_MAPPING_FN(log_grad_grad, phi::LogDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(log2_grad, phi::Log2GradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(log2_grad, phi::Log2GradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(log10_grad, phi::Log10GradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(log10_grad, phi::Log10GradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(log1p_grad, phi::Log1pGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(log1p_grad, phi::Log1pGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(hard_swish_grad,
phi::HardSwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(swish_grad, phi::SwishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(round_grad, phi::RoundGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(floor_grad, phi::FloorGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(ceil_grad, phi::CeilGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册