未验证 提交 43926c80 编写于 作者: W Wenyu 提交者: GitHub

support `numpy.array/asarray(tensor) -> ndarray`, test=develop (#32300)

上级 f0cc1883
...@@ -434,6 +434,9 @@ def monkey_patch_varbase(): ...@@ -434,6 +434,9 @@ def monkey_patch_varbase():
def __bool__(self): def __bool__(self):
return self.__nonzero__() return self.__nonzero__()
def __array__(self, dtype=None):
return self.numpy().astype(dtype)
for method_name, method in ( for method_name, method in (
("__bool__", __bool__), ("__nonzero__", __nonzero__), ("__bool__", __bool__), ("__nonzero__", __nonzero__),
("_to_static_var", _to_static_var), ("set_value", set_value), ("_to_static_var", _to_static_var), ("set_value", set_value),
...@@ -442,7 +445,7 @@ def monkey_patch_varbase(): ...@@ -442,7 +445,7 @@ def monkey_patch_varbase():
("gradient", gradient), ("register_hook", register_hook), ("gradient", gradient), ("register_hook", register_hook),
("__str__", __str__), ("__repr__", __str__), ("__str__", __str__), ("__repr__", __str__),
("__deepcopy__", __deepcopy__), ("__module__", "paddle"), ("__deepcopy__", __deepcopy__), ("__module__", "paddle"),
("__name__", "Tensor")): ("__name__", "Tensor"), ("__array__", __array__)):
setattr(core.VarBase, method_name, method) setattr(core.VarBase, method_name, method)
# NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class. # NOTE(zhiqiu): pybind11 will set a default __str__ method of enum class.
......
...@@ -502,6 +502,15 @@ class TestVarBase(unittest.TestCase): ...@@ -502,6 +502,15 @@ class TestVarBase(unittest.TestCase):
np.array_equal(var.numpy(), np.array_equal(var.numpy(),
fluid.framework._var_base_to_np(var))) fluid.framework._var_base_to_np(var)))
def test_var_base_as_np(self):
with fluid.dygraph.guard():
var = fluid.dygraph.to_variable(self.array)
self.assertTrue(np.array_equal(var.numpy(), np.array(var)))
self.assertTrue(
np.array_equal(
var.numpy(), np.array(
var, dtype=np.float32)))
def test_if(self): def test_if(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
var1 = fluid.dygraph.to_variable(np.array([[[0]]])) var1 = fluid.dygraph.to_variable(np.array([[[0]]]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册