未验证 提交 b2b78f8e 编写于 作者: S Sing_chan 提交者: GitHub

[Bug Fix] set device_id=current_id when calling Tensor.cuda() without device_id (#43510)

* make device_id=current_id when not given

* use tracer to get current device id
上级 76b02b7c
...@@ -1535,40 +1535,40 @@ void BindImperative(py::module *m_ptr) { ...@@ -1535,40 +1535,40 @@ void BindImperative(py::module *m_ptr) {
"Cannot copy this Tensor to GPU in CPU version Paddle, " "Cannot copy this Tensor to GPU in CPU version Paddle, "
"Please recompile or reinstall Paddle with CUDA support.")); "Please recompile or reinstall Paddle with CUDA support."));
#else #else
int device_count = platform::GetGPUDeviceCount(); int device_count = platform::GetGPUDeviceCount();
int device_id = 0; int device_id = 0;
if (handle == py::none()) { if (handle == py::none()) {
if (platform::is_gpu_place(self->Place())) { auto default_place =
return self; imperative::GetCurrentTracer()->ExpectedPlace();
} device_id = default_place.GetDeviceId();
} else { } else {
PyObject *py_obj = handle.ptr(); PyObject *py_obj = handle.ptr();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
PyCheckInteger(py_obj), true, PyCheckInteger(py_obj), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
" 'device_id' must be a positive integer")); " 'device_id' must be a positive integer"));
device_id = py::cast<int>(handle); device_id = py::cast<int>(handle);
} }
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
device_id, 0, device_id, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Can not copy Tensor to Invalid CUDAPlace(%d), device id " "Can not copy Tensor to Invalid CUDAPlace(%d), device id "
"must inside [0, %d)", "must inside [0, %d)",
device_id, device_count)); device_id, device_count));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
device_id, device_count, device_id, device_count,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Can not copy Tensor to Invalid CUDAPlace(%d), device id " "Can not copy Tensor to Invalid CUDAPlace(%d), device id "
"must inside [0, %d)", "must inside [0, %d)",
device_id, device_count)); device_id, device_count));
platform::CUDAPlace place = platform::CUDAPlace(device_id); platform::CUDAPlace place = platform::CUDAPlace(device_id);
if (platform::is_same_place(self->Place(), place)) { if (platform::is_same_place(self->Place(), place)) {
return self; return self;
} else { } else {
auto new_var = self->NewVarBase(place, blocking); auto new_var = self->NewVarBase(place, blocking);
new_var->SetOverridedStopGradient(self->OverridedStopGradient()); new_var->SetOverridedStopGradient(self->OverridedStopGradient());
return new_var; return new_var;
} }
#endif #endif
}, },
py::arg("device_id") = py::none(), py::arg("blocking") = true, R"DOC( py::arg("device_id") = py::none(), py::arg("blocking") = true, R"DOC(
...@@ -1588,16 +1588,17 @@ void BindImperative(py::module *m_ptr) { ...@@ -1588,16 +1588,17 @@ void BindImperative(py::module *m_ptr) {
# required: gpu # required: gpu
import paddle import paddle
x = paddle.to_tensor(1.0, place=paddle.CPUPlace()) x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
print(x.place) # CPUPlace print(x.place) # Place(cpu)
y = x.cuda() y = x.cuda()
print(y.place) # CUDAPlace(0) print(y.place) # Place(gpu:0)
y = x.cuda(None) y = x.cuda(None)
print(y.place) # CUDAPlace(0) print(y.place) # Place(gpu:0)
y = x.cuda(1) paddle.device.set_device("gpu:1")
print(y.place) # CUDAPlace(1) y = x.cuda(None)
print(y.place) # Place(gpu:1)
)DOC") )DOC")
.def( .def(
"_share_memory", "_share_memory",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册