未验证 提交 5db0c84b 编写于 作者: C chentianyu03 提交者: GitHub

transform complex scale to tensor (#33699)

* transform complex scale to tensor

* add test_case for complex scalar

* modify import paddle
上级 480b284c
...@@ -202,12 +202,17 @@ def monkey_patch_math_varbase(): ...@@ -202,12 +202,17 @@ def monkey_patch_math_varbase():
# 2. create varbase for scalar # 2. create varbase for scalar
lhs_dtype = self.dtype lhs_dtype = self.dtype
if not isinstance(other_var, core.VarBase): if not isinstance(other_var, core.VarBase):
if reverse: if isinstance(other_var, complex):
other_var = create_tensor( import paddle
other_var, dtype=lhs_dtype, shape=self.shape) other_var = paddle.to_tensor(other_var, dtype='complex64')
else: else:
# add fill_op if reverse:
other_var = create_scalar(value=other_var, dtype=lhs_dtype) other_var = create_tensor(
other_var, dtype=lhs_dtype, shape=self.shape)
else:
# add fill_op
other_var = create_scalar(
value=other_var, dtype=lhs_dtype)
# 3. promote types or unify right var type to left var # 3. promote types or unify right var type to left var
rhs_dtype = other_var.dtype rhs_dtype = other_var.dtype
......
...@@ -575,6 +575,13 @@ class TestMathOpPatchesVarBase(unittest.TestCase): ...@@ -575,6 +575,13 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
self.assertTrue(inspect.ismethod(a.std)) self.assertTrue(inspect.ismethod(a.std))
self.assertTrue(inspect.ismethod(a.numel)) self.assertTrue(inspect.ismethod(a.numel))
def test_complex_scalar(self):
a_np = np.random.random(self.shape).astype(self.dtype)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
res = 1J * a
self.assertTrue(np.array_equal(res.numpy(), 1J * a_np))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册