diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 85092c77cc214c05d3a662fc0390d9d9013e99a6..0e522533c85126ed03b5eef81074a0899b1fe109 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1080,8 +1080,8 @@ def dot(x, y, name=None): is the batch dimension, which means that the vectors of multiple batches are dotted. Parameters: - x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64`` - y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64`` + x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64``, ``complex64``, ``complex128`` + y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64``, ``complex64``, ``complex128`` name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name` Returns: @@ -1117,13 +1117,31 @@ def dot(x, y, name=None): check_variable_and_dtype( x, 'x', - ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'int32', + 'int64', + 'complex64', + 'complex128', + ], op_type, ) check_variable_and_dtype( y, 'y', - ['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], + [ + 'float16', + 'uint16', + 'float32', + 'float64', + 'int32', + 'int64', + 'complex64', + 'complex128', + ], op_type, ) diff --git a/test/legacy_test/test_dot_op.py b/test/legacy_test/test_dot_op.py index 7b89e074c4cc22eb1d22c6176d7ce85449f3d9b0..a8b7d4dd539024f307e3d90240478ad9a8f22f97 100644 --- a/test/legacy_test/test_dot_op.py +++ b/test/legacy_test/test_dot_op.py @@ -182,102 +182,36 @@ class TestDygraph(unittest.TestCase): ) -class TestComplexDotOp(OpTest): - def setUp(self): - self.op_type = "dot" - self.python_api = paddle.dot - self.init_base_dtype() - self.init_input_output() - - self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(self.x), - 'Y': OpTest.np_dtype_to_fluid_dtype(self.y), - } - self.outputs = {'Out': self.out} - - def init_base_dtype(self): - self.dtype = np.float64 +class TestComplex64DotOp(DotOp): + def init_dtype(self): + self.dtype = np.complex64 def init_input_output(self): - self.x = np.random.random(100).astype( - self.dtype - ) + 1j * np.random.random(100).astype(self.dtype) - self.y = np.random.random(100).astype( - self.dtype - ) + 1j * np.random.random(100).astype(self.dtype) - self.out = np.dot(self.x, self.y) - - def test_check_output(self): - self.check_output() - - def test_check_grad_normal(self): - self.check_grad( - ['X', 'Y'], - 'Out', - ) - - def test_check_grad_ingore_x(self): - self.check_grad( - ['Y'], - 'Out', - no_grad_set=set("X"), - ) - - def test_check_grad_ingore_y(self): - self.check_grad( - ['X'], - 'Out', - no_grad_set=set('Y'), - ) - - -class TestComplexDotOp2D(OpTest): - def setUp(self): - self.op_type = "dot" - self.python_api = paddle.dot - self.init_base_dtype() - self.init_input_output() - - self.inputs = { - 'X': OpTest.np_dtype_to_fluid_dtype(self.x), - 'Y': OpTest.np_dtype_to_fluid_dtype(self.y), - } - self.outputs = {'Out': self.out} + shape = 100 + self.x = ( + np.random.random(shape) + 1j * np.random.random(shape) + ).astype(self.dtype) + self.y = ( + np.random.random(shape) + 1j * np.random.random(shape) + ).astype(self.dtype) + self.out = np.dot(self.x, self.y).astype(self.dtype) - def init_base_dtype(self): - self.dtype = np.float64 +class TestComplex64DotOp2D(TestComplex64DotOp): def init_input_output(self): - self.x = np.random.random((2, 100)).astype( - self.dtype - ) + 1j * np.random.random((2, 100)).astype(self.dtype) - self.y = np.random.random((2, 100)).astype( - self.dtype - ) + 1j * np.random.random((2, 100)).astype(self.dtype) + shape = (2, 100) + self.x = ( + np.random.random(shape) + 1j * np.random.random(shape) + ).astype(self.dtype) + self.y = ( + np.random.random(shape) + 1j * np.random.random(shape) + ).astype(self.dtype) self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1) - def test_check_output(self): - self.check_output() - - def test_check_grad_normal(self): - self.check_grad( - ['X', 'Y'], - 'Out', - ) - def test_check_grad_ingore_x(self): - self.check_grad( - ['Y'], - 'Out', - no_grad_set=set("X"), - ) - - def test_check_grad_ingore_y(self): - self.check_grad( - ['X'], - 'Out', - no_grad_set=set('Y'), - ) +class TestComplex128DotOp(TestComplex64DotOp): + def init_dtype(self): + self.dtype = np.complex128 @unittest.skipIf(