未验证 提交 1d4e938d 编写于 作者: R Ruibin Cheung 提交者: GitHub

[complex] add complex support for silu (#56903)

上级 5ed7c8a0
......@@ -496,6 +496,16 @@ HOSTDEVICE inline complex<T> log(const complex<T>& a) {
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> exp(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::exp(thrust::complex<T>(a)));
#else
return complex<T>(std::exp(std::complex<T>(a)));
#endif
}
template <typename T>
inline std::ostream& operator<<(std::ostream& os, const complex<T>& a) {
os << "real:" << a.real << " imag:" << a.imag;
......
......@@ -301,7 +301,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(stanh_grad, STanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel)
......
......@@ -195,7 +195,7 @@ PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(stanh, STanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
......
......@@ -1847,6 +1847,25 @@ struct SiluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SiluGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
auto temp1 = static_cast<ComplexType<T>>(1) + (-x).exp(); // 1+e^(-x)
auto temp2 = x * (-x).exp(); // x*e^(-x)
dx.device(d) = dout * ((static_cast<ComplexType<T>>(1) / temp1) *
(static_cast<ComplexType<T>>(1) + (temp2 / temp1)))
.unaryExpr(Conj<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
......@@ -3793,6 +3812,23 @@ struct CudaSiluGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSiluGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
// dx = dout * (1 + exp(-x) + x * exp(-x) / (1 + exp(-x))^2)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> arg_dout, const ComplexType<T> arg_x) const {
ComplexType<T> dout = static_cast<ComplexType<T>>(arg_dout);
ComplexType<T> x = static_cast<ComplexType<T>>(arg_x);
ComplexType<T> temp = one / (one + exp(-x));
return dout * conj(temp * (one + x * (one - temp)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSoftsignFunctor : public BaseActivationFunctor<T> {
T one = static_cast<T>(1.0f);
......
......@@ -403,7 +403,7 @@ PD_REGISTER_KERNEL(exp_grad,
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softshrink_grad, SoftShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_shrink_grad, TanhShrinkGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(silu_grad, SiluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_grad, EluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(logit_grad, LogitCUDAGradKernel)
......
......@@ -287,7 +287,7 @@ PD_REGISTER_ACTIVATION_KERNEL(hard_shrink, HardShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(softshrink, SoftShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh_shrink, TanhShrinkKernel)
PD_REGISTER_ACTIVATION_KERNEL(elu, EluKernel)
PD_REGISTER_ACTIVATION_KERNEL(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(silu, SiluKernel)
PD_REGISTER_ACTIVATION_KERNEL(softsign, SoftsignKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(logsigmoid, LogSigmoidKernel)
......
......@@ -1034,7 +1034,7 @@ def silu(x, name=None):
Where :math:`x` is the input Tensor.
Parameters:
x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64.
x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64, complex64, complex128.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
......@@ -1057,7 +1057,17 @@ def silu(x, name=None):
return _C_ops.silu(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'silu'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'silu',
)
helper = LayerHelper("silu", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......
......@@ -384,6 +384,11 @@ class TestSilu(TestActivation):
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = x / (np.exp(-x) + 1)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
......@@ -397,7 +402,11 @@ class TestSilu(TestActivation):
pass
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
# TODO(BeingGod): set `check_prim=True` when `fill_constant` supports `complex` dtype
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
else:
self.check_grad(['X'], 'Out', check_prim=True)
class TestSilu_ZeroDim(TestSilu):
......@@ -405,6 +414,16 @@ class TestSilu_ZeroDim(TestSilu):
self.shape = []
class TestSilu_Complex64(TestSilu):
def init_dtype(self):
self.dtype = np.complex64
class TestSilu_Complex128(TestSilu):
def init_dtype(self):
self.dtype = np.complex128
class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu
def setUp(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册