提交 fc06222a 编写于 作者: Y Yancey1989

fix async worker

上级 540b4535
...@@ -37,6 +37,7 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -37,6 +37,7 @@ class SendBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
...@@ -51,13 +52,14 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -51,13 +52,14 @@ class SendBarrierOp : public framework::OperatorBase {
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) {
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep; VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
} }
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
} }
}
}; };
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -77,6 +79,7 @@ the Parameter Server would knew all variables have been sent. ...@@ -77,6 +79,7 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.") "Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
} }
}; };
......
...@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase): ...@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase):
def test_transpiler(self): def test_transpiler(self):
trainer = self.get_trainer() trainer = self.get_trainer()
pserver, startup = self.get_pserver(self.current_pserver_ep) pserver, startup = self.get_pserver(self.current_pserver_ep)
self.assertEqual([op.type for op in trainer.global_block().ops], self.assertEqual([op.type for op in trainer.global_block().ops],
self.get_expect_trainer_ops()) self.get_expect_trainer_ops())
...@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase): ...@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase):
optimize_ops, params_grads = self.net_conf() optimize_ops, params_grads = self.net_conf()
delete_ops(trainer.global_block(), optimize_ops) delete_ops(trainer.global_block(), optimize_ops)
return [op.type for op in trainer.global_block().ops ops = [op.type for op in trainer.global_block().ops] + [
] + ["split_byref", "send", "concat"] "split_byref", "send_vars", "send_barrier", "recv", "recv",
"fetch_barrier", "concat"
]
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
return ops
def get_trainer(self): def get_trainer(self):
return self._transpiler_instance().get_trainer_program() return self._transpiler_instance().get_trainer_program()
......
...@@ -348,7 +348,10 @@ class DistributeTranspiler: ...@@ -348,7 +348,10 @@ class DistributeTranspiler:
type="send_barrier", type="send_barrier",
inputs={}, inputs={},
outputs={"RPCClient": rpc_client_var}, outputs={"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints}) attrs={
"endpoints": pserver_endpoints,
"sync_mode": self.sync_mode
})
# step 3.2: insert recv op to receive parameters from parameter server # step 3.2: insert recv op to receive parameters from parameter server
recv_vars = [] recv_vars = []
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
class PSDispatcher(object): class PSDispatcher(object):
""" """
DistributedSpliter is the base class for dispatching vars PSDispatcher is the base class for dispatching vars
into different pserver instance. into different pserver instance.
You need to implement the `dispatch` inferface. You need to implement the `dispatch` inferface.
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册