From 5db0c84b2fb2300da99e0a2285493449ac6676b0 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 22 Jun 2021 15:56:06 +0800 Subject: [PATCH] transform complex scale to tensor (#33699) * transform complex scale to tensor * add test_case for complex scalar * modify import paddle --- python/paddle/fluid/dygraph/math_op_patch.py | 15 ++++++++++----- .../unittests/test_math_op_patch_var_base.py | 7 +++++++ 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index f6986265e2f..dee11da4ac9 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -202,12 +202,17 @@ def monkey_patch_math_varbase(): # 2. create varbase for scalar lhs_dtype = self.dtype if not isinstance(other_var, core.VarBase): - if reverse: - other_var = create_tensor( - other_var, dtype=lhs_dtype, shape=self.shape) + if isinstance(other_var, complex): + import paddle + other_var = paddle.to_tensor(other_var, dtype='complex64') else: - # add fill_op - other_var = create_scalar(value=other_var, dtype=lhs_dtype) + if reverse: + other_var = create_tensor( + other_var, dtype=lhs_dtype, shape=self.shape) + else: + # add fill_op + other_var = create_scalar( + value=other_var, dtype=lhs_dtype) # 3. promote types or unify right var type to left var rhs_dtype = other_var.dtype diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index 7de6148fe73..0afc9ee6253 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -575,6 +575,13 @@ class TestMathOpPatchesVarBase(unittest.TestCase): self.assertTrue(inspect.ismethod(a.std)) self.assertTrue(inspect.ismethod(a.numel)) + def test_complex_scalar(self): + a_np = np.random.random(self.shape).astype(self.dtype) + with fluid.dygraph.guard(): + a = fluid.dygraph.to_variable(a_np) + res = 1J * a + self.assertTrue(np.array_equal(res.numpy(), 1J * a_np)) + if __name__ == '__main__': unittest.main() -- GitLab