diff --git a/paddle/fluid/operators/send_barrier_op.cc b/paddle/fluid/operators/send_barrier_op.cc index 05e262363095d0914d057ace353e32c6a6702413..354eb4fa13913eb6ec01885cf411627bf8cfa61c 100644 --- a/paddle/fluid/operators/send_barrier_op.cc +++ b/paddle/fluid/operators/send_barrier_op.cc @@ -37,6 +37,7 @@ class SendBarrierOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { std::vector eps = Attr>("endpoints"); + bool sync_mode = Attr("sync_mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -51,12 +52,13 @@ class SendBarrierOp : public framework::OperatorBase { // need to wait before sending send_barrier message PADDLE_ENFORCE(rpc_client->Wait()); - - for (auto& ep : eps) { - VLOG(3) << "send barrier, ep: " << ep; - rpc_client->AsyncSendBatchBarrier(ep); + if (sync_mode) { + for (auto& ep : eps) { + VLOG(3) << "send barrier, ep: " << ep; + rpc_client->AsyncSendBatchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } - PADDLE_ENFORCE(rpc_client->Wait()); } }; @@ -77,6 +79,7 @@ the Parameter Server would knew all variables have been sent. "(string vector, default 127.0.0.1:6164)" "Server endpoints to send variables to.") .SetDefault({"127.0.0.1:6164"}); + AddAttr("sync_mode", "work in sync_mode or not").SetDefault(true); } }; diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py index 10f8c4f3f0167632bb4a3d454ab026ba73a8f305..fa49bd41a5876847d046682dce5c3d3868a18500 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py +++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py @@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase): def test_transpiler(self): trainer = self.get_trainer() pserver, startup = self.get_pserver(self.current_pserver_ep) - self.assertEqual([op.type for op in trainer.global_block().ops], self.get_expect_trainer_ops()) @@ -67,7 +66,7 @@ class TestDistTranspiler(unittest.TestCase): "fill_constant", "fill_constant", "uniform_random", "uniform_random" ]) - # the variable #fc_w will be split into two blocks + # the variable #fc_w will be split into two blocks fc_w_var = startup.global_block().var("fc_w.block1") self.assertEqual(fc_w_var.shape, (500, 1000)) @@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase): optimize_ops, params_grads = self.net_conf() delete_ops(trainer.global_block(), optimize_ops) - return [op.type for op in trainer.global_block().ops - ] + ["split_byref", "send", "concat"] + ops = [op.type for op in trainer.global_block().ops] + [ + "split_byref", "send_vars", "send_barrier", "recv", "recv", + "fetch_barrier", "concat" + ] + ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars") + return ops def get_trainer(self): return self._transpiler_instance().get_trainer_program() diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 848cb0bd6c71895e6e44343f394f6415ec3d8acc..72a02f24a339ba7d36dbf58a0479e4b4e681cab3 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -348,7 +348,10 @@ class DistributeTranspiler: type="send_barrier", inputs={}, outputs={"RPCClient": rpc_client_var}, - attrs={"endpoints": pserver_endpoints}) + attrs={ + "endpoints": pserver_endpoints, + "sync_mode": self.sync_mode + }) # step 3.2: insert recv op to receive parameters from parameter server recv_vars = [] diff --git a/python/paddle/fluid/transpiler/ps_dispatcher.py b/python/paddle/fluid/transpiler/ps_dispatcher.py index dffe66998a4e89c89df2395d114b0fefab850606..9ba3bf82161c2f105f61e87239c6f3f5477f515d 100644 --- a/python/paddle/fluid/transpiler/ps_dispatcher.py +++ b/python/paddle/fluid/transpiler/ps_dispatcher.py @@ -15,7 +15,7 @@ class PSDispatcher(object): """ - DistributedSpliter is the base class for dispatching vars + PSDispatcher is the base class for dispatching vars into different pserver instance. You need to implement the `dispatch` inferface. """