diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index c9c42e0938d51991c53b74ac6ad59c350f4a3ced..de77121ee3990366771723e3c43e53362c832ef7 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -62,6 +62,34 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input, gpu_dev_ctx.Wait(); #else PADDLE_THROW("Unexpected branch"); +#endif + return true; + } else if (platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU + auto& xpu_dev_ctx = static_cast(dev_ctx); + platform::CPUPlace cpu; + char* p = reinterpret_cast(dest); + while (total_written < length) { + if (!input->GetDirectBufferPointer(&data, &size_to_write)) { + return false; + } + + if (total_written + size_to_write > length) { + size_to_write = length - total_written; + } + + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, place), + reinterpret_cast(p), cpu, data, size_to_write); + p += size_to_write; + total_written += size_to_write; + input->Skip(size_to_write); + } + xpu_dev_ctx.Wait(); +#else + PADDLE_ENFORCE_NOT_NULL( + nullptr, + platform::errors::Unimplemented( + "Not supported XPU, please compile with option WITH_XPU=ON.")); #endif return true; } diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index 870c3fe8be4c87b8c7bde9b690c3ee1eded39393..227f8f60210ee8a44ab9e87ed7b88337c79ac7f1 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -202,6 +202,9 @@ class ParameterServerRuntime(RuntimeBase): if self.role_maker._get_heter_worker_device() == "GPU": gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) executor = Executor(fluid.CUDAPlace(gpu_id)) + elif self.role_maker._get_heter_worker_device() == "XPU": + xpu_id = int(os.getenv("FLAGS_selected_xpus", "0")) + executor = Executor(fluid.XPUPlace(xpu_id)) else: raise ValueError("Not Support Device {}".format( self.role_maker._get_heter_worker_device()))