未验证 提交 920c66e9 编写于 作者: H huangjiyi 提交者: GitHub

【complex op】 No.34 add complex support for dot (#56349)

* update

* fix codestyle

* update

* update
上级 488071af
......@@ -1080,8 +1080,8 @@ def dot(x, y, name=None):
is the batch dimension, which means that the vectors of multiple batches are dotted.
Parameters:
x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64``
y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64``
x(Tensor): 1-D or 2-D ``Tensor``. Its dtype should be ``float32``, ``float64``, ``int32``, ``int64``, ``complex64``, ``complex128``
y(Tensor): 1-D or 2-D ``Tensor``. Its dtype soulde be ``float32``, ``float64``, ``int32``, ``int64``, ``complex64``, ``complex128``
name(str, optional): Name of the output. Default is None. It's used to print debug info for developers. Details: :ref:`api_guide_Name`
Returns:
......@@ -1117,13 +1117,31 @@ def dot(x, y, name=None):
check_variable_and_dtype(
x,
'x',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
[
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
op_type,
)
check_variable_and_dtype(
y,
'y',
['float16', 'uint16', 'float32', 'float64', 'int32', 'int64'],
[
'float16',
'uint16',
'float32',
'float64',
'int32',
'int64',
'complex64',
'complex128',
],
op_type,
)
......
......@@ -182,102 +182,36 @@ class TestDygraph(unittest.TestCase):
)
class TestComplexDotOp(OpTest):
def setUp(self):
self.op_type = "dot"
self.python_api = paddle.dot
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}
def init_base_dtype(self):
self.dtype = np.float64
class TestComplex64DotOp(DotOp):
def init_dtype(self):
self.dtype = np.complex64
def init_input_output(self):
self.x = np.random.random(100).astype(
self.dtype
) + 1j * np.random.random(100).astype(self.dtype)
self.y = np.random.random(100).astype(
self.dtype
) + 1j * np.random.random(100).astype(self.dtype)
self.out = np.dot(self.x, self.y)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
)
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
)
class TestComplexDotOp2D(OpTest):
def setUp(self):
self.op_type = "dot"
self.python_api = paddle.dot
self.init_base_dtype()
self.init_input_output()
self.inputs = {
'X': OpTest.np_dtype_to_fluid_dtype(self.x),
'Y': OpTest.np_dtype_to_fluid_dtype(self.y),
}
self.outputs = {'Out': self.out}
shape = 100
self.x = (
np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.y = (
np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.out = np.dot(self.x, self.y).astype(self.dtype)
def init_base_dtype(self):
self.dtype = np.float64
class TestComplex64DotOp2D(TestComplex64DotOp):
def init_input_output(self):
self.x = np.random.random((2, 100)).astype(
self.dtype
) + 1j * np.random.random((2, 100)).astype(self.dtype)
self.y = np.random.random((2, 100)).astype(
self.dtype
) + 1j * np.random.random((2, 100)).astype(self.dtype)
shape = (2, 100)
self.x = (
np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.y = (
np.random.random(shape) + 1j * np.random.random(shape)
).astype(self.dtype)
self.out = np.diag(np.dot(self.x, self.y.T)).reshape(-1)
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(
['X', 'Y'],
'Out',
)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'],
'Out',
no_grad_set=set("X"),
)
def test_check_grad_ingore_y(self):
self.check_grad(
['X'],
'Out',
no_grad_set=set('Y'),
)
class TestComplex128DotOp(TestComplex64DotOp):
def init_dtype(self):
self.dtype = np.complex128
@unittest.skipIf(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册