diff --git a/paddle/phi/common/complex.h b/paddle/phi/common/complex.h index 130047f850426e2bcae09ba49f5b308a0cc17677..e0ff7f11ac5427959ad719fe37012d02843b2c53 100644 --- a/paddle/phi/common/complex.h +++ b/paddle/phi/common/complex.h @@ -476,6 +476,16 @@ HOSTDEVICE inline complex conj(const complex& a) { #endif } +template +HOSTDEVICE inline complex exp(const complex& a) { +#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ + (defined(__CUDA_ARCH__) || defined(__HIPCC__)) + return complex(thrust::exp(thrust::complex(a))); +#else + return complex(std::exp(std::complex(a))); +#endif +} + template HOSTDEVICE inline complex log(const complex& a) { #if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \ diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 1216801e0eee02908a237a6b21dbd9830bb6eaf0..68f3fce76a8dcccab6335f689b812eb8eb26e8f8 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -423,7 +423,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(logsigmoid_grad, LogSigmoidGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad, + LogSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 204947572ce2a0d507f5363c0b0f7685f1ad7409..b3f86c7c908db3f1bffe64e15bdb482a78cb3f51 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -228,7 +228,7 @@ PD_REGISTER_KERNEL( square, CPU, ALL_LAYOUT, phi::SquareKernel, float, double, int, int64_t) {} PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) -PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 926920dfa93dddbd3f67fd9adefab14958bb6993..e0ee4ea7d7a2b2f44bf1e6546e450edd912492f1 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2051,6 +2051,25 @@ struct LogSigmoidGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct LogSigmoidGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + auto temp = + (-x).cwiseMax(static_cast>(0)); // temp = max(-x, 0) + dx.device(d) = + dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct HardSigmoidFunctor : public BaseActivationFunctor { float slope; @@ -3862,6 +3881,28 @@ struct CudaLogSigmoidGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLogSigmoidGradFunctor> + : public BaseActivationFunctor> { + ComplexType zero = static_cast>(0.0f); + + // dx = dout * exp(-x) / (1 + exp(-x)) + // For numerical stability: + // dx = dout * exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) + exp(-x - max(-x, + // 0))) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_dout, const ComplexType arg_x) const { + ComplexType dout = static_cast>(arg_dout); + ComplexType x = static_cast>(arg_x); + ComplexType temp1 = x > zero ? zero : -x; + ComplexType temp2 = exp(-x - temp1); + return static_cast>(dout * + conj(temp2 / (exp(-temp1) + temp2))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaHardSigmoidFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index ea4a88683c0bfe332f475aac792f0890f5928f8e..43460eb10a41153f64bb5eb2b22533a983f88ab0 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -495,7 +495,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_double_grad, SigmoidDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_triple_grad, SigmoidTripleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(logsigmoid_grad, LogSigmoidGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad, + LogSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index d29df5758931574fceb8285dce73f47605133542..aefe04385d7f758aa037733d7d09dff2eb884a08 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -290,7 +290,7 @@ PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel) PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel) PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel) PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) -PD_REGISTER_ACTIVATION_KERNEL(logsigmoid, LogSigmoidKernel) +PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardsigmoid, HardSigmoidKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index aacc0232be2f3f545b6694b7a52c82a16ec6d093..131201b1bc60c4b905662c484c0480b2559813ed 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -790,7 +790,7 @@ def log_sigmoid(x, name=None): log\_sigmoid(x) = log \frac{1}{1 + e^{-x}} Parameters: - x (Tensor): The input Tensor with data type float32, float64. + x (Tensor): The input Tensor with data type float32, float64, complex64, complex128. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: @@ -813,7 +813,10 @@ def log_sigmoid(x, name=None): return _C_ops.logsigmoid(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64'], 'log_sigmoid' + x, + 'x', + ['float16', 'float32', 'float64', 'complex64', 'complex128'], + 'log_sigmoid', ) helper = LayerHelper("log_sigmoid", **locals()) out = helper.create_variable_for_type_inference(x.dtype) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 144b7fdcaa4e5875d25ad195ad90d92e19369dd9..a0bb12264d61df74424ec82e78f6fbf94fa96cd2 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -464,7 +464,13 @@ class TestLogSigmoid(TestActivation): self.init_shape() np.random.seed(2048) - x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + if self.dtype is np.complex64 or self.dtype is np.complex128: + x = ( + np.random.uniform(-1, 1, self.shape) + + 1j * np.random.uniform(-1, 1, self.shape) + ).astype(self.dtype) + else: + x = np.random.uniform(-1, 1, self.shape).astype(self.dtype) out = np.log(1 / (1 + np.exp(-x))) self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} self.outputs = {'Out': out} @@ -477,6 +483,16 @@ class TestLogSigmoid(TestActivation): self.check_grad(['X'], 'Out', max_relative_error=0.008) +class TestLogSigmoidComplex64(TestLogSigmoid): + def init_dtype(self): + self.dtype = np.complex64 + + +class TestLogSigmoidComplex128(TestLogSigmoid): + def init_dtype(self): + self.dtype = np.complex128 + + class TestLogSigmoid_ZeroDim(TestLogSigmoid): def init_shape(self): self.shape = []