未验证 提交 191c441a 编写于 作者: Y YuanRisheng 提交者: GitHub

move activation kernel (#42880)

上级 d8b69124
......@@ -1659,15 +1659,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::CELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(celu, CELU, CELUFunctor, CELUGradFunctor);
REGISTER_OP_CPU_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<float>>,
ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<double>>,
ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
......@@ -1687,13 +1678,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<float>>,
ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== rsqrt register =============================
......@@ -1714,14 +1698,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::RsqrtGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
rsqrt_grad_grad,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<float>>,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<double>>,
ops::RsqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== square register ============================ */
......@@ -1742,18 +1718,6 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_OP_CPU_KERNEL(
square_grad_grad,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<float>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<plat::float16>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<int>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================== pow register ============================ */
......
......@@ -296,9 +296,14 @@ USE_PHI_FUNCTOR(Mish)
USE_PHI_FUNCTOR(STanh)
USE_PHI_FUNCTOR(Reciprocal)
USE_PHI_FUNCTOR(Square)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Square)
USE_PHI_FUNCTOR(Sqrt)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Sqrt)
USE_PHI_FUNCTOR(Rsqrt)
USE_PHI_DOUBLE_GRAD_FUNCTOR(Rsqrt)
USE_PHI_FUNCTOR(Softplus)
USE_PHI_FUNCTOR(CELU)
USE_PHI_DOUBLE_GRAD_FUNCTOR(CELU)
template <typename T>
using ELUGradNegativeAlphaFunctor = phi::funcs::ELUGradNegativeAlphaFunctor<T>;
......@@ -331,68 +336,6 @@ using ReluGradGradFunctor = phi::funcs::ReluGradGradFunctor<T>;
template <typename T>
using ReluCUDAFunctor = phi::funcs::ReluCUDAFunctor<T>;
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
// relu6(x) = min(max(0, x), 6)
template <typename T>
struct Relu6Functor : public BaseActivationFunctor<T> {
......@@ -498,51 +441,6 @@ class ELUGradKernel : public framework::OpKernel<T> {
}
};
template <typename T>
struct CELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) *
((x / static_cast<T>(alpha)).exp() - static_cast<T>(1)),
x);
}
};
template <typename T>
struct CELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_pos * temp_x_neg +
dout * temp_a_neg * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
......@@ -564,74 +462,6 @@ struct AbsGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad"));
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad"));
dx.device(*d) = ddx * dout / static_cast<T>(alpha) *
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad"));
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad"));
dx.device(*d) = ddx * static_cast<T>(2) * dout;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(2) * x;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
// DOut(dy) as input(not output), tensor extraction is different from
// others. Impliment extraction kernel separately here.
......@@ -675,29 +505,6 @@ inline void ExtractDoubleGradTensorWithInputDOut(
}
}
template <typename DeviceContext, typename Functor>
class SquareDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, X, ddX, ddOut, dOut, dX);
}
};
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
......@@ -721,153 +528,6 @@ struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename DeviceContext, typename Functor>
class CELUDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(place, X, ddX, ddOut, dOut, dX);
}
};
template <typename DeviceContext, typename Functor>
class SqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *dX, *ddX;
Out = dX = ddX = nullptr;
framework::Tensor *ddOut, *dOut;
ddOut = dOut = nullptr;
// extract ddx(input), ddout(output)
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE_NOT_NULL(
ddx_var, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
ddOut = ctx.Output<framework::Tensor>("DDOut");
}
PADDLE_ENFORCE_NOT_NULL(
ddX, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
// extract out(input), dout(output)
auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound(
"Cannot get input Variable Out, variable name = %s",
ctx.InputName("Out")));
auto dout_var = ctx.OutputVar("DOut");
Out = ctx.Input<framework::Tensor>("Out");
if (dout_var) {
dOut = ctx.Output<framework::Tensor>("DOut");
}
// extract dx(input)
auto dx_var = ctx.InputVar("DX");
PADDLE_ENFORCE_NOT_NULL(
dx_var, platform::errors::NotFound(
"Cannot get input Variable DX, variable name = %s",
ctx.InputName("DX")));
if (dx_var) {
dX = ctx.Input<framework::Tensor>("DX");
}
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, ddOut, dOut, dX);
}
};
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx
template <typename DeviceContext, typename Functor>
class RsqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *dX, *ddX;
Out = dX = ddX = nullptr;
framework::Tensor *ddOut, *dOut;
ddOut = dOut = nullptr;
// extract ddx(input), ddout(output)
auto ddx_var = ctx.InputVar("DDX");
auto ddo_var = ctx.OutputVar("DDOut");
PADDLE_ENFORCE_NOT_NULL(
ddx_var, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
ddX = ctx.Input<framework::Tensor>("DDX");
if (ddo_var) {
ddOut = ctx.Output<framework::Tensor>("DDOut");
}
PADDLE_ENFORCE_NOT_NULL(
ddX, platform::errors::NotFound(
"Cannot get input Variable DDX, variable name = %s",
ctx.InputName("DDX")));
// extract out(input), dout(output)
auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound(
"Cannot get input Variable Out, variable name = %s",
ctx.InputName("Out")));
auto dout_var = ctx.OutputVar("DOut");
Out = ctx.Input<framework::Tensor>("Out");
if (dout_var) {
dOut = ctx.Output<framework::Tensor>("DOut");
}
// extract dx(input)
auto dx_var = ctx.InputVar("DX");
PADDLE_ENFORCE_NOT_NULL(
dx_var, platform::errors::NotFound(
"Cannot get input Variable DX, variable name = %s",
ctx.InputName("DX")));
if (dx_var) {
dX = ctx.Input<framework::Tensor>("DX");
}
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, ddOut, dOut, dX);
}
};
} // namespace operators
} // namespace paddle
......
......@@ -126,59 +126,6 @@ struct CudaSoftsignGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
__device__ __forceinline__ T operator()(const T arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res);
}
};
template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout *
(temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......@@ -357,79 +304,35 @@ namespace plat = paddle::platform;
ops::ActivationGradCudaKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::bfloat16>>);
/* ========================================================================== */
/* ======================== celu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
CudaCELUGradFunctor);
REGISTER_OP_CUDA_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<float>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<double>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_OP_CUDA_KERNEL(
sqrt_grad_grad,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<float>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>,
ops::SqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SqrtGradGradFunctor<plat::bfloat16>>);
/* ========================================================================== */
/* =========================== rsqrt register =============================
*/
relu6, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6Functor<float>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6Functor<double>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6Functor<int>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaRelu6Functor<int64_t>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6Functor<plat::float16>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6Functor<plat::bfloat16>>);
REGISTER_OP_CUDA_KERNEL(
rsqrt_grad_grad,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::RsqrtGradGradFunctor<float>>,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::RsqrtGradGradFunctor<double>>,
ops::RsqrtDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::RsqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== square register ============================ */
REGISTER_OP_CUDA_KERNEL(
square_grad_grad,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<float>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::float16>>,
ops::SquareDoubleGradKernel<plat::CUDADeviceContext,
ops::SquareGradGradFunctor<plat::bfloat16>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<int>>,
ops::SquareDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::SquareGradGradFunctor<int64_t>>);
/* ========================================================================== */
/* ========================== logit register ============================ */
namespace ops = paddle::operators;
/* ========================================================================== */
/* ========================== exp register ============================ */
/* ========================================================================== */
/* ========================== expm1 register ============================ */
/* ========================================================================== */
relu6_grad, ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6GradFunctor<float>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6GradFunctor<double>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6GradFunctor<int>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6GradFunctor<int64_t>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6GradFunctor<plat::float16>>,
ops::ActivationGradCudaKernel<plat::CUDADeviceContext,
ops::CudaRelu6GradFunctor<plat::bfloat16>>);
#define FOR_EACH_ACTIVATION_CUDA_OP(__macro) \
__macro(soft_relu, SoftRelu, CudaSoftReluFunctor, CudaSoftReluGradFunctor); \
__macro(relu6, Relu6, CudaRelu6Functor, CudaRelu6GradFunctor); \
__macro(softsign, Softsign, CudaSoftsignFunctor, CudaSoftsignGradFunctor);
FOR_EACH_ACTIVATION_CUDA_OP(REGISTER_ACTIVATION_CUDA_KERNEL)
......@@ -452,13 +355,14 @@ REGISTER_OP_KERNEL(
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
ops::CudaZeroGradFunctor<float>>);
REGISTER_OP_KERNEL(celu, KP, plat::XPUPlace,
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
ops::CudaCELUFunctor<float>>);
REGISTER_OP_KERNEL(
celu, KP, plat::XPUPlace,
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
phi::funcs::CudaCELUFunctor<float>>);
REGISTER_OP_KERNEL(
celu_grad, KP, plat::XPUPlace,
ops::ActivationGradCudaKernel<paddle::platform::XPUDeviceContext,
ops::CudaCELUGradFunctor<float>>);
phi::funcs::CudaCELUGradFunctor<float>>);
REGISTER_OP_KERNEL(elu, KP, plat::XPUPlace,
ops::ActivationCudaKernel<paddle::platform::XPUDeviceContext,
......
......@@ -150,6 +150,39 @@ void LogDoubleGradKernel(const Context& dev_ctx,
DenseTensor* dx,
DenseTensor* ddout);
template <typename T, typename Context>
void SqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& ddx,
DenseTensor* dout,
DenseTensor* ddout);
template <typename T, typename Context>
void RsqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& ddx,
DenseTensor* dout,
DenseTensor* ddout);
template <typename T, typename Context>
void CeluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
float alpha,
DenseTensor* dx,
DenseTensor* ddout);
template <typename T, typename Context>
void SquareDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout);
template <typename T, typename Context>
void HardSwishGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -200,6 +233,7 @@ DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, beta);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps);
DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, alpha);
DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu, t_min, t_max);
......
......@@ -78,6 +78,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Elu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Swish, beta)
DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(celu, alpha)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(BRelu, t_min, t_max)
DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b)
......
......@@ -167,6 +167,7 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, SwishGradFunctor, beta);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
MishGradFunctor,
threshold);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu, CELUGradFunctor, alpha);
DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
BReluGradFunctor,
......@@ -281,6 +282,10 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad,
SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad,
RsqrtDoubleGradKernel)
PD_REGISTER_KERNEL(tanh_triple_grad,
CPU,
......@@ -317,6 +322,15 @@ PD_REGISTER_KERNEL(square_grad,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(square_double_grad,
CPU,
ALL_LAYOUT,
phi::SquareDoubleGradKernel,
float,
double,
phi::dtype::float16,
int,
int64_t) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel)
......@@ -332,6 +346,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(celu_double_grad,
CeluDoubleGradKernel)
PD_REGISTER_KERNEL(pow_grad,
CPU,
......
......@@ -90,19 +90,19 @@ DEFINE_CPU_ACTIVATION_KERNEL(Floor, FloorFunctor)
DEFINE_CPU_ACTIVATION_KERNEL(Ceil, CeilFunctor)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu,
ThresholdedReluFunctor,
threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, STanhFunctor, scale_a, scale_b)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, SoftplusFunctor, beta, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, ELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, SwishFunctor, beta)
DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CELUFunctor, alpha)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, BReluFunctor, t_min, t_max)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(STanh, STanhFunctor, scale_a, scale_b)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(Softplus, SoftplusFunctor, beta, threshold)
DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid,
HardSigmoidFunctor,
slope,
......@@ -181,5 +181,6 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_KERNEL(
pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {}
......@@ -1832,6 +1832,196 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* Out,
const DenseTensor* dX,
const DenseTensor* ddX,
DenseTensor* dOut,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx
if (dOut) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* Out,
const DenseTensor* dX,
const DenseTensor* ddX,
DenseTensor* dOut,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
auto out = EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
if (dOut) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) *
((x / static_cast<T>(alpha)).exp() - static_cast<T>(1)),
x);
}
};
template <typename T>
struct CELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_pos * temp_x_neg +
dout * temp_a_neg * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* dOut,
const DenseTensor* ddX,
DenseTensor* dX,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad"));
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad"));
if (dX) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad"));
dx.device(*d) = ddx * dout / static_cast<T>(alpha) *
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* dOut,
const DenseTensor* ddX,
DenseTensor* dX,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad"));
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad"));
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx
if (dX) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad"));
dx.device(*d) = ddx * static_cast<T>(2) * dout;
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(2) * x;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
template <typename T>
struct CudaReluFunctor : public BaseActivationFunctor<T> {
......@@ -3091,6 +3281,59 @@ struct CudaZeroGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename phi::dtype::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
__device__ __forceinline__ T operator()(const T arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res);
}
};
template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout *
(temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
#endif
} // namespace funcs
......
......@@ -221,6 +221,9 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish,
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish,
CudaMishGradFunctor,
threshold);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Celu,
CudaCELUGradFunctor,
alpha);
DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(BRelu,
CudaBReluGradFunctor,
......@@ -351,7 +354,9 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel)
PD_REGISTER_KERNEL(exp_grad,
GPU,
......@@ -396,6 +401,16 @@ PD_REGISTER_KERNEL(square_grad,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(square_double_grad,
GPU,
ALL_LAYOUT,
phi::SquareDoubleGradKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel)
......@@ -418,6 +433,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(round_grad, RoundGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(floor_grad, FloorGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(ceil_grad, CeilGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_grad, CeluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(celu_double_grad, CeluDoubleGradKernel)
PD_REGISTER_KERNEL(pow_grad,
GPU,
......
......@@ -118,8 +118,8 @@ DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink,
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, CudaSoftShrinkFunctor, lambda)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Elu, CudaELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Swish, CudaSwishFunctor, beta)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, CudaMishFunctor, threshold)
DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(Celu, CudaCELUFunctor, alpha)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(BRelu, CudaBReluFunctor, t_min, t_max)
DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Stanh, CudaSTanhFunctor, scale_a, scale_b)
......@@ -234,6 +234,7 @@ PD_REGISTER_KERNEL(square,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(soft_shrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
......@@ -251,6 +252,7 @@ PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel)
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(ceil, CeilKernel)
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_KERNEL(pow,
GPU,
ALL_LAYOUT,
......
......@@ -335,4 +335,87 @@ void PowGradKernel(const Context& dev_ctx,
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
}
template <typename T, typename Context>
void SqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& ddx,
DenseTensor* dout,
DenseTensor* ddout) {
if (dout) {
dout->Resize(out.dims());
dev_ctx.template Alloc<T>(dout);
}
if (ddout) {
ddout->Resize(out.dims());
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::SqrtGradGradFunctor<T> functor;
functor(dev_ctx, &out, &dx, &ddx, dout, ddout);
}
// rsqrt Grad: dx = -0.5 * dy * y * y * y
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx
template <typename T, typename Context>
void RsqrtDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& ddx,
DenseTensor* dout,
DenseTensor* ddout) {
if (dout) {
dout->Resize(out.dims());
dev_ctx.template Alloc<T>(dout);
}
if (ddout) {
ddout->Resize(out.dims());
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::RsqrtGradGradFunctor<T> functor;
functor(dev_ctx, &out, &dx, &ddx, dout, ddout);
}
template <typename T, typename Context>
void CeluDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
float alpha,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::CELUGradGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = alpha;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
}
template <typename T, typename Context>
void SquareDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::SquareGradGradFunctor<T> functor;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
}
} // namespace phi
......@@ -67,6 +67,7 @@ DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log, "log", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log2, "log2", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log10, "log10", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Log1p, "log1p", ); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(Celu, "celu", "alpha"); // NOLINT
DEFINE_ACT_GRAD_DEPX_OP_ARGMAP(HardSwish,
"hard_swish",
"threshold" comma "scale" comma
......@@ -181,6 +182,30 @@ KernelSignature LogDoubleGradOpArgumentMapping(
"log_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"});
}
KernelSignature SqrtDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sqrt_double_grad", {"Out", "DX", "DDX"}, {}, {"DOut", "DDOut"});
}
KernelSignature RsqrtDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"rsqrt_double_grad", {"Out", "DX", "DDX"}, {}, {"DOut", "DDOut"});
}
KernelSignature CeluDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"celu_double_grad", {"X", "DOut", "DDX"}, {"alpha"}, {"DX", "DDOut"});
}
KernelSignature SquareDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"square_double_grad", {"X", "DOut", "DDX"}, {}, {"DX", "DDOut"});
}
KernelSignature PowOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("FactorTensor")) {
return KernelSignature("pow", {"X"}, {"FactorTensor"}, {"Out"});
......@@ -209,6 +234,10 @@ PD_REGISTER_BASE_KERNEL_NAME(softshrink_grad, soft_shrink_grad);
PD_REGISTER_BASE_KERNEL_NAME(elu_grad_grad, elu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(sigmoid_grad_grad, sigmoid_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(log_grad_grad, log_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(sqrt_grad_grad, sqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(rsqrt_grad_grad, rsqrt_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(celu_grad_grad, celu_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(square_grad_grad, square_double_grad);
PD_REGISTER_ARG_MAPPING_FN(cos_grad, phi::CosGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(tan_grad, phi::TanGradOpArgumentMapping);
......@@ -229,7 +258,11 @@ PD_REGISTER_ARG_MAPPING_FN(square_grad, phi::SquareGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(reciprocal_grad,
phi::ReciprocalGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sqrt_grad, phi::SqrtGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(sqrt_grad_grad,
phi::SqrtDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad, phi::RsqrtGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(rsqrt_grad_grad,
phi::RsqrtDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(mish_grad, phi::MishGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(stanh_grad, phi::STanhGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(softplus_grad, phi::SoftplusGradOpArgumentMapping);
......@@ -286,3 +319,8 @@ PD_REGISTER_ARG_MAPPING_FN(floor_grad, phi::FloorGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(ceil_grad, phi::CeilGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow_grad, phi::PowGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(pow, phi::PowOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(celu_grad, phi::CeluGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(celu_grad_grad,
phi::CeluDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(square_grad_grad,
phi::SquareDoubleGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册