diff --git a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc index a5c80cb04108d28381241f4cce6de347e331ddac..8dfb65d9813749ea3529b7168e6ea02d61c9e41d 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupBKCL.cc @@ -105,6 +105,7 @@ void ProcessGroupBKCL::BroadcastUniqueBKCLID(BKCLUniqueId* bkcl_id) { void ProcessGroupBKCL::CreateBKCLEnvCache(const Place& place, const std::string& place_key) { + platform::XPUDeviceGuard guard(place.GetDeviceId()); BKCLUniqueId bkcl_id; if (rank_ == 0) { PADDLE_ENFORCE_XPU_SUCCESS(bkcl_get_unique_id(&bkcl_id)); diff --git a/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py index 8e3944b79c30faf097ff91b5048314e5a54fcd28..aacde58fa7bc2e1d22f54b325eacc30ac6d6fcd7 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py @@ -128,6 +128,15 @@ FUNCTION_SET_DEVICE_TEMPLATE = """{} if (paddle::platform::is_gpu_place(place #else PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( "PaddlePaddle should compile with CUSTOM_DEVICE if use CustomPlace.")); +#endif + }} + if (paddle::platform::is_xpu_place(place)) {{ +#if defined(PADDLE_WITH_XPU) + phi::backends::xpu::SetXPUDeviceId(place.device); + VLOG(4) <<"CurrentDeviceId: " << phi::backends::xpu::GetXPUCurrentDeviceId() << " from " << (int)place.device; +#else + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with XPU if use XPUPlace.")); #endif }} """ diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index dbc4c57c656baac0d1dcb2ce2001cc512ed4d240..52160ea99a083267278318b900bff8d5607f2870 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -1284,7 +1284,7 @@ void BindDistributed(py::module *m) { auto processGroupBKCL = py::class_>( - *m, "ProcessGroupBKCL", ProcessGroup) + *m, "ProcessGroupBKCL", ProcessGroupStream) .def(py::init &, int, int,