From 6d8dcc7407f269c410d0a4e8d0351227b3efed56 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Fri, 8 May 2020 21:52:03 +0800 Subject: [PATCH] Fix np ndarray mul varbase (#24331) * fix numpy ndarray mul var base error; test=develop * add comment for __array_ufunc__ ; test=develop * move unitest from imperative math op path to test_math_op_patch_var_base; test=develop --- python/paddle/fluid/dygraph/math_op_patch.py | 10 ++++++++++ .../tests/unittests/test_math_op_patch_var_base.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index 8e4752999b6..b15e6388884 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -249,3 +249,13 @@ def monkey_patch_math_varbase(): core.VarBase.__len__ = _len_ core.VarBase.__index__ = _index_ core.VarBase.astype = astype + """ + When code is written like this + y = np.pi * var + ndarray.__mul__(self, var) is called, var will be traced as an array(by using __len__, __getitem__), which is not right. + when var.__array_ufunc__ is set to None, var.__rmul__(self, np) will be called. + + The details can be seen bellow: + https://docs.scipy.org/doc/numpy-1.13.0/neps/ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations + """ + core.VarBase.__array_ufunc__ = None diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index 4a967d97964..34f14b75952 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -237,6 +237,20 @@ class TestMathOpPatchesVarBase(unittest.TestCase): str1 = "just test" self.assertTrue(str1[var1] == 's') + def test_np_left_mul(self): + with fluid.dygraph.guard(): + t = np.sqrt(2.0 * np.pi) + x = fluid.layers.ones((2, 2), dtype="float32") + y = t * x + + self.assertTrue( + np.allclose( + y.numpy(), + t * np.ones( + (2, 2), dtype="float32"), + rtol=1e-05, + atol=0.0)) + if __name__ == '__main__': unittest.main() -- GitLab