未验证 提交 920c66e9 编写于 作者: H huangjiyi 提交者: GitHub

【complex op】 No.34 add complex support for dot (#56349)

* update

* fix codestyle

* update

* update
上级 488071af
...@@ -1080,8 +1080,8 @@ def dot(x, y, name=None): ...@@ -1080,8 +1080,8 @@ def dot(x, y, name=None):
is the batch dimension, which means that the vectors of multiple batches are dotted. is the batch dimension, which means that the vectors of multiple batches are dotted.
Parameters: Parameters:
x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64`` x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64``, ``complex64``, ``complex128``
y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64`` y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64``, ``complex64``, ``complex128``
name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name` name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name`
Returns: Returns:
...@@ -1117,13 +1117,31 @@ def dot(x, y, name=None): ...@@ -1117,13 +1117,31 @@ def dot(x, y, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], [
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
op_type, op_type,
) )
check_variable_and_dtype( check_variable_and_dtype(
y, y,
'y', 'y',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'], [
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
op_type, op_type,
) )
......
...@@ -182,102 +182,36 @@ class TestDygraph(unittest.TestCase): ...@@ -182,102 +182,36 @@ class TestDygraph(unittest.TestCase):
) )
class TestComplexDotOp(OpTest): class TestComplex64DotOp(DotOp):
def setUp(self): def init_dtype(self):
self.op_type = "dot" self.dtype = np.complex64
self.python_api = paddle.dot
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
def init_input_output(self): def init_input_output(self):
self.x = np.random.random(100).astype( shape = 100
self.dtype self.x = (
) + 1j * np.random.random(100).astype(self.dtype) np.random.random(shape) + 1j * np.random.random(shape)
self.y = np.random.random(100).astype( ).astype(self.dtype)
self.dtype self.y = (
) + 1j * np.random.random(100).astype(self.dtype) np.random.random(shape) + 1j * np.random.random(shape)
self.out = np.dot(self.x, self.y) ).astype(self.dtype)
self.out = np.dot(self.x, self.y).astype(self.dtype)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
)
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
)
class TestComplexDotOp2D(OpTest):
def setUp(self):
self.op_type = "dot"
self.python_api = paddle.dot
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
class TestComplex64DotOp2D(TestComplex64DotOp):
def init_input_output(self): def init_input_output(self):
self.x = np.random.random((2, 100)).astype( shape = (2, 100)
self.dtype self.x = (
) + 1j * np.random.random((2, 100)).astype(self.dtype) np.random.random(shape) + 1j * np.random.random(shape)
self.y = np.random.random((2, 100)).astype( ).astype(self.dtype)
self.dtype self.y = (
) + 1j * np.random.random((2, 100)).astype(self.dtype) np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1) self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
)
def test_check_grad_ingore_x(self): class TestComplex128DotOp(TestComplex64DotOp):
self.check_grad( def init_dtype(self):
['Y'], self.dtype = np.complex128
'Out',
no_grad_set=set("X"),
)
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
)
@unittest.skipIf( @unittest.skipIf(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册