未验证 提交 6d8dcc74 编写于 作者: H hong 提交者: GitHub

Fix np ndarray mul varbase (#24331)

* fix numpy ndarray mul var base error; test=develop

* add comment for __array_ufunc__ ; test=develop

* move unitest from imperative math op path to test_math_op_patch_var_base;
test=develop
上级 67f66f09
......@@ -249,3 +249,13 @@ def monkey_patch_math_varbase():
core.VarBase.__len__ = _len_
core.VarBase.__index__ = _index_
core.VarBase.astype = astype
"""
When code is written like this
y = np.pi * var
ndarray.__mul__(self, var) is called, var will be traced as an array(by using __len__, __getitem__), which is not right.
when var.__array_ufunc__ is set to None, var.__rmul__(self, np) will be called.
The details can be seen bellow:
https://docs.scipy.org/doc/numpy-1.13.0/neps/ufunc-overrides.html#behavior-in-combination-with-python-s-binary-operations
"""
core.VarBase.__array_ufunc__ = None
......@@ -237,6 +237,20 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
str1 = "just test"
self.assertTrue(str1[var1] == 's')
def test_np_left_mul(self):
with fluid.dygraph.guard():
t = np.sqrt(2.0 * np.pi)
x = fluid.layers.ones((2, 2), dtype="float32")
y = t * x
self.assertTrue(
np.allclose(
y.numpy(),
t * np.ones(
(2, 2), dtype="float32"),
rtol=1e-05,
atol=0.0))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册