未验证 提交 75d306ff 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] fix to_uva_tensor without specific gpu number (#44937)

* [Eager] fix to_uva_tensor without specific gpu number

* Update test_tensor_uva.py

update test case
上级 ee9ea48d
...@@ -870,10 +870,13 @@ static PyObject* eager_api_to_uva_tensor(PyObject* self, ...@@ -870,10 +870,13 @@ static PyObject* eager_api_to_uva_tensor(PyObject* self,
PyObject* obj = PyTuple_GET_ITEM(args, 0); PyObject* obj = PyTuple_GET_ITEM(args, 0);
auto array = py::cast<py::array>(py::handle(obj)); auto array = py::cast<py::array>(py::handle(obj));
int device_id = 0; Py_ssize_t args_num = PyTuple_Size(args);
PyObject* Py_device_id = PyTuple_GET_ITEM(args, 1); int64_t device_id = 0;
if (Py_device_id) { if (args_num > 1) {
device_id = CastPyArg2AttrLong(Py_device_id, 1); PyObject* Py_device_id = PyTuple_GET_ITEM(args, 1);
if (Py_device_id) {
device_id = CastPyArg2AttrLong(Py_device_id, 1);
}
} }
if (py::isinstance<py::array_t<int32_t>>(array)) { if (py::isinstance<py::array_t<int32_t>>(array)) {
......
...@@ -47,10 +47,15 @@ class TestUVATensorFromNumpy(unittest.TestCase): ...@@ -47,10 +47,15 @@ class TestUVATensorFromNumpy(unittest.TestCase):
data = np.random.randint(10, size=[4, 5]).astype(dtype) data = np.random.randint(10, size=[4, 5]).astype(dtype)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
tensor = paddle.fluid.core.to_uva_tensor(data, 0) tensor = paddle.fluid.core.to_uva_tensor(data, 0)
tensor2 = paddle.fluid.core.to_uva_tensor(data)
else: else:
tensor = core.eager.to_uva_tensor(data, 0) tensor = core.eager.to_uva_tensor(data, 0)
tensor2 = core.eager.to_uva_tensor(data)
self.assertTrue(tensor.place.is_gpu_place()) self.assertTrue(tensor.place.is_gpu_place())
self.assertTrue(tensor2.place.is_gpu_place())
self.assertTrue(np.allclose(tensor.numpy(), data)) self.assertTrue(np.allclose(tensor.numpy(), data))
self.assertTrue(np.allclose(tensor2.numpy(), data))
def test_uva_tensor_creation(self): def test_uva_tensor_creation(self):
with _test_eager_guard(): with _test_eager_guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册