From 075a02d24e0f77d37cf628c348311843a1e6698a Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 3 Dec 2021 18:57:26 +0800 Subject: [PATCH] Fix _numel func logic and add test (#37810) --- paddle/fluid/pybind/imperative.cc | 4 ---- python/paddle/fluid/tests/unittests/test_var_base.py | 8 +++++++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index f03acc38084..29a1f0eafcb 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1963,10 +1963,6 @@ void BindImperative(py::module *m_ptr) { .def("_numel", [](std::shared_ptr &self) { auto *t = self->MutableVar()->GetMutable(); - PADDLE_ENFORCE_EQ( - t->IsInitialized(), true, - platform::errors::InvalidArgument( - "Tensor %s has not been initialized!", self->Name())); return t->numel(); }) .def_property("name", &imperative::VarBase::Name, diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 0e50a20a04e..ab6e8003833 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -1279,7 +1279,7 @@ class TestVarBaseInitVarBaseFromTensorWithDevice(unittest.TestCase): class TestVarBaseNumel(unittest.TestCase): - def test_numel(self): + def test_numel_normal(self): paddle.disable_static() np_x = np.random.random((3, 8, 8)) x = paddle.to_tensor(np_x, dtype="float64") @@ -1287,6 +1287,12 @@ class TestVarBaseNumel(unittest.TestCase): x_expected_numel = np.product((3, 8, 8)) self.assertEqual(x_actual_numel, x_expected_numel) + def test_numel_without_holder(self): + paddle.disable_static() + x_without_holder = core.VarBase() + x_actual_numel = x_without_holder._numel() + self.assertEqual(x_actual_numel, 0) + class TestVarBaseCopyGradientFrom(unittest.TestCase): def test_copy_gradient_from(self): -- GitLab