未验证 提交 191c441a 编写于 作者: Y YuanRisheng 提交者: GitHub

move activation kernel (#42880)

上级 d8b69124
...@@ -1659,15 +1659,6 @@ REGISTER_OPERATOR( ...@@ -1659,15 +1659,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::CELUGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::CELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(celu, CELU, CELUFunctor, CELUGradFunctor);
REGISTER_OP_CPU_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<float>>,
ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<double>>,
ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* =========================== sqrt register ============================= */ /* =========================== sqrt register ============================= */
...@@ -1687,13 +1678,6 @@ REGISTER_OPERATOR( ...@@ -1687,13 +1678,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<float>>,
ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* =========================== rsqrt register ============================= /* =========================== rsqrt register =============================
...@@ -1714,14 +1698,6 @@ REGISTER_OPERATOR( ...@@ -1714,14 +1698,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
rsqrt_grad_grad,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<float>>,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<double>>,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== square register ============================ */ /* ========================== square register ============================ */
...@@ -1742,18 +1718,6 @@ REGISTER_OPERATOR( ...@@ -1742,18 +1718,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>, ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer); ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
square_grad_grad,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<float>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<plat::float16>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<int>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<int64_t>>);
/* ========================================================================== */ /* ========================================================================== */
/* ========================== pow register ============================ */ /* ========================== pow register ============================ */
......
...@@ -296,9 +296,14 @@ USE_PHI_FUNCTOR(Mish) ...@@ -296,9 +296,14 @@ USE_PHI_FUNCTOR(Mish)
USE_PHI_FUNCTOR(STanh) USE_PHI_FUNCTOR(STanh)
USE_PHI_FUNCTOR(Reciprocal) USE_PHI_FUNCTOR(Reciprocal)
USE_PHI_FUNCTOR(Square) USE_PHI_FUNCTOR(Square)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Square)
USE_PHI_FUNCTOR(Sqrt) USE_PHI_FUNCTOR(Sqrt)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sqrt)
USE_PHI_FUNCTOR(Rsqrt) USE_PHI_FUNCTOR(Rsqrt)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Rsqrt)
USE_PHI_FUNCTOR(Softplus) USE_PHI_FUNCTOR(Softplus)
USE_PHI_FUNCTOR(CELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(CELU)
template <typename T> template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>; using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
...@@ -331,68 +336,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>; ...@@ -331,68 +336,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
template <typename T> template <typename T>
using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>; using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// relu6(x) = min(max(0, x), 6) // relu6(x) = min(max(0, x), 6)
template <typename T> template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> { struct Relu6Functor : public BaseActivationFunctor<T> {
...@@ -498,51 +441,6 @@ class ELUGradKernel : public framework::OpKernel<T> { ...@@ -498,51 +441,6 @@ class ELUGradKernel : public framework::OpKernel<T> {
} }
}; };
template <typename T>
struct CELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) *
((x / static_cast<T>(alpha)).exp() - static_cast<T>(1)),
x);
}
};
template <typename T>
struct CELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_pos * temp_x_neg +
dout * temp_a_neg * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T> template <typename T>
struct AbsGradGradFunctor : public BaseActivationFunctor<T> { struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device> template <typename Device>
...@@ -564,74 +462,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> { ...@@ -564,74 +462,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad"));
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad"));
dx.device(*d) = ddx * dout / static_cast<T>(alpha) *
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad"));
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad"));
dx.device(*d) = ddx * static_cast<T>(2) * dout;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(2) * x;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need // TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// DOut(dy) as input(not output), tensor extraction is different from // DOut(dy) as input(not output), tensor extraction is different from
// others. Impliment extraction kernel separately here. // others. Impliment extraction kernel separately here.
...@@ -675,29 +505,6 @@ inline void ExtractDoubleGradTensorWithInputDOut( ...@@ -675,29 +505,6 @@ inline void ExtractDoubleGradTensorWithInputDOut(
} }
} }
template <typename DeviceContext, typename Functor>
class SquareDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, X, ddX, ddOut, dOut, dX);
}
};
template <typename T> template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> { struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
...@@ -721,153 +528,6 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> { ...@@ -721,153 +528,6 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename DeviceContext, typename Functor>
class CELUDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(place, X, ddX, ddOut, dOut, dX);
}
};
template <typename DeviceContext, typename Functor>
class SqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *dX, *ddX;
Out = dX = ddX = nullptr;
framework::Tensor *ddOut, *dOut;
ddOut = dOut = nullptr;
// extract ddx(input), ddout(output)
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE_NOT_NULL(
ddx_var, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
ddOut = ctx.Output<framework::Tensor>("DDOut");
}
PADDLE_ENFORCE_NOT_NULL(
ddX, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
// extract out(input), dout(output)
auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound(
"Cannot get input Variable Out, variable name = %s",
ctx.InputName("Out")));
auto dout_var = ctx.OutputVar("DOut");
Out = ctx.Input<framework::Tensor>("Out");
if (dout_var) {
dOut = ctx.Output<framework::Tensor>("DOut");
}
// extract dx(input)
auto dx_var = ctx.InputVar("DX");
PADDLE_ENFORCE_NOT_NULL(
dx_var, platform::errors::NotFound(
"Cannot get input Variable DX, variable name = %s",
ctx.InputName("DX")));
if (dx_var) {
dX = ctx.Input<framework::Tensor>("DX");
}
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, ddOut, dOut, dX);
}
};
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx
template <typename DeviceContext, typename Functor>
class RsqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *dX, *ddX;
Out = dX = ddX = nullptr;
framework::Tensor *ddOut, *dOut;
ddOut = dOut = nullptr;
// extract ddx(input), ddout(output)
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE_NOT_NULL(
ddx_var, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
ddOut = ctx.Output<framework::Tensor>("DDOut");
}
PADDLE_ENFORCE_NOT_NULL(
ddX, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
// extract out(input), dout(output)
auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound(
"Cannot get input Variable Out, variable name = %s",
ctx.InputName("Out")));
auto dout_var = ctx.OutputVar("DOut");
Out = ctx.Input<framework::Tensor>("Out");
if (dout_var) {
dOut = ctx.Output<framework::Tensor>("DOut");
}
// extract dx(input)
auto dx_var = ctx.InputVar("DX");
PADDLE_ENFORCE_NOT_NULL(
dx_var, platform::errors::NotFound(
"Cannot get input Variable DX, variable name = %s",
ctx.InputName("DX")));
if (dx_var) {
dX = ctx.Input<framework::Tensor>("DX");
}
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, ddOut, dOut, dX);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -126,59 +126,6 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> { ...@@ -126,59 +126,6 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
__device__ __forceinline__ T operator()(const T arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res);
}
};
template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout *
(temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class ActivationCudaKernel class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
...@@ -357,79 +304,35 @@ namespace plat = paddle::platform; ...@@ -357,79 +304,35 @@ namespace plat = paddle::platform;
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \ ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::bfloat16>>); ops::grad_functor<plat::bfloat16>>);
/* ========================================================================== */
/* ======================== celu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
CudaCELUGradFunctor);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext, relu6, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CELUGradGradFunctor<float>>, ops::CudaRelu6Functor<float>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CELUGradGradFunctor<double>>, ops::CudaRelu6Functor<double>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CELUGradGradFunctor<plat::float16>>); ops::CudaRelu6Functor<int>>,
/* ========================================================================== */ ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6Functor<int64_t>>,
/* =========================== sqrt register ============================= */ ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6Functor<plat::float16>>,
REGISTER_OP_CUDA_KERNEL( ops::ActivationCudaKernel<plat::CUDADeviceContext,
sqrt_grad_grad, ops::CudaRelu6Functor<plat::bfloat16>>);
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<float>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::bfloat16>>);
/* ========================================================================== */
/* =========================== rsqrt register =============================
*/
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
rsqrt_grad_grad, relu6_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::CudaRelu6GradFunctor<float>>,
ops::RsqrtGradGradFunctor<float>>, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::CudaRelu6GradFunctor<double>>,
ops::RsqrtGradGradFunctor<double>>, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext, ops::CudaRelu6GradFunctor<int>>,
ops::RsqrtGradGradFunctor<plat::float16>>); ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
/* ========================================================================== */ ops::CudaRelu6GradFunctor<int64_t>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
/* =========================== square register ============================ */ ops::CudaRelu6GradFunctor<plat::float16>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
REGISTER_OP_CUDA_KERNEL( ops::CudaRelu6GradFunctor<plat::bfloat16>>);
square_grad_grad,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<float>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::float16>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::bfloat16>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<int>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================== logit register ============================ */
namespace ops = paddle::operators;
/* ========================================================================== */
/* ========================== exp register ============================ */
/* ========================================================================== */
/* ========================== expm1 register ============================ */
/* ========================================================================== */
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \ #define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \ __macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \
__macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor); __macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL) FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
...@@ -452,13 +355,14 @@ REGISTER_OP_KERNEL( ...@@ -452,13 +355,14 @@ REGISTER_OP_KERNEL(
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext, ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
ops::CudaZeroGradFunctor<float>>); ops::CudaZeroGradFunctor<float>>);
REGISTER_OP_KERNEL(celu, KP, plat::XPUPlace, REGISTER_OP_KERNEL(
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext, celu, KP, plat::XPUPlace,
ops::CudaCELUFunctor<float>>); ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
phi::funcs::CudaCELUFunctor<float>>);
REGISTER_OP_KERNEL( REGISTER_OP_KERNEL(
celu_grad, KP, plat::XPUPlace, celu_grad, KP, plat::XPUPlace,
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext, ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
ops::CudaCELUGradFunctor<float>>); phi::funcs::CudaCELUGradFunctor<float>>);
REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace, REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace,
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext, ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
......
...@@ -150,6 +150,39 @@ void LogDoubleGradKernel(const Context& dev_ctx, ...@@ -150,6 +150,39 @@ void LogDoubleGradKernel(const Context& dev_ctx,
DenseTensor* dx, DenseTensor* dx,
DenseTensor* ddout); DenseTensor* ddout);
template <typename T, typename Context>
void SqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& ddx,
DenseTensor* dout,
DenseTensor* ddout);
template <typename T, typename Context>
void RsqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& ddx,
DenseTensor* dout,
DenseTensor* ddout);
template <typename T, typename Context>
void CeluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
float alpha,
DenseTensor* dx,
DenseTensor* ddout);
template <typename T, typename Context>
void SquareDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout);
template <typename T, typename Context> template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx, void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -200,6 +233,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda); ...@@ -200,6 +233,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max);
......
...@@ -78,6 +78,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) ...@@ -78,6 +78,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(celu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b)
......
...@@ -167,6 +167,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishGradFunctor, beta); ...@@ -167,6 +167,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishGradFunctor, beta);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
MishGradFunctor, MishGradFunctor,
threshold); threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, CELUGradFunctor, alpha);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
BReluGradFunctor, BReluGradFunctor,
...@@ -281,6 +282,10 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad, ...@@ -281,6 +282,10 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad, PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel) LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad,
SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad,
RsqrtDoubleGradKernel)
PD_REGISTER_KERNEL(tanh_triple_grad, PD_REGISTER_KERNEL(tanh_triple_grad,
CPU, CPU,
...@@ -317,6 +322,15 @@ PD_REGISTER_KERNEL(square_grad, ...@@ -317,6 +322,15 @@ PD_REGISTER_KERNEL(square_grad,
double, double,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(square_double_grad,
CPU,
ALL_LAYOUT,
phi::SquareDoubleGradKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel)
...@@ -332,6 +346,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) ...@@ -332,6 +346,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(celu_double_grad,
CeluDoubleGradKernel)
PD_REGISTER_KERNEL(pow_grad, PD_REGISTER_KERNEL(pow_grad,
CPU, CPU,
......
...@@ -90,19 +90,19 @@ DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor) ...@@ -90,19 +90,19 @@ DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor) DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor, ThresholdedReluFunctor,
threshold) threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, STanhFunctor, scale_a, scale_b)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, SoftplusFunctor, beta, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishFunctor, beta) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishFunctor, beta)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, STanhFunctor, scale_a, scale_b)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, SoftplusFunctor, beta, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
HardSigmoidFunctor, HardSigmoidFunctor,
slope, slope,
...@@ -181,5 +181,6 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) ...@@ -181,5 +181,6 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {} pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {}
...@@ -1832,6 +1832,196 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> { ...@@ -1832,6 +1832,196 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* Out,
const DenseTensor* dX,
const DenseTensor* ddX,
DenseTensor* dOut,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if (dOut) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* Out,
const DenseTensor* dX,
const DenseTensor* ddX,
DenseTensor* dOut,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
auto out = EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if (dOut) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) *
((x / static_cast<T>(alpha)).exp() - static_cast<T>(1)),
x);
}
};
template <typename T>
struct CELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_pos * temp_x_neg +
dout * temp_a_neg * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* dOut,
const DenseTensor* ddX,
DenseTensor* dX,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad"));
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad"));
if (dX) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad"));
dx.device(*d) = ddx * dout / static_cast<T>(alpha) *
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* dOut,
const DenseTensor* ddX,
DenseTensor* dX,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad"));
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad"));
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx
if (dX) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad"));
dx.device(*d) = ddx * static_cast<T>(2) * dout;
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(2) * x;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename T> template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> { struct CudaReluFunctor : public BaseActivationFunctor<T> {
...@@ -3091,6 +3281,59 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> { ...@@ -3091,6 +3281,59 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
} }
}; };
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename phi::dtype::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
__device__ __forceinline__ T operator()(const T arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res);
}
};
template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout *
(temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
#endif #endif
} // namespace funcs } // namespace funcs
......
...@@ -221,6 +221,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, ...@@ -221,6 +221,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
CudaMishGradFunctor, CudaMishGradFunctor,
threshold); threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu,
CudaCELUGradFunctor,
alpha);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
CudaBReluGradFunctor, CudaBReluGradFunctor,
...@@ -351,7 +354,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) ...@@ -351,7 +354,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel)
PD_REGISTER_KERNEL(exp_grad, PD_REGISTER_KERNEL(exp_grad,
GPU, GPU,
...@@ -396,6 +401,16 @@ PD_REGISTER_KERNEL(square_grad, ...@@ -396,6 +401,16 @@ PD_REGISTER_KERNEL(square_grad,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square_double_grad,
GPU,
ALL_LAYOUT,
phi::SquareDoubleGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
...@@ -418,6 +433,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) ...@@ -418,6 +433,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_double_grad, CeluDoubleGradKernel)
PD_REGISTER_KERNEL(pow_grad, PD_REGISTER_KERNEL(pow_grad,
GPU, GPU,
......
...@@ -118,8 +118,8 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, ...@@ -118,8 +118,8 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, CudaSwishFunctor, beta) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, CudaSwishFunctor, beta)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CudaCELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Stanh, CudaSTanhFunctor, scale_a, scale_b) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Stanh, CudaSTanhFunctor, scale_a, scale_b)
...@@ -234,6 +234,7 @@ PD_REGISTER_KERNEL(square, ...@@ -234,6 +234,7 @@ PD_REGISTER_KERNEL(square,
int64_t, int64_t,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16) {} phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel) PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
...@@ -251,6 +252,7 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) ...@@ -251,6 +252,7 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel) PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_KERNEL(pow, PD_REGISTER_KERNEL(pow,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -335,4 +335,87 @@ void PowGradKernel(const Context& dev_ctx, ...@@ -335,4 +335,87 @@ void PowGradKernel(const Context& dev_ctx,
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten); functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
} }
template <typename T, typename Context>
void SqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& ddx,
DenseTensor* dout,
DenseTensor* ddout) {
if (dout) {
dout->Resize(out.dims());
dev_ctx.template Alloc<T>(dout);
}
if (ddout) {
ddout->Resize(out.dims());
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::SqrtGradGradFunctor<T> functor;
functor(dev_ctx, &out, &dx, &ddx, dout, ddout);
}
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx
template <typename T, typename Context>
void RsqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& ddx,
DenseTensor* dout,
DenseTensor* ddout) {
if (dout) {
dout->Resize(out.dims());
dev_ctx.template Alloc<T>(dout);
}
if (ddout) {
ddout->Resize(out.dims());
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::RsqrtGradGradFunctor<T> functor;
functor(dev_ctx, &out, &dx, &ddx, dout, ddout);
}
template <typename T, typename Context>
void CeluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
float alpha,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::CELUGradGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = alpha;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
}
template <typename T, typename Context>
void SquareDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::SquareGradGradFunctor<T> functor;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
}
} // namespace phi } // namespace phi
...@@ -67,6 +67,7 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT ...@@ -67,6 +67,7 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log10, "log10", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log10, "log10", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log1p, "log1p", ); // NOLINT DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log1p, "log1p", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Celu, "celu", "alpha"); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish, DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish,
"hard_swish", "hard_swish",
"threshold" comma "scale" comma "threshold" comma "scale" comma
...@@ -181,6 +182,30 @@ KernelSignature LogDoubleGradOpArgumentMapping( ...@@ -181,6 +182,30 @@ KernelSignature LogDoubleGradOpArgumentMapping(
"log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"}); "log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"});
} }
KernelSignature SqrtDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sqrt_double_grad", {"Out", "DX", "DDX"}, {}, {"DOut", "DDOut"});
}
KernelSignature RsqrtDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"rsqrt_double_grad", {"Out", "DX", "DDX"}, {}, {"DOut", "DDOut"});
}
KernelSignature CeluDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"celu_double_grad", {"X", "DOut", "DDX"}, {"alpha"}, {"DX", "DDOut"});
}
KernelSignature SquareDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"square_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"});
}
KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) { if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"}); return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"});
...@@ -209,6 +234,10 @@ PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad); ...@@ -209,6 +234,10 @@ PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad);
PD_REGISTER_BASE_KERNEL_NAME(elu_grad_grad, elu_double_grad); PD_REGISTER_BASE_KERNEL_NAME(elu_grad_grad, elu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(sigmoid_grad_grad, sigmoid_double_grad); PD_REGISTER_BASE_KERNEL_NAME(sigmoid_grad_grad, sigmoid_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(log_grad_grad, log_double_grad); PD_REGISTER_BASE_KERNEL_NAME(log_grad_grad, log_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(sqrt_grad_grad, sqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(rsqrt_grad_grad, rsqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(celu_grad_grad, celu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(square_grad_grad, square_double_grad);
PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping);
...@@ -229,7 +258,11 @@ PD_REGISTER_ARG_MAPPING_FN(square_grad, phi::SquareGradOpArgumentMapping); ...@@ -229,7 +258,11 @@ PD_REGISTER_ARG_MAPPING_FN(square_grad, phi::SquareGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reciprocal_grad, PD_REGISTER_ARG_MAPPING_FN(reciprocal_grad,
phi::ReciprocalGradOpArgumentMapping); phi::ReciprocalGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sqrt_grad, phi::SqrtGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(sqrt_grad, phi::SqrtGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sqrt_grad_grad,
phi::SqrtDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad, phi::RsqrtGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad, phi::RsqrtGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad_grad,
phi::RsqrtDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softplus_grad, phi::SoftplusGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(softplus_grad, phi::SoftplusGradOpArgumentMapping);
...@@ -286,3 +319,8 @@ PD_REGISTER_ARG_MAPPING_FN(floor_grad, phi::FloorGradOpArgumentMapping); ...@@ -286,3 +319,8 @@ PD_REGISTER_ARG_MAPPING_FN(floor_grad, phi::FloorGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(ceil_grad, phi::CeilGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(ceil_grad, phi::CeilGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(celu_grad, phi::CeluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(celu_grad_grad,
phi::CeluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(square_grad_grad,
phi::SquareDoubleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册