未验证 提交 4d649893 编写于 作者: Z ziyoujiyi 提交者: GitHub

bug fix (#43526)

* back fl

* delete ssl cert

* .

* make warning

* .

* unittest paral degree

* solve unittest

* heter & multi cloud commm ready

* .

* .

* fl-ps v1.0

* .

* support N + N mode

* .

* .

* .

* .

* delete print

* .

* .

* .

* .

* fix bug

* .

* .
上级 6f1d2483
...@@ -333,5 +333,5 @@ Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) { ...@@ -333,5 +333,5 @@ Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) {
} }
} // end namespace framework } // end namespace framework
} // namespace paddle } // end namespace paddle
#endif #endif
...@@ -434,8 +434,8 @@ class DistributedOpsPass(PassBase): ...@@ -434,8 +434,8 @@ class DistributedOpsPass(PassBase):
if op.type in SPARSE_OP_TYPE_DICT.keys() \ if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True: and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0] param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
if attrs['is_heter_ps_mode']: if attrs['is_heter_ps_mode'] and not attrs['is_fl_ps_mode']:
# trick for matchnet, need to modify # TODO: trick for matchnet, need to modify for heter_ps
param_name += op.input("Ids")[0][0] param_name += op.input("Ids")[0][0]
ops = pull_sparse_ops.get(param_name, []) ops = pull_sparse_ops.get(param_name, [])
ops.append(op) ops.append(op)
......
...@@ -1015,14 +1015,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1015,14 +1015,8 @@ class TheOnePSRuntime(RuntimeBase):
is_test = bool(int(os.getenv("TEST_MODE", "0"))) is_test = bool(int(os.getenv("TEST_MODE", "0")))
# for GEO # for GEO & heter_ps
if self.role_maker._is_first_worker() and self.is_heter_ps_mode: init_params = dense_map
# for ps-heter mode load all parameters on first_worker
init_params = get_the_one_recv_context(self.context,
split_dense_table=True,
use_origin_program=True)
else:
init_params = dense_map
# if not is_test: # if not is_test:
# self._communicator.init_params(init_params) # self._communicator.init_params(init_params)
...@@ -1053,11 +1047,7 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -1053,11 +1047,7 @@ class TheOnePSRuntime(RuntimeBase):
fleet.util.barrier() # 保证 0 号 worker 参数 push_dense_param over fleet.util.barrier() # 保证 0 号 worker 参数 push_dense_param over
if not self.context['use_ps_gpu']: if not self.context['use_ps_gpu']:
if self.is_heter_ps_mode == True and not self.role_maker._is_first_worker( self._pull_all_dense(scopes, send_ctx, dense_map)
):
self._communicator.pull_dense(init_params)
else:
self._pull_all_dense(scopes, send_ctx, dense_map)
fleet.util.barrier() fleet.util.barrier()
if self.context[ if self.context[
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册