From a72752263b1847071a749dc46dd98eaebbd668ea Mon Sep 17 00:00:00 2001 From: Chengmo Date: Fri, 4 Sep 2020 19:33:27 +0800 Subject: [PATCH] support heter-xpu-ps (#27018) support heter-xpu-ps --- .../distributed/variable_response.cc | 28 +++++++++++++++++++ .../fleet/runtime/parameter_server_runtime.py | 3 ++ 2 files changed, 31 insertions(+) diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index c9c42e0938d..de77121ee39 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -62,6 +62,34 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input, gpu_dev_ctx.Wait(); #else PADDLE_THROW("Unexpected branch"); +#endif + return true; + } else if (platform::is_xpu_place(place)) { +#ifdef PADDLE_WITH_XPU + auto& xpu_dev_ctx = static_cast(dev_ctx); + platform::CPUPlace cpu; + char* p = reinterpret_cast(dest); + while (total_written < length) { + if (!input->GetDirectBufferPointer(&data, &size_to_write)) { + return false; + } + + if (total_written + size_to_write > length) { + size_to_write = length - total_written; + } + + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, place), + reinterpret_cast(p), cpu, data, size_to_write); + p += size_to_write; + total_written += size_to_write; + input->Skip(size_to_write); + } + xpu_dev_ctx.Wait(); +#else + PADDLE_ENFORCE_NOT_NULL( + nullptr, + platform::errors::Unimplemented( + "Not supported XPU, please compile with option WITH_XPU=ON.")); #endif return true; } diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index 870c3fe8be4..227f8f60210 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -202,6 +202,9 @@ class ParameterServerRuntime(RuntimeBase): if self.role_maker._get_heter_worker_device() == "GPU": gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) executor = Executor(fluid.CUDAPlace(gpu_id)) + elif self.role_maker._get_heter_worker_device() == "XPU": + xpu_id = int(os.getenv("FLAGS_selected_xpus", "0")) + executor = Executor(fluid.XPUPlace(xpu_id)) else: raise ValueError("Not Support Device {}".format( self.role_maker._get_heter_worker_device())) -- GitLab