未验证 提交 54b6c390 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

fix API bug of Tensor.cuda (#34416)

上级 1f0f5d3c
...@@ -1400,20 +1400,26 @@ void BindImperative(py::module *m_ptr) { ...@@ -1400,20 +1400,26 @@ void BindImperative(py::module *m_ptr) {
)DOC") )DOC")
.def("cuda", .def("cuda",
[](const std::shared_ptr<imperative::VarBase> &self, int device_id, [](const std::shared_ptr<imperative::VarBase> &self,
bool blocking) { py::handle &handle, bool blocking) {
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP) #if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot copy this Tensor to GPU in CPU version Paddle, " "Cannot copy this Tensor to GPU in CPU version Paddle, "
"Please recompile or reinstall Paddle with CUDA support.")); "Please recompile or reinstall Paddle with CUDA support."));
#else #else
int device_count = platform::GetCUDADeviceCount(); int device_count = platform::GetCUDADeviceCount();
if (device_id == -1) { int device_id = 0;
if (handle == py::none()) {
if (platform::is_gpu_place(self->Place())) { if (platform::is_gpu_place(self->Place())) {
return self; return self;
} else {
device_id = 0;
} }
} else {
PyObject *py_obj = handle.ptr();
PADDLE_ENFORCE_EQ(
PyCheckInteger(py_obj), true,
platform::errors::InvalidArgument(
" 'device_id' must be a positive integer"));
device_id = py::cast<int>(handle);
} }
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
device_id, 0, device_id, 0,
...@@ -1437,26 +1443,30 @@ void BindImperative(py::module *m_ptr) { ...@@ -1437,26 +1443,30 @@ void BindImperative(py::module *m_ptr) {
} }
#endif #endif
}, },
py::arg("device_id") = -1, py::arg("blocking") = true, R"DOC( py::arg("device_id") = py::none(), py::arg("blocking") = true, R"DOC(
Returns a copy of this Tensor in GPU memory. Returns a copy of this Tensor in GPU memory.
If this Tensor is already in GPU memory and device_id is default, If this Tensor is already in GPU memory and device_id is default,
then no copy is performed and the original Tensor is returned. then no copy is performed and the original Tensor is returned.
Args: Args:
device_id(int, optional): The destination GPU device id. Defaults to the current device. device_id(int, optional): The destination GPU device id. Default: None, means current device.
blocking(bool, optional): If False and the source is in pinned memory, the copy will be blocking(bool, optional): If False and the source is in pinned memory, the copy will be
asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False. asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False.
Examples: Examples:
.. code-block:: python .. code-block:: python
# required: gpu
import paddle import paddle
x = paddle.to_tensor(1.0, place=paddle.CPUPlace()) x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
print(x.place) # CPUPlace print(x.place) # CPUPlace
y = x.cuda() y = x.cuda()
print(y.place) # CUDAPlace(0) print(y.place) # CUDAPlace(0)
y = x.cuda(None)
print(y.place) # CUDAPlace(0)
y = x.cuda(1) y = x.cuda(1)
print(y.place) # CUDAPlace(1) print(y.place) # CUDAPlace(1)
......
...@@ -72,10 +72,17 @@ class TestVarBase(unittest.TestCase): ...@@ -72,10 +72,17 @@ class TestVarBase(unittest.TestCase):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
y = x.pin_memory() y = x.pin_memory()
self.assertEqual(y.place.__repr__(), "CUDAPinnedPlace") self.assertEqual(y.place.__repr__(), "CUDAPinnedPlace")
y = x.cuda()
y = x.cuda(None)
self.assertEqual(y.place.__repr__(), "CUDAPlace(0)")
y = x.cuda(device_id=0)
self.assertEqual(y.place.__repr__(), "CUDAPlace(0)")
y = x.cuda(blocking=False) y = x.cuda(blocking=False)
self.assertEqual(y.place.__repr__(), "CUDAPlace(0)") self.assertEqual(y.place.__repr__(), "CUDAPlace(0)")
y = x.cuda(blocking=True) y = x.cuda(blocking=True)
self.assertEqual(y.place.__repr__(), "CUDAPlace(0)") self.assertEqual(y.place.__repr__(), "CUDAPlace(0)")
with self.assertRaises(ValueError):
y = x.cuda("test")
# support 'dtype' is core.VarType # support 'dtype' is core.VarType
x = paddle.rand((2, 2)) x = paddle.rand((2, 2))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册