未验证 提交 4a16d5c6 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Support numpy.ndarry in CastNumpy2Scalar (#42136)

上级 1178f153
...@@ -1025,7 +1025,20 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj, ...@@ -1025,7 +1025,20 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
PyTypeObject* type = obj->ob_type; PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name); auto type_name = std::string(type->tp_name);
VLOG(1) << "type_name: " << type_name; VLOG(1) << "type_name: " << type_name;
if (type_name == "numpy.float64") { if (type_name == "numpy.ndarray" && PySequence_Check(obj)) {
PyObject* item = nullptr;
item = PySequence_GetItem(obj, 0);
if (PyObject_CheckFloatOrToFloat(&item)) {
float value = static_cast<float>(PyFloat_AsDouble(item));
return paddle::experimental::Scalar(value);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) is numpy.ndarry, the inner elements "
"must be "
"numpy.float32/float64 now, but got %s",
op_type, arg_pos + 1, type_name)); // NOLINT
}
} else if (type_name == "numpy.float64") {
double value = CastPyArg2Double(obj, op_type, arg_pos); double value = CastPyArg2Double(obj, op_type, arg_pos);
return paddle::experimental::Scalar(value); return paddle::experimental::Scalar(value);
} else if (type_name == "numpy.float32") { } else if (type_name == "numpy.float32") {
......
...@@ -20,6 +20,7 @@ import paddle ...@@ -20,6 +20,7 @@ import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.incubate.optimizer.functional.bfgs import minimize_bfgs from paddle.incubate.optimizer.functional.bfgs import minimize_bfgs
from paddle.fluid.framework import _test_eager_guard
from paddle.fluid.framework import _enable_legacy_dygraph from paddle.fluid.framework import _enable_legacy_dygraph
_enable_legacy_dygraph() _enable_legacy_dygraph()
...@@ -120,7 +121,7 @@ class TestBfgs(unittest.TestCase): ...@@ -120,7 +121,7 @@ class TestBfgs(unittest.TestCase):
results = test_static_graph(func, x0, dtype='float64') results = test_static_graph(func, x0, dtype='float64')
self.assertTrue(np.allclose(0.8, results[2])) self.assertTrue(np.allclose(0.8, results[2]))
def test_rosenbrock(self): def func_rosenbrock(self):
# The Rosenbrock function is a standard optimization test case. # The Rosenbrock function is a standard optimization test case.
a = np.random.random(size=[1]).astype('float32') a = np.random.random(size=[1]).astype('float32')
minimum = [a.item(), (a**2).item()] minimum = [a.item(), (a**2).item()]
...@@ -139,6 +140,11 @@ class TestBfgs(unittest.TestCase): ...@@ -139,6 +140,11 @@ class TestBfgs(unittest.TestCase):
results = test_dynamic_graph(func, x0) results = test_dynamic_graph(func, x0)
self.assertTrue(np.allclose(minimum, results[2])) self.assertTrue(np.allclose(minimum, results[2]))
def test_rosenbrock(self):
with _test_eager_guard():
self.func_rosenbrock()
self.func_rosenbrock()
def test_exception(self): def test_exception(self):
def func(x): def func(x):
return paddle.dot(x, x) return paddle.dot(x, x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册