From 4a16d5c6a03df776b08ff587d01048971fb64b2e Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Mon, 25 Apr 2022 10:27:34 +0800 Subject: [PATCH] [Eager] Support numpy.ndarry in CastNumpy2Scalar (#42136) --- paddle/fluid/pybind/eager_utils.cc | 15 ++++++++++++++- python/paddle/fluid/tests/unittests/test_bfgs.py | 8 +++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index b391274843..d07cbd5ee2 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -1025,7 +1025,20 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj, PyTypeObject* type = obj->ob_type; auto type_name = std::string(type->tp_name); VLOG(1) << "type_name: " << type_name; - if (type_name == "numpy.float64") { + if (type_name == "numpy.ndarray" && PySequence_Check(obj)) { + PyObject* item = nullptr; + item = PySequence_GetItem(obj, 0); + if (PyObject_CheckFloatOrToFloat(&item)) { + float value = static_cast(PyFloat_AsDouble(item)); + return paddle::experimental::Scalar(value); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) is numpy.ndarry, the inner elements " + "must be " + "numpy.float32/float64 now, but got %s", + op_type, arg_pos + 1, type_name)); // NOLINT + } + } else if (type_name == "numpy.float64") { double value = CastPyArg2Double(obj, op_type, arg_pos); return paddle::experimental::Scalar(value); } else if (type_name == "numpy.float32") { diff --git a/python/paddle/fluid/tests/unittests/test_bfgs.py b/python/paddle/fluid/tests/unittests/test_bfgs.py index 4bf6de3eee..1a12913bc7 100644 --- a/python/paddle/fluid/tests/unittests/test_bfgs.py +++ b/python/paddle/fluid/tests/unittests/test_bfgs.py @@ -20,6 +20,7 @@ import paddle import paddle.nn.functional as F from paddle.incubate.optimizer.functional.bfgs import minimize_bfgs +from paddle.fluid.framework import _test_eager_guard from paddle.fluid.framework import _enable_legacy_dygraph _enable_legacy_dygraph() @@ -120,7 +121,7 @@ class TestBfgs(unittest.TestCase): results = test_static_graph(func, x0, dtype='float64') self.assertTrue(np.allclose(0.8, results[2])) - def test_rosenbrock(self): + def func_rosenbrock(self): # The Rosenbrock function is a standard optimization test case. a = np.random.random(size=[1]).astype('float32') minimum = [a.item(), (a**2).item()] @@ -139,6 +140,11 @@ class TestBfgs(unittest.TestCase): results = test_dynamic_graph(func, x0) self.assertTrue(np.allclose(minimum, results[2])) + def test_rosenbrock(self): + with _test_eager_guard(): + self.func_rosenbrock() + self.func_rosenbrock() + def test_exception(self): def func(x): return paddle.dot(x, x) -- GitLab