未验证 提交 f8cba26d 编写于 作者: R Ryan 提交者: GitHub

【Complex op】add complex support for numel (#56412)

* add complex numel

* change test && add doc
上级 0668650f
......@@ -30,7 +30,9 @@ PD_REGISTER_KERNEL(numel,
phi::dtype::bfloat16,
float,
double,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -47,7 +49,9 @@ PD_REGISTER_KERNEL(numel,
phi::dtype::bfloat16,
float,
double,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
#endif
......@@ -29,6 +29,8 @@ PD_REGISTER_KERNEL(numel,
phi::dtype::bfloat16,
float,
double,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -27,6 +27,8 @@ PD_REGISTER_KERNEL(numel,
int64_t,
phi::dtype::float16,
float,
bool) {
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
......@@ -227,7 +227,7 @@ def numel(x, name=None):
Returns the number of elements for a tensor, which is a 0-D int64 Tensor with shape [].
Args:
x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64.
x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64, complex64, complex128.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
......
......@@ -71,6 +71,54 @@ class TestNumelOp2FP16(TestNumelOp):
self.shape = (0,)
class TestNumelOpComplex(TestNumelOp):
def setUp(self):
self.op_type = "size"
self.python_api = paddle.numel
self.init()
x = np.random.random(self.shape).astype(
self.dtype
) + 1j * np.random.random(self.shape).astype(self.dtype)
self.inputs = {
'Input': x,
}
self.outputs = {'Out': np.array(np.size(x))}
def init(self):
self.dtype = np.complex64
self.shape = (6, 56, 8, 55)
class Test1NumelOpComplex64(TestNumelOpComplex):
def init(self):
self.dtype = np.complex64
self.shape = (11, 66)
class Test2NumelOpComplex64(TestNumelOpComplex):
def init(self):
self.dtype = np.complex64
self.shape = (0,)
class Test0NumelOpComplex128(TestNumelOpComplex):
def init(self):
self.dtype = np.complex128
self.shape = (6, 56, 8, 55)
class Test1NumelOpComplex128(TestNumelOpComplex):
def init(self):
self.dtype = np.complex128
self.shape = (11, 66)
class Test2NumelOpComple128(TestNumelOpComplex):
def init(self):
self.dtype = np.complex128
self.shape = (0,)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册