未验证 提交 a7275226 编写于 作者: C Chengmo 提交者: GitHub

support heter-xpu-ps (#27018)

support heter-xpu-ps
上级 d8437062
......@@ -62,6 +62,34 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input,
gpu_dev_ctx.Wait();
#else
PADDLE_THROW("Unexpected branch");
#endif
return true;
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
auto& xpu_dev_ctx = static_cast<const platform::XPUDeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
char* p = reinterpret_cast<char*>(dest);
while (total_written < length) {
if (!input->GetDirectBufferPointer(&data, &size_to_write)) {
return false;
}
if (total_written + size_to_write > length) {
size_to_write = length - total_written;
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, place),
reinterpret_cast<void*>(p), cpu, data, size_to_write);
p += size_to_write;
total_written += size_to_write;
input->Skip(size_to_write);
}
xpu_dev_ctx.Wait();
#else
PADDLE_ENFORCE_NOT_NULL(
nullptr,
platform::errors::Unimplemented(
"Not supported XPU, please compile with option WITH_XPU=ON."));
#endif
return true;
}
......
......@@ -202,6 +202,9 @@ class ParameterServerRuntime(RuntimeBase):
if self.role_maker._get_heter_worker_device() == "GPU":
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
executor = Executor(fluid.CUDAPlace(gpu_id))
elif self.role_maker._get_heter_worker_device() == "XPU":
xpu_id = int(os.getenv("FLAGS_selected_xpus", "0"))
executor = Executor(fluid.XPUPlace(xpu_id))
else:
raise ValueError("Not Support Device {}".format(
self.role_maker._get_heter_worker_device()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册