未验证 提交 e522ceb7 编写于 作者: G GGBond8488 提交者: GitHub

add complex support for optest (#53356)

* add complex support for  optest

* add complex grad test

* append one

* move some debug info

* move some debug info

* move some debug info

* move some debug info

* add more complex test

* Fix naming ambiguity

* Revert "add more complex test"

This reverts commit dbcb0516b8e53ba42e2d6089878a39b395345969.

* change backward gradient, add TODO
上级 70180df5
......@@ -35,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/pybind/complex.h"
#include "paddle/phi/kernels/funcs/strided_memcpy.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/cuda_device_guard.h"
......
......@@ -47,4 +47,6 @@ PD_REGISTER_KERNEL(mean_grad,
phi::ReduceMeanGradKernel,
bool,
float,
double) {}
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -36,5 +36,12 @@ void MeanRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
mean_raw, CPU, ALL_LAYOUT, phi::MeanRawKernel, float, double, bool) {}
PD_REGISTER_KERNEL(mean_raw,
CPU,
ALL_LAYOUT,
phi::MeanRawKernel,
float,
double,
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -67,4 +67,6 @@ PD_REGISTER_KERNEL(mean_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -48,5 +48,7 @@ PD_REGISTER_KERNEL(mean_raw,
phi::dtype::bfloat16,
float16,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
......@@ -31,8 +31,15 @@ void MeanKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
mean, CPU, ALL_LAYOUT, phi::MeanKernel, float, double, bool) {}
PD_REGISTER_KERNEL(mean,
CPU,
ALL_LAYOUT,
phi::MeanKernel,
float,
double,
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(mean,
......@@ -45,7 +52,9 @@ PD_REGISTER_KERNEL(mean,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
......
......@@ -200,6 +200,10 @@ def get_numeric_gradient(
return tensor._get_float_element(i)
elif tensor_to_check_dtype == np.float64:
return tensor._get_double_element(i)
elif tensor_to_check_dtype == np.complex64:
return tensor._get_complex64_element(i)
elif tensor_to_check_dtype == np.complex128:
return tensor._get_complex128_element(i)
else:
raise TypeError(
"Unsupported test data type %s." % tensor_to_check_dtype
......@@ -224,6 +228,10 @@ def get_numeric_gradient(
tensor._set_float_element(i, e)
elif tensor_to_check_dtype == np.float64:
tensor._set_double_element(i, e)
elif tensor_to_check_dtype == np.complex64:
return tensor._set_complex64_element(i, e)
elif tensor_to_check_dtype == np.complex128:
return tensor._set_complex128_element(i, e)
else:
raise TypeError(
"Unsupported test data type %s." % tensor_to_check_dtype
......@@ -242,6 +250,13 @@ def get_numeric_gradient(
__set_elem__(tensor_to_check, i, x_pos)
y_pos = get_output()
if tensor_to_check_dtype in [np.complex64, np.complex128]:
if in_place:
set_input(scope, op, inputs, place)
x_pos_j = origin + 1j * delta
__set_elem__(tensor_to_check, i, x_pos_j)
y_pos_j = get_output()
if in_place:
set_input(scope, op, inputs, place)
......@@ -249,8 +264,44 @@ def get_numeric_gradient(
__set_elem__(tensor_to_check, i, x_neg)
y_neg = get_output()
if tensor_to_check_dtype in [np.complex64, np.complex128]:
if in_place:
set_input(scope, op, inputs, place)
x_neg_j = origin - 1j * delta
__set_elem__(tensor_to_check, i, x_neg_j)
y_neg_j = get_output()
__set_elem__(tensor_to_check, i, origin)
if tensor_to_check_dtype in [np.complex64, np.complex128]:
# always assume real output, because this function has
# no input for dl/di, though it should do. so there di will be zero
# TODO: Here is a trick to be consistent with the existing OpTest, it
# need to support variable gradients input
f_ajoint = np.array(1 + 0j)
df_over_dr = (y_pos - y_neg) / delta / 2
df_over_di = (y_pos_j - y_neg_j) / delta / 2
dl_over_du, dl_over_dv = f_ajoint.real, f_ajoint.imag
du_over_dr, dv_over_dr = df_over_dr.real, df_over_dr.imag
du_over_di, dv_over_di = df_over_di.real, df_over_di.imag
dl_over_dr = np.sum(
dl_over_du * du_over_dr + dl_over_dv * dv_over_dr
)
dl_over_di = np.sum(
dl_over_du * du_over_di + dl_over_dv * dv_over_di
)
gradient_flat[i] = dl_over_dr + 1j * dl_over_di
else:
df_over_dr = y_pos - y_neg
gradient_flat[i] = df_over_dr / delta / 2
__set_elem__(tensor_to_check, i, origin)
gradient_flat[i] = (y_pos - y_neg) / delta / 2
return gradient_flat.reshape(tensor_to_check.shape())
......@@ -375,6 +426,13 @@ class OpTest(unittest.TestCase):
def is_custom_device_op_test():
return hasattr(cls, "use_custom_device") and cls.use_custom_device
def is_complex_test():
return (
hasattr(cls, "test_complex")
and cls.test_complex
or (cls.dtype in [np.complex64, np.complex128])
)
if not hasattr(cls, "op_type"):
raise AssertionError(
"This test do not have op_type in class attrs, "
......@@ -382,8 +440,10 @@ class OpTest(unittest.TestCase):
)
# case in NO_FP64_CHECK_GRAD_CASES and op in NO_FP64_CHECK_GRAD_OP_LIST should be fixed
if not hasattr(cls, "no_need_check_grad") and not is_empty_grad_op(
cls.op_type
if (
not hasattr(cls, "no_need_check_grad")
and not is_empty_grad_op(cls.op_type)
and not is_complex_test()
):
if cls.dtype is None or (
cls.dtype == np.float16
......@@ -2496,7 +2556,6 @@ class OpTest(unittest.TestCase):
max_relative_error = (
0.001 if max_relative_error < 0.001 else max_relative_error
)
self._assert_is_close(
numeric_grads,
analytic_grads,
......
......@@ -93,6 +93,55 @@ class ElementwiseMulOp(OpTest):
pass
class TestComplexElementwiseMulOpWithCheckGrad(ElementwiseMulOp):
def setUp(self):
self.op_type = "elementwise_mul"
self.python_api = paddle.multiply
self.public_python_api = paddle.multiply
self.dtype = np.complex128
self.axis = -1
self.init_dtype()
self.init_input_output()
self.init_kernel_type()
self.init_axis()
self.if_enable_cinn()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}
self.attrs = {'axis': self.axis}
def init_input_output(self):
self.x = np.array([3 + 4j, 1 + 2j]).astype(self.dtype)
self.y = np.array([3 + 4j, 5 + 6j]).astype(self.dtype)
self.out = np.multiply(self.x, self.y)
def if_enable_cinn(self):
self.enable_cinn = False
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
)
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
)
class TestElementwiseMulOp_ZeroDim1(ElementwiseMulOp):
def init_input_output(self):
self.x = np.random.uniform(0.1, 1, []).astype(self.dtype)
......
......@@ -57,6 +57,20 @@ class TestSumOp(OpTest):
self.check_grad(['X'], 'Out', check_prim=True)
class TestComplexSumOP(TestSumOp):
def init_dtype(self):
self.dtype = np.complex128
def init_input(self):
self.x = np.random.random((3, 4)).astype(self.dtype)
def init_attrs(self):
self.attrs = {'dim': [0]}
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=False)
class TestSumOp_ZeroDim(TestSumOp):
def init_attrs(self):
self.attrs = {'dim': [], 'reduce_all': True}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册