提交 fc06222a 编写于 作者: Y Yancey1989

fix async worker

上级 540b4535
......@@ -37,6 +37,7 @@ class SendBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......@@ -51,13 +52,14 @@ class SendBarrierOp : public framework::OperatorBase {
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) {
for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
}
};
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -77,6 +79,7 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
}
};
......
......@@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase):
def test_transpiler(self):
trainer = self.get_trainer()
pserver, startup = self.get_pserver(self.current_pserver_ep)
self.assertEqual([op.type for op in trainer.global_block().ops],
self.get_expect_trainer_ops())
......@@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase):
optimize_ops, params_grads = self.net_conf()
delete_ops(trainer.global_block(), optimize_ops)
return [op.type for op in trainer.global_block().ops
] + ["split_byref", "send", "concat"]
ops = [op.type for op in trainer.global_block().ops] + [
"split_byref", "send_vars", "send_barrier", "recv", "recv",
"fetch_barrier", "concat"
]
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
return ops
def get_trainer(self):
return self._transpiler_instance().get_trainer_program()
......
......@@ -348,7 +348,10 @@ class DistributeTranspiler:
type="send_barrier",
inputs={},
outputs={"RPCClient": rpc_client_var},
attrs={"endpoints": pserver_endpoints})
attrs={
"endpoints": pserver_endpoints,
"sync_mode": self.sync_mode
})
# step 3.2: insert recv op to receive parameters from parameter server
recv_vars = []
......
......@@ -15,7 +15,7 @@
class PSDispatcher(object):
"""
DistributedSpliter is the base class for dispatching vars
PSDispatcher is the base class for dispatching vars
into different pserver instance.
You need to implement the `dispatch` inferface.
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册