From fc06222ae91990d6eaece2c9895b869742000eae Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 24 May 2018 12:26:50 +0800 Subject: [PATCH] fix async worker --- paddle/fluid/operators/send_barrier_op.cc | 13 ++++++++----- .../fluid/tests/unittests/test_dist_transpiler.py | 11 +++++++---- .../fluid/transpiler/distribute_transpiler.py | 5 ++++- python/paddle/fluid/transpiler/ps_dispatcher.py | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 05e26236309..354eb4fa139 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -37,6 +37,7 @@ class SendBarrierOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + bool sync_mode = Attr("sync_mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -51,12 +52,13 @@ class SendBarrierOp : public framework::OperatorBase { // need to wait before sending send_barrier message PADDLE_ENFORCE(rpc_client->Wait()); - - for (auto& ep : eps) { - VLOG(3) << "send barrier, ep: " << ep; - rpc_client->AsyncSendBatchBarrier(ep); + if (sync_mode) { + for (auto& ep : eps) { + VLOG(3) << "send barrier, ep: " << ep; + rpc_client->AsyncSendBatchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } - PADDLE_ENFORCE(rpc_client->Wait()); } }; @@ -77,6 +79,7 @@ the Parameter Server would knew all variables have been sent. "(string vector, default 127.0.0.1:6164)" "Server endpoints to send variables to.") .SetDefault({"127.0.0.1:6164"}); + AddAttr("sync_mode", "work in sync_mode or not").SetDefault(true); } }; diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 10f8c4f3f01..fa49bd41a58 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase): def test_transpiler(self): trainer = self.get_trainer() pserver, startup = self.get_pserver(self.current_pserver_ep) - self.assertEqual([op.type for op in trainer.global_block().ops], self.get_expect_trainer_ops()) @@ -67,7 +66,7 @@ class TestDistTranspiler(unittest.TestCase): "fill_constant", "fill_constant", "uniform_random", "uniform_random" ]) - # the variable #fc_w will be split into two blocks + # the variable #fc_w will be split into two blocks fc_w_var = startup.global_block().var("fc_w.block1") self.assertEqual(fc_w_var.shape, (500, 1000)) @@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase): optimize_ops, params_grads = self.net_conf() delete_ops(trainer.global_block(), optimize_ops) - return [op.type for op in trainer.global_block().ops - ] + ["split_byref", "send", "concat"] + ops = [op.type for op in trainer.global_block().ops] + [ + "split_byref", "send_vars", "send_barrier", "recv", "recv", + "fetch_barrier", "concat" + ] + ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars") + return ops def get_trainer(self): return self._transpiler_instance().get_trainer_program() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 848cb0bd6c7..72a02f24a33 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -348,7 +348,10 @@ class DistributeTranspiler: type="send_barrier", inputs={}, outputs={"RPCClient": rpc_client_var}, - attrs={"endpoints": pserver_endpoints}) + attrs={ + "endpoints": pserver_endpoints, + "sync_mode": self.sync_mode + }) # step 3.2: insert recv op to receive parameters from parameter server recv_vars = [] diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py index dffe66998a4..9ba3bf82161 100644 --- a/python/paddle/fluid/transpiler/ps_dispatcher.py +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -15,7 +15,7 @@ class PSDispatcher(object): """ - DistributedSpliter is the base class for dispatching vars + PSDispatcher is the base class for dispatching vars into different pserver instance. You need to implement the `dispatch` inferface. """ -- GitLab