From 43926c80e6f640a9c2ee6e7a9c11c7eec23d66f0 Mon Sep 17 00:00:00 2001 From: Wenyu Date: Tue, 20 Apr 2021 13:49:15 +0800 Subject: [PATCH] support `numpy.array/asarray(tensor) -> ndarray`, test=develop (#32300) --- python/paddle/fluid/dygraph/varbase_patch_methods.py | 5 ++++- python/paddle/fluid/tests/unittests/test_var_base.py | 9 +++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index ac59470986..64209aee87 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -434,6 +434,9 @@ def monkey_patch_varbase(): def __bool__(self): return self.__nonzero__() + def __array__(self, dtype=None): + return self.numpy().astype(dtype) + for method_name, method in ( ("__bool__", __bool__), ("__nonzero__", __nonzero__), ("_to_static_var", _to_static_var), ("set_value", set_value), @@ -442,7 +445,7 @@ def monkey_patch_varbase(): ("gradient", gradient), ("register_hook", register_hook), ("__str__", __str__), ("__repr__", __str__), ("__deepcopy__", __deepcopy__), ("__module__", "paddle"), - ("__name__", "Tensor")): + ("__name__", "Tensor"), ("__array__", __array__)): setattr(core.VarBase, method_name, method) # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 1fea193547..76c871f372 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -502,6 +502,15 @@ class TestVarBase(unittest.TestCase): np.array_equal(var.numpy(), fluid.framework._var_base_to_np(var))) + def test_var_base_as_np(self): + with fluid.dygraph.guard(): + var = fluid.dygraph.to_variable(self.array) + self.assertTrue(np.array_equal(var.numpy(), np.array(var))) + self.assertTrue( + np.array_equal( + var.numpy(), np.array( + var, dtype=np.float32))) + def test_if(self): with fluid.dygraph.guard(): var1 = fluid.dygraph.to_variable(np.array([[[0]]])) -- GitLab