未验证 提交 3b18d96b 编写于 作者: J james 提交者: GitHub

fix device id issue for xpu eager mode (#48076)

* fix device id issue for xpu eager

xpu device id is not correctly set in eager mode, thus vars are on dev0 unless
XPUDeviceGurad is called, leading to this error message for all node rank != 0:
"NotImplementedError: (Unimplemented) Place Place(xpu:0) is not supported."

* fix typo

* fix pybind error
上级 14a6e67b
...@@ -105,6 +105,7 @@ void ProcessGroupBKCL::BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id) { ...@@ -105,6 +105,7 @@ void ProcessGroupBKCL::BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id) {
void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place, void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place,
const std::string& place_key) { const std::string& place_key) {
platform::XPUDeviceGuard guard(place.GetDeviceId());
BKCLUniqueId bkcl_id; BKCLUniqueId bkcl_id;
if (rank_ == 0) { if (rank_ == 0) {
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id)); PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id));
......
...@@ -128,6 +128,15 @@ FUNCTION_SET_DEVICE_TEMPLATE = """{} if (paddle::platform::is_gpu_place(place ...@@ -128,6 +128,15 @@ FUNCTION_SET_DEVICE_TEMPLATE = """{} if (paddle::platform::is_gpu_place(place
#else #else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace.")); "PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace."));
#endif
}}
if (paddle::platform::is_xpu_place(place)) {{
#if defined(PADDLE_WITH_XPU)
phi::backends::xpu::SetXPUDeviceId(place.device);
VLOG(4) <<"CurrentDeviceId: " << phi::backends::xpu::GetXPUCurrentDeviceId() << " from " << (int)place.device;
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU if use XPUPlace."));
#endif #endif
}} }}
""" """
......
...@@ -1284,7 +1284,7 @@ void BindDistributed(py::module *m) { ...@@ -1284,7 +1284,7 @@ void BindDistributed(py::module *m) {
auto processGroupBKCL = auto processGroupBKCL =
py::class_<distributed::ProcessGroupBKCL, py::class_<distributed::ProcessGroupBKCL,
std::shared_ptr<distributed::ProcessGroupBKCL>>( std::shared_ptr<distributed::ProcessGroupBKCL>>(
*m, "ProcessGroupBKCL", ProcessGroup) *m, "ProcessGroupBKCL", ProcessGroupStream)
.def(py::init<const std::shared_ptr<distributed::Store> &, .def(py::init<const std::shared_ptr<distributed::Store> &,
int, int,
int, int,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册