diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index f6986265e2fbb319529fd8552b2d84f543be2f0c..dee11da4ac9ac195880dba6911b53a1f752e7d14 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -202,12 +202,17 @@ def monkey_patch_math_varbase(): # 2. create varbase for scalar lhs_dtype = self.dtype if not isinstance(other_var, core.VarBase): - if reverse: - other_var = create_tensor( - other_var, dtype=lhs_dtype, shape=self.shape) + if isinstance(other_var, complex): + import paddle + other_var = paddle.to_tensor(other_var, dtype='complex64') else: - # add fill_op - other_var = create_scalar(value=other_var, dtype=lhs_dtype) + if reverse: + other_var = create_tensor( + other_var, dtype=lhs_dtype, shape=self.shape) + else: + # add fill_op + other_var = create_scalar( + value=other_var, dtype=lhs_dtype) # 3. promote types or unify right var type to left var rhs_dtype = other_var.dtype diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index 7de6148fe73da29fb3b2a44ba9fd57b99cd76b95..0afc9ee6253ea627860aae1be0fa2d0aa3cb2c6f 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -575,6 +575,13 @@ class TestMathOpPatchesVarBase(unittest.TestCase): self.assertTrue(inspect.ismethod(a.std)) self.assertTrue(inspect.ismethod(a.numel)) + def test_complex_scalar(self): + a_np = np.random.random(self.shape).astype(self.dtype) + with fluid.dygraph.guard(): + a = fluid.dygraph.to_variable(a_np) + res = 1J * a + self.assertTrue(np.array_equal(res.numpy(), 1J * a_np)) + if __name__ == '__main__': unittest.main()