未验证 提交 542844b4 编写于 作者: W will-jl944 提交者: GitHub

Add softplus double grad (#50261)

* add softplus double grad

* use constant method
上级 1a966db2
...@@ -1313,6 +1313,17 @@ ...@@ -1313,6 +1313,17 @@
func : slogdet_grad func : slogdet_grad
data_type : out_grad data_type : out_grad
- backward_op : softplus_double_grad
forward : softplus_grad (Tensor x, Tensor grad_out, float beta, float threshold) -> Tensor(grad_x)
args : (Tensor x, Tensor grad_out, Tensor grad_x_grad, float beta, float threshold)
output : Tensor(x_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [x, x]
kernel :
func : softplus_double_grad
inplace : (grad_x_grad -> grad_out_grad)
- backward_op : softplus_grad - backward_op : softplus_grad
forward : softplus (Tensor x, float beta, float threshold) -> Tensor(out) forward : softplus (Tensor x, float beta, float threshold) -> Tensor(out)
args : (Tensor x, Tensor out_grad, float beta, float threshold) args : (Tensor x, Tensor out_grad, float beta, float threshold)
...@@ -1322,6 +1333,7 @@ ...@@ -1322,6 +1333,7 @@
param : [x] param : [x]
kernel : kernel :
func : softplus_grad func : softplus_grad
backward : softplus_double_grad
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : softshrink_grad - backward_op : softshrink_grad
......
...@@ -1600,7 +1600,7 @@ ...@@ -1600,7 +1600,7 @@
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
- op : softplus - op : softplus
backward : softplus_grad backward : softplus_grad, softplus_double_grad
inputs : inputs :
x : X x : X
outputs : outputs :
......
...@@ -257,6 +257,17 @@ void PowTripleGradKernel(const Context& dev_ctx, ...@@ -257,6 +257,17 @@ void PowTripleGradKernel(const Context& dev_ctx,
DenseTensor* out_d_x, DenseTensor* out_d_x,
DenseTensor* out_d_dout, DenseTensor* out_d_dout,
DenseTensor* out_d_ddx); DenseTensor* out_d_ddx);
template <typename T, typename Context>
void SoftplusDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
float beta,
float threshold,
DenseTensor* dx,
DenseTensor* ddout);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan);
DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos);
......
...@@ -291,6 +291,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad, ...@@ -291,6 +291,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad,
SqrtDoubleGradKernel) SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad, PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad,
RsqrtDoubleGradKernel) RsqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_KERNEL(tanh_triple_grad, PD_REGISTER_KERNEL(tanh_triple_grad,
CPU, CPU,
......
...@@ -693,6 +693,56 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> { ...@@ -693,6 +693,56 @@ struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
}; };
template <typename T>
struct SoftplusDoubleGradFunctor : public BaseActivationFunctor<T> {
float beta;
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* X,
const DenseTensor* dOut,
const DenseTensor* ddX,
DenseTensor* dX,
DenseTensor* ddOut) const {
auto* d = dev.eigen_device();
auto x = EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "SoftplusDoubleGrad"));
auto x_beta = static_cast<T>(beta) * x;
auto ddx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "SoftplusDoubleGrad"));
if (dX) {
auto dx = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "SoftplusDoubleGrad"));
auto dout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SoftplusDoubleGrad"));
// ddx * dout * beta * exp(x_beta) / (exp(x_beta) + 1) ^ 2, if x_beta
// <= threshold
// 0, if x_beta > threshold
dx.device(*d) =
(x_beta > static_cast<T>(threshold))
.select(x.constant(static_cast<T>(0)),
ddx * dout * static_cast<T>(beta) * x_beta.exp() /
(x_beta.exp() + static_cast<T>(1))
.pow(static_cast<T>(2)));
}
if (ddOut) {
auto ddout = EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SoftplusDoubleGrad"));
// ddx / (1 + exp(-x_beta)), if x_beta <= threshold
// ddx, if x_beta > threshold
ddout.device(*d) =
(x_beta > static_cast<T>(threshold))
.select(ddx, ddx / (static_cast<T>(1) + (-x_beta).exp()));
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
// Tangent(x) = tan(x) // Tangent(x) = tan(x)
template <typename T> template <typename T>
struct TanFunctor : public BaseActivationFunctor<T> { struct TanFunctor : public BaseActivationFunctor<T> {
......
...@@ -358,6 +358,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) ...@@ -358,6 +358,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad,
SoftplusDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel)
......
...@@ -575,6 +575,30 @@ void CeluDoubleGradKernel(const Context& dev_ctx, ...@@ -575,6 +575,30 @@ void CeluDoubleGradKernel(const Context& dev_ctx,
functor(dev_ctx, &x, &dout, &ddx, dx, ddout); functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
} }
template <typename T, typename Context>
void SoftplusDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
const DenseTensor& ddx,
float beta,
float threshold,
DenseTensor* dx,
DenseTensor* ddout) {
if (dx) {
dx->Resize(x.dims());
dev_ctx.template Alloc<T>(dx);
}
if (ddout) {
dev_ctx.template Alloc<T>(ddout);
}
phi::funcs::SoftplusDoubleGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = beta;
*(attrs[1].second) = threshold;
functor(dev_ctx, &x, &dout, &ddx, dx, ddout);
}
template <typename T, typename Context> template <typename T, typename Context>
void SquareDoubleGradKernel(const Context& dev_ctx, void SquareDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -296,6 +296,41 @@ class TestCELUDoubleGradCheck(unittest.TestCase): ...@@ -296,6 +296,41 @@ class TestCELUDoubleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
class TestSoftplusDoubleGradCheck(unittest.TestCase):
def softplus_wrapper(self, x):
return F.softplus(x[0], beta=1, threshold=20)
@prog_scope()
def func(self, place):
shape = [2, 4, 4, 4]
eps = 1e-6
beta = 1
threshold = 20
dtype = np.float64
SEED = 0
x = paddle.static.data('x', shape, dtype)
x.persistable = True
y = F.softplus(x, beta=beta, threshold=threshold)
np.random.RandomState(SEED)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps
)
gradient_checker.double_grad_check_for_dygraph(
self.softplus_wrapper, [x], y, x_init=x_arr, place=place
)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestSqrtDoubleGradCheck(unittest.TestCase): class TestSqrtDoubleGradCheck(unittest.TestCase):
def sqrt_wrapper(self, x): def sqrt_wrapper(self, x):
return paddle.sqrt(x[0]) return paddle.sqrt(x[0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册