未验证 提交 2c945737 编写于 作者: S ShenLiang 提交者: GitHub

Add wait_server_ready for dygraph parallel (#34207)

* add wait_server_ready

* fix remove bug
上级 8b59f5e0
...@@ -40,6 +40,7 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> { ...@@ -40,6 +40,7 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("npu only accept LoDTensor")); platform::errors::InvalidArgument("npu only accept LoDTensor"));
output_t->mutable_data<T>(ctx.GetPlace()); output_t->mutable_data<T>(ctx.GetPlace());
// add copy ids to ensure ids_t is prepared.
std::vector<int> ids; std::vector<int> ids;
TensorToVector(*ids_t, ctx.device_context(), &ids); TensorToVector(*ids_t, ctx.device_context(), &ids);
......
...@@ -193,6 +193,12 @@ def init_parallel_env(): ...@@ -193,6 +193,12 @@ def init_parallel_env():
elif core.is_compiled_with_xpu(): elif core.is_compiled_with_xpu():
parallel_helper._set_parallel_ctx( parallel_helper._set_parallel_ctx(
core.BKCLParallelContext(strategy, place)) core.BKCLParallelContext(strategy, place))
other_endpoints = strategy.trainer_endpoints[:]
other_endpoints.remove(strategy.current_endpoint)
if strategy.local_rank == 0:
wait_server_ready(other_endpoints)
parallel_helper._init_parallel_ctx() parallel_helper._init_parallel_ctx()
# 5: init gloo context (step 2: gloo init) # 5: init gloo context (step 2: gloo init)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册