未验证 提交 940c6ff1 编写于 作者: C Chengmo 提交者: GitHub

Fix communicator slow bug & fix communicator stop bug (#20366)

* test=develop,Fix communicator slow bug

* test=develop, delete if() in stop_worker()

* test=develop

* fix UT, test=develop

* fix bug in fetch handler, test=develop

* fix bug in fetch handler, test=develop

* test=develop, fix fetch barrier bug

* test=develop, bug fix

* test=develop, bug fix

* test=develop, fix bug
上级 69e0b98f
......@@ -144,6 +144,10 @@ void DistMultiTrainer::Run() {
}
}
Scope *DistMultiTrainer::GetWorkerScope(int thread_id) {
return workers_[thread_id]->GetThreadScope();
}
void DistMultiTrainer::Finalize() {
for (auto &th : threads_) {
th.join();
......@@ -199,5 +203,5 @@ void DistMultiTrainer::MergeToRootScope(LoDTensor *root_tensor,
root_data[i] += data[i];
}
}
} // end namespace framework
} // end namespace paddle
} // namespace framework
} // namespace paddle
......@@ -93,8 +93,8 @@ class DistMultiTrainer : public MultiTrainer {
void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor);
virtual void FinalizeDumpEnv();
virtual void InitDumpEnv();
virtual Scope* GetWorkerScope(int thread_id);
virtual void DumpWork(int tid);
virtual Scope* GetWorkerScope(int thread_id) { return root_scope_; }
protected:
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;
......
......@@ -17,7 +17,7 @@
// default to 3min to avoid temprary network failures.
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
DEFINE_int32(rpc_retry_times, 3, "retry times for rpc");
DEFINE_int32(rpc_retry_times, 0, "retry times for rpc");
namespace paddle {
namespace operators {
......
......@@ -55,6 +55,9 @@ class FetchBarrierOp : public framework::OperatorBase {
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Any) Dummy inputs, used for control dependency")
.AsDispensable()
.AsDuplicable();
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
.AsDuplicable();
AddComment(R"DOC(
......
......@@ -129,8 +129,7 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
if not self._transpile_config.sync_mode and self._communicator.is_running(
):
if not self._transpile_config.sync_mode:
self._communicator.stop()
self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker):
......
......@@ -67,7 +67,7 @@ class DatasetCtrReader(data_generator.MultiSlotDataGenerator):
return random.random()
def iter():
if get_rand() < 0.1:
if get_rand() < 0.05:
fs = line.strip().split('\t')
dnn_input = load_dnn_input_record(fs[0])
lr_input = load_lr_input_record(fs[1])
......
......@@ -139,7 +139,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
for epoch_id in range(2):
for epoch_id in range(1):
pass_start = time.time()
dataset.set_filelist(filelist)
exe.train_from_dataset(
......@@ -157,7 +157,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
print("{}: \n {}\n".format(self.fetch_target_names[0],
fetch_target_vars[0]))
for epoch_id in range(2):
for epoch_id in range(1):
pass_start = time.time()
dataset.set_filelist(filelist)
exe.train_from_dataset(
......
......@@ -30,7 +30,6 @@ def skip_ci(func):
return __func__
@skip_ci
class TestDistMnist2x2(TestFleetBase):
def _setup_config(self):
self._sync_mode = False
......
......@@ -84,6 +84,9 @@ class FetchHandlerMonitor(object):
for varname in fetch_target_names
]
if None in fetch_vars:
continue
fetch_tensors = [var.get_tensor() for var in fetch_vars]
if self.fetch_instance.return_np:
......
......@@ -701,6 +701,7 @@ class DistributeTranspiler(object):
send_vars.append(var)
if self.sync_mode:
fetch_barrier_input = []
send_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
if self.has_distributed_lookup_table:
......@@ -718,6 +719,7 @@ class DistributeTranspiler(object):
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
fetch_barrier_input.append(send_barrier_out)
# step 3: insert recv op to receive parameters from parameter server
recv_vars = []
......@@ -788,12 +790,14 @@ class DistributeTranspiler(object):
OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name]
})
if self.sync_mode:
fetch_barrier_input.extend(splited_var)
if self.sync_mode:
# form a WAW dependency
program.global_block().append_op(
type="fetch_barrier",
inputs={},
inputs={"X": fetch_barrier_input},
outputs={"Out": all_recv_outputs},
attrs={
"endpoints": pserver_endpoints,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册