未验证 提交 3bedec8a 编写于 作者: S Scotty 提交者: GitHub

【Complex op】add complex support for sin, cos, tan, tanh (#55380)

* add complex dtype for tanh

* add test case

* support complex for sin, cos and tan

* support gpu

* fix error in cpu

* fix gpu error

* set check_prim to False only for complex type
上级 e2e0d296
......@@ -422,6 +422,36 @@ HOSTDEVICE inline complex<T> sqrt(const complex<T>& a) {
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> sin(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::sin(thrust::complex<T>(a)));
#else
return complex<T>(std::sin(std::complex<T>(a)));
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> cos(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::cos(thrust::complex<T>(a)));
#else
return complex<T>(std::cos(std::complex<T>(a)));
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> tan(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::tan(thrust::complex<T>(a)));
#else
return complex<T>(std::tan(std::complex<T>(a)));
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> tanh(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
......@@ -432,6 +462,16 @@ HOSTDEVICE inline complex<T> tanh(const complex<T>& a) {
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> conj(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
(defined(__CUDA_ARCH__) || defined(__HIPCC__))
return complex<T>(thrust::conj(thrust::complex<T>(a)));
#else
return complex<T>(std::conj(std::complex<T>(a)));
#endif
}
template <typename T>
HOSTDEVICE inline complex<T> log(const complex<T>& a) {
#if defined(PADDLE_WITH_CUDA_OR_HIP_COMPLEX) && \
......
......@@ -255,13 +255,34 @@ PD_REGISTER_KERNEL(
#define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {}
#define PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}
#define PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, CPU, ALL_LAYOUT, phi::func, float, double, phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tan_grad, TanGradKernel)
#define PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tan_grad, TanGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acos_grad, AcosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asin_grad, AsinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atan_grad, AtanGradKernel)
......@@ -270,7 +291,7 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(cosh_grad, CoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(thresholded_relu_grad,
......@@ -290,8 +311,8 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(relu_double_grad,
ReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(tanh_double_grad,
TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(tanh_double_grad,
TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(leaky_relu_double_grad,
LeakyReluDoubleGradKernel)
PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(elu_double_grad, EluDoubleGradKernel)
......@@ -308,7 +329,9 @@ PD_REGISTER_KERNEL(tanh_triple_grad,
phi::TanhTripleGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(exp_grad,
CPU,
......@@ -355,7 +378,9 @@ PD_REGISTER_KERNEL(sin_double_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(sin_triple_grad,
CPU,
......@@ -365,7 +390,9 @@ PD_REGISTER_KERNEL(sin_triple_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(cos_double_grad,
CPU,
......@@ -375,7 +402,9 @@ PD_REGISTER_KERNEL(cos_double_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(cos_triple_grad,
CPU,
......@@ -385,7 +414,9 @@ PD_REGISTER_KERNEL(cos_triple_grad,
double,
phi::dtype::float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
......
......@@ -166,9 +166,19 @@ PD_REGISTER_KERNEL(relu, CPU, ALL_LAYOUT, phi::ReluKernel, float, double) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, CPU, ALL_LAYOUT, phi::func, float, double) {}
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL(tan, TanKernel)
#define PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tan, TanKernel)
PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel)
PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel)
PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel)
......@@ -177,7 +187,7 @@ PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
......
......@@ -96,6 +96,17 @@ struct Cosine<dtype::bfloat16> {
}
};
template <typename T>
using ComplexType = phi::dtype::complex<T>;
// T is phi::dtype::complex<float> or phi::dtype::complex<double>
template <typename T>
struct Conj {
HOSTDEVICE ComplexType<T> operator()(const ComplexType<T>& val) const {
return ComplexType<T>(val.real, -val.imag);
}
};
// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
......@@ -111,6 +122,21 @@ struct SinGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct SinGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) =
dout * x.unaryExpr(Cosine<ComplexType<T>>()).unaryExpr(Conj<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// sine(x) = sin(x)
template <typename T>
struct SinFunctor : public BaseActivationFunctor<T> {
......@@ -320,6 +346,22 @@ struct CosGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CosGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
dx.device(d) =
-dout * x.unaryExpr(Sine<ComplexType<T>>()).unaryExpr(Conj<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// cos''(x) = -cos(x)
template <typename T>
struct CosDoubleGradFunctor : public BaseActivationFunctor<T> {
......@@ -584,6 +626,27 @@ struct TanGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct TanGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
// auto dx_ =
// static_cast<ComplexType<T>>(x.unaryExpr(Cosine<T>()).square());
// ComplexType<T> dx_conj_(dx_.real, -dx_.imag);
// dx.device(d) = dout / dx_conj_;
dx.device(d) =
dout /
x.unaryExpr(Cosine<ComplexType<T>>()).square().unaryExpr(Conj<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
......@@ -1217,6 +1280,28 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct TanhGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x UNUSED, Out out, dOut dout, dX dx) const {
// auto dx_ = static_cast<ComplexType<T>>(1) - out * out;
// ComplexType<T> dx_conj_(dx_.real, -dx_.imag);
// dx.device(d) = dout * dx_conj_;
dx.device(d) =
dout *
(static_cast<ComplexType<T>>(1) - out * out).unaryExpr(Conj<T>());
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct TanhGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
......@@ -2675,6 +2760,18 @@ struct CudaCosGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaCosGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout * (-sin(x))
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(-dout * conj(sin(x)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaExpFunctor : public BaseActivationFunctor<T> {
// exp(x) = expf(x)
......@@ -2847,6 +2944,18 @@ struct CudaSinGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaSinGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout * cos(x)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(dout * conj(cos(x)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaTanFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -2873,6 +2982,18 @@ struct CudaTanGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaTanGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
// dx = dout / cos(x)^2
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> x) const {
return static_cast<ComplexType<T>>(dout / conj(cos(x) * cos(x)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};
template <typename T>
struct CudaAsinFunctor : public BaseActivationFunctor<T> {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
......@@ -3250,6 +3371,22 @@ struct CudaTanhGradFunctor : public BaseActivationFunctor<T> {
}
};
template <typename T>
struct CudaTanhGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> one = static_cast<ComplexType<T>>(1.0f);
// dx = dout * (1 - out^2)
__device__ __forceinline__ ComplexType<T> operator()(
const ComplexType<T> dout, const ComplexType<T> out) const {
return dout * conj(one - out * out);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename T>
struct CudaHardTanhFunctor : public BaseActivationFunctor<T> {
float t_min;
......
......@@ -343,9 +343,21 @@ PD_REGISTER_KERNEL(relu_double_grad,
phi::dtype::float16, \
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tan_grad, TanGradKernel)
#define PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
GPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sin_grad, SinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(cos_grad, CosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tan_grad, TanGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acos_grad, AcosGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asin_grad, AsinGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atan_grad, AtanGradKernel)
......@@ -354,9 +366,11 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL(cosh_grad, CoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(asinh_grad, AsinhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(acosh_grad, AcoshGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(atanh_grad, AtanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_double_grad, TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(tanh_triple_grad, TanhTripleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_grad, TanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_double_grad,
TanhDoubleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(tanh_triple_grad,
TanhTripleGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardtanh_grad, HardTanhGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_double_grad,
......@@ -433,7 +447,9 @@ PD_REGISTER_KERNEL(sin_double_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(sin_triple_grad,
GPU,
......@@ -444,7 +460,9 @@ PD_REGISTER_KERNEL(sin_triple_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(cos_double_grad,
GPU,
......@@ -455,7 +473,9 @@ PD_REGISTER_KERNEL(cos_double_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(cos_triple_grad,
GPU,
......@@ -466,7 +486,9 @@ PD_REGISTER_KERNEL(cos_triple_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_ACTIVATION_GRAD_KERNEL(softsign_grad, SoftsignGradKernel)
PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel)
......
......@@ -217,9 +217,21 @@ PD_REGISTER_KERNEL(relu,
phi::dtype::float16, \
phi::dtype::bfloat16) {}
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL(tan, TanKernel)
#define PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(name, func) \
PD_REGISTER_KERNEL(name, \
GPU, \
ALL_LAYOUT, \
phi::func, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16, \
phi::dtype::complex<float>, \
phi::dtype::complex<double>) {}
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(cos, CosKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tan, TanKernel)
PD_REGISTER_ACTIVATION_KERNEL(acos, AcosKernel)
PD_REGISTER_ACTIVATION_KERNEL(asin, AsinKernel)
PD_REGISTER_ACTIVATION_KERNEL(atan, AtanKernel)
......@@ -228,7 +240,7 @@ PD_REGISTER_ACTIVATION_KERNEL(cosh, CoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(asinh, AsinhKernel)
PD_REGISTER_ACTIVATION_KERNEL(acosh, AcoshKernel)
PD_REGISTER_ACTIVATION_KERNEL(atanh, AtanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(tanh, TanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardtanh, HardTanhKernel)
PD_REGISTER_ACTIVATION_KERNEL(thresholded_relu, ThresholdedReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
......
......@@ -508,7 +508,10 @@ def cos(x, name=None):
return _C_ops.cos(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'cos'
x,
'x',
['float16', 'float32', 'float64', 'complex64', 'complex128'],
'cos',
)
helper = LayerHelper('cos', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -875,7 +878,17 @@ def sin(x, name=None):
return _C_ops.sin(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'sin'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'sin',
)
helper = LayerHelper('sin', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -1039,7 +1052,17 @@ def tan(x, name=None):
return _C_ops.tan(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'tan'
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
'complex64',
'complex128',
],
'tan',
)
helper = LayerHelper('tan', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
......@@ -546,6 +546,11 @@ class TestTanh(TestActivation, TestParameter):
np.random.seed(1024)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = np.tanh(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
......@@ -554,7 +559,11 @@ class TestTanh(TestActivation, TestParameter):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True)
# TODO(ScottWong98): set `check_prim=False` when `fill_any_like` supports `complex` dtype
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
else:
self.check_grad(['X'], 'Out', check_prim=True)
def init_dtype(self):
# TODO If dtype is float64, the output (Out) has diff at CPUPlace
......@@ -566,6 +575,16 @@ class TestTanh(TestActivation, TestParameter):
pass
class TestTanh_Complex64(TestTanh):
def init_dtype(self):
self.dtype = np.complex64
class TestTanh_Complex128(TestTanh):
def init_dtype(self):
self.dtype = np.complex128
class TestTanh_ZeroDim(TestTanh):
def init_shape(self):
self.shape = []
......@@ -1566,6 +1585,11 @@ class TestCos(TestActivation):
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = np.cos(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
......@@ -1577,12 +1601,29 @@ class TestCos(TestActivation):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True)
# TODO(ScottWong98): set `check_prim=False` when `fill_any_like` supports `complex` dtype
if self.dtype == np.complex64 or self.dtype == np.complex128:
# Complex64 [GPU]: AssertionError: 0.0057843705 not less than or equal to 0.005
self.check_grad(
['X'], 'Out', check_prim=False, max_relative_error=0.006
)
else:
self.check_grad(['X'], 'Out', check_prim=True)
def if_enable_cinn(self):
pass
class TestCos_Complex64(TestCos):
def init_dtype(self):
self.dtype = np.complex64
class TestCos_Complex128(TestCos):
def init_dtype(self):
self.dtype = np.complex128
class TestCos_ZeroDim(TestCos):
def init_shape(self):
self.shape = []
......@@ -1596,8 +1637,12 @@ class TestTan(TestActivation):
self.init_dtype()
self.init_shape()
self.dtype = 'float32'
self.x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.x_np = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
......@@ -1619,6 +1664,21 @@ class TestTan(TestActivation):
self.check_grad(['X'], 'Out')
class TestTan_float32(TestTan):
def init_dtype(self):
self.dtype = "float32"
class TestTan_Complex64(TestTan):
def init_dtype(self):
self.dtype = np.complex64
class TestTan_Complex128(TestTan):
def init_dtype(self):
self.dtype = np.complex128
class TestTan_ZeroDim(TestTan):
def init_shape(self):
self.shape = []
......@@ -1707,6 +1767,11 @@ class TestSin(TestActivation, TestParameter):
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if self.dtype == np.complex64 or self.dtype == np.complex128:
x = (
np.random.uniform(-1, 1, self.shape)
+ 1j * np.random.uniform(-1, 1, self.shape)
).astype(self.dtype)
out = np.sin(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
......@@ -1718,12 +1783,26 @@ class TestSin(TestActivation, TestParameter):
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_prim=True)
# TODO(ScottWong98): set `check_prim=False` when `fill_any_like` supports `complex` dtype
if self.dtype == np.complex64 or self.dtype == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
else:
self.check_grad(['X'], 'Out', check_prim=True)
def if_enable_cinn(self):
pass
class TestSin_Complex64(TestSin):
def init_dtype(self):
self.dtype = np.complex64
class TestSin_Complex128(TestSin):
def init_dtype(self):
self.dtype = np.complex128
class TestSin_ZeroDim(TestSin):
def init_shape(self):
self.shape = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册