diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index b70a752b2eca5bf55ec16dbd3595d9d0ca445679..4a42799764fa8014d06d0d4b6e0c84029bad258c 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1313,6 +1313,17 @@ func : slogdet_grad data_type : out_grad +- backward_op : softplus_double_grad + forward : softplus_grad (Tensor x, Tensor grad_out, float beta, float threshold) -> Tensor(grad_x) + args : (Tensor x, Tensor grad_out, Tensor grad_x_grad, float beta, float threshold) + output : Tensor(x_grad), Tensor(grad_out_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [x, x] + kernel : + func : softplus_double_grad + inplace : (grad_x_grad -> grad_out_grad) + - backward_op : softplus_grad forward : softplus (Tensor x, float beta, float threshold) -> Tensor(out) args : (Tensor x, Tensor out_grad, float beta, float threshold) @@ -1322,6 +1333,7 @@ param : [x] kernel : func : softplus_grad + backward : softplus_double_grad inplace : (out_grad -> x_grad) - backward_op : softshrink_grad diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index e0ee3f23d4c47ab2939b39ee202a8812738b8fff..13131bd345bd1a96fa8dea587e7d5958e127b3cb 100644 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1600,7 +1600,7 @@ attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] - op : softplus - backward : softplus_grad + backward : softplus_grad, softplus_double_grad inputs : x : X outputs : diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index 56cb316640d1923eaf72282b5bb503431f71d051..b65a2304cac47f1bb89abe0558b35222a5c38fb4 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -257,6 +257,17 @@ void PowTripleGradKernel(const Context& dev_ctx, DenseTensor* out_d_x, DenseTensor* out_d_dout, DenseTensor* out_d_ddx); + +template +void SoftplusDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + float beta, + float threshold, + DenseTensor* dx, + DenseTensor* ddout); + DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Cos); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Tan); DECLARE_ACTIVATION_GRAD_KERNEL_DEPX(Acos); diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 128336d6a5e60ca981ae5185838ca590ae8b2810..1f3e8b4cc7ba3323c00ddc9887f9b71945aa1711 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -291,6 +291,8 @@ PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(rsqrt_double_grad, RsqrtDoubleGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(softplus_double_grad, + SoftplusDoubleGradKernel) PD_REGISTER_KERNEL(tanh_triple_grad, CPU, diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 35970e2b7df91429b1c0f1e9ce8464c54b0d99cf..4fca3ccc3af4d93622a46417bfc6070974ce51fb 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -693,6 +693,56 @@ struct SoftplusGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct SoftplusDoubleGradFunctor : public BaseActivationFunctor { + float beta; + float threshold; + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* dOut, + const DenseTensor* ddX, + DenseTensor* dX, + DenseTensor* ddOut) const { + auto* d = dev.eigen_device(); + auto x = EigenVector::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "SoftplusDoubleGrad")); + auto x_beta = static_cast(beta) * x; + auto ddx = EigenVector::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "SoftplusDoubleGrad")); + + if (dX) { + auto dx = EigenVector::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "SoftplusDoubleGrad")); + auto dout = EigenVector::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "SoftplusDoubleGrad")); + // ddx * dout * beta * exp(x_beta) / (exp(x_beta) + 1) ^ 2, if x_beta + // <= threshold + // 0, if x_beta > threshold + dx.device(*d) = + (x_beta > static_cast(threshold)) + .select(x.constant(static_cast(0)), + ddx * dout * static_cast(beta) * x_beta.exp() / + (x_beta.exp() + static_cast(1)) + .pow(static_cast(2))); + } + + if (ddOut) { + auto ddout = EigenVector::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SoftplusDoubleGrad")); + // ddx / (1 + exp(-x_beta)), if x_beta <= threshold + // ddx, if x_beta > threshold + ddout.device(*d) = + (x_beta > static_cast(threshold)) + .select(ddx, ddx / (static_cast(1) + (-x_beta).exp())); + } + } + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + // Tangent(x) = tan(x) template struct TanFunctor : public BaseActivationFunctor { diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 441790aab3ae2167784b1fab6c61ff53b34c4037..fc7bf8b1cc37f56ad4045f64db646d4a7281edc8 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -358,6 +358,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_double_grad, + SoftplusDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_double_grad, SqrtDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(rsqrt_grad, RsqrtGradKernel) diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index fd6a69f1f2052f880699d6084eaccb8a2cfd8fc7..ffe9ac26c6935407802c5948d529ca69b8f79d15 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -575,6 +575,30 @@ void CeluDoubleGradKernel(const Context& dev_ctx, functor(dev_ctx, &x, &dout, &ddx, dx, ddout); } +template +void SoftplusDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const DenseTensor& ddx, + float beta, + float threshold, + DenseTensor* dx, + DenseTensor* ddout) { + if (dx) { + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + } + if (ddout) { + dev_ctx.template Alloc(ddout); + } + + phi::funcs::SoftplusDoubleGradFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = beta; + *(attrs[1].second) = threshold; + functor(dev_ctx, &x, &dout, &ddx, dx, ddout); +} + template void SquareDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py index 8333da1accfda292da640353f0f005bd29224873..78102130dcb1d8c676fb7c9178db9c0f5c35b562 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_activation_nn_grad.py @@ -296,6 +296,41 @@ class TestCELUDoubleGradCheck(unittest.TestCase): self.func(p) +class TestSoftplusDoubleGradCheck(unittest.TestCase): + def softplus_wrapper(self, x): + return F.softplus(x[0], beta=1, threshold=20) + + @prog_scope() + def func(self, place): + shape = [2, 4, 4, 4] + eps = 1e-6 + beta = 1 + threshold = 20 + dtype = np.float64 + SEED = 0 + + x = paddle.static.data('x', shape, dtype) + x.persistable = True + + y = F.softplus(x, beta=beta, threshold=threshold) + np.random.RandomState(SEED) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + gradient_checker.double_grad_check( + [x], y, x_init=x_arr, place=place, eps=eps + ) + gradient_checker.double_grad_check_for_dygraph( + self.softplus_wrapper, [x], y, x_init=x_arr, place=place + ) + + def test_grad(self): + paddle.enable_static() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + class TestSqrtDoubleGradCheck(unittest.TestCase): def sqrt_wrapper(self, x): return paddle.sqrt(x[0])