未验证 提交 eeaf04da 编写于 作者: C Chengmo 提交者: GitHub

[cherry-pick]Fix communicator slow bug & fix communicator stop bug (#20366) (#20646)

* Fix communicator slow bug & fix communicator stop bug (#20366)

* test=develop,Fix communicator slow bug

* test=develop, delete if() in stop_worker()

* test=develop

* fix UT, test=develop

* fix bug in fetch handler, test=develop

* fix bug in fetch handler, test=develop

* test=develop, fix fetch barrier bug

* test=develop, bug fix

* test=develop, bug fix

* test=develop, fix bug

* test=develop,test=release/1.6
上级 965b45e8
...@@ -144,6 +144,10 @@ void DistMultiTrainer::Run() { ...@@ -144,6 +144,10 @@ void DistMultiTrainer::Run() {
} }
} }
Scope *DistMultiTrainer::GetWorkerScope(int thread_id) {
return workers_[thread_id]->GetThreadScope();
}
void DistMultiTrainer::Finalize() { void DistMultiTrainer::Finalize() {
for (auto &th : threads_) { for (auto &th : threads_) {
th.join(); th.join();
...@@ -199,5 +203,5 @@ void DistMultiTrainer::MergeToRootScope(LoDTensor *root_tensor, ...@@ -199,5 +203,5 @@ void DistMultiTrainer::MergeToRootScope(LoDTensor *root_tensor,
root_data[i] += data[i]; root_data[i] += data[i];
} }
} }
} // end namespace framework } // namespace framework
} // end namespace paddle } // namespace paddle
...@@ -93,8 +93,8 @@ class DistMultiTrainer : public MultiTrainer { ...@@ -93,8 +93,8 @@ class DistMultiTrainer : public MultiTrainer {
void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor); void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor);
virtual void FinalizeDumpEnv(); virtual void FinalizeDumpEnv();
virtual void InitDumpEnv(); virtual void InitDumpEnv();
virtual Scope* GetWorkerScope(int thread_id);
virtual void DumpWork(int tid); virtual void DumpWork(int tid);
virtual Scope* GetWorkerScope(int thread_id) { return root_scope_; }
protected: protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_; std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
......
...@@ -923,6 +923,7 @@ void GeoSgdCommunicator::RpcSend(const std::string &origin_var_name, ...@@ -923,6 +923,7 @@ void GeoSgdCommunicator::RpcSend(const std::string &origin_var_name,
auto &cpu_ctx_send = *pool.Get(platform::CPUPlace()); auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, *delta_scope_.get(), rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, *delta_scope_.get(),
splited_var_name); splited_var_name);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
// default to 3min to avoid temprary network failures. // default to 3min to avoid temprary network failures.
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc"); DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
DEFINE_int32(rpc_retry_times, 3, "retry times for rpc"); DEFINE_int32(rpc_retry_times, 0, "retry times for rpc");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -55,6 +55,9 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -55,6 +55,9 @@ class FetchBarrierOp : public framework::OperatorBase {
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(Any) Dummy inputs, used for control dependency")
.AsDispensable()
.AsDuplicable();
AddOutput("Out", "(Any) Dummy outputs, used for control dependency") AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable(); .AsDuplicable();
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -129,8 +129,7 @@ class DistributedTranspiler(Fleet): ...@@ -129,8 +129,7 @@ class DistributedTranspiler(Fleet):
Returns: Returns:
None None
""" """
if not self._transpile_config.sync_mode and self._communicator.is_running( if not self._transpile_config.sync_mode:
):
self._communicator.stop() self._communicator.stop()
self._executor.close() self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker): if isinstance(self._role_maker, MPISymetricRoleMaker):
......
...@@ -67,7 +67,7 @@ class DatasetCtrReader(data_generator.MultiSlotDataGenerator): ...@@ -67,7 +67,7 @@ class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
return random.random() return random.random()
def iter(): def iter():
if get_rand() < 0.1: if get_rand() < 0.05:
fs = line.strip().split('\t') fs = line.strip().split('\t')
dnn_input = load_dnn_input_record(fs[0]) dnn_input = load_dnn_input_record(fs[0])
lr_input = load_lr_input_record(fs[1]) lr_input = load_lr_input_record(fs[1])
......
...@@ -139,7 +139,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -139,7 +139,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
dataset.set_thread(thread_num) dataset.set_thread(thread_num)
for epoch_id in range(2): for epoch_id in range(1):
pass_start = time.time() pass_start = time.time()
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
exe.train_from_dataset( exe.train_from_dataset(
...@@ -157,7 +157,7 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -157,7 +157,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
print("{}: \n {}\n".format(self.fetch_target_names[0], print("{}: \n {}\n".format(self.fetch_target_names[0],
fetch_target_vars[0])) fetch_target_vars[0]))
for epoch_id in range(2): for epoch_id in range(1):
pass_start = time.time() pass_start = time.time()
dataset.set_filelist(filelist) dataset.set_filelist(filelist)
exe.train_from_dataset( exe.train_from_dataset(
......
...@@ -30,7 +30,6 @@ def skip_ci(func): ...@@ -30,7 +30,6 @@ def skip_ci(func):
return __func__ return __func__
@skip_ci
class TestDistMnist2x2(TestFleetBase): class TestDistMnist2x2(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
......
...@@ -84,6 +84,9 @@ class FetchHandlerMonitor(object): ...@@ -84,6 +84,9 @@ class FetchHandlerMonitor(object):
for varname in fetch_target_names for varname in fetch_target_names
] ]
if None in fetch_vars:
continue
fetch_tensors = [var.get_tensor() for var in fetch_vars] fetch_tensors = [var.get_tensor() for var in fetch_vars]
if self.fetch_instance.return_np: if self.fetch_instance.return_np:
......
...@@ -701,6 +701,7 @@ class DistributeTranspiler(object): ...@@ -701,6 +701,7 @@ class DistributeTranspiler(object):
send_vars.append(var) send_vars.append(var)
if self.sync_mode: if self.sync_mode:
fetch_barrier_input = []
send_barrier_out = program.global_block().create_var( send_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
...@@ -718,6 +719,7 @@ class DistributeTranspiler(object): ...@@ -718,6 +719,7 @@ class DistributeTranspiler(object):
"trainer_id": self.trainer_id, "trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
fetch_barrier_input.append(send_barrier_out)
# step 3: insert recv op to receive parameters from parameter server # step 3: insert recv op to receive parameters from parameter server
recv_vars = [] recv_vars = []
...@@ -788,12 +790,14 @@ class DistributeTranspiler(object): ...@@ -788,12 +790,14 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name] [param_varname, recv_op_role_var_name]
}) })
if self.sync_mode:
fetch_barrier_input.extend(splited_var)
if self.sync_mode: if self.sync_mode:
# form a WAW dependency # form a WAW dependency
program.global_block().append_op( program.global_block().append_op(
type="fetch_barrier", type="fetch_barrier",
inputs={}, inputs={"X": fetch_barrier_input},
outputs={"Out": all_recv_outputs}, outputs={"Out": all_recv_outputs},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册