提交 40d0fff2 编写于 作者: T typhoonzero

single pserver workable version

上级 2b47fb3d
...@@ -69,6 +69,9 @@ class RecvOp : public framework::OperatorBase { ...@@ -69,6 +69,9 @@ class RecvOp : public framework::OperatorBase {
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) {
// TODO(typhoonzero): get from multiple trainers.
for (size_t i = 0; i < param_count; ++i) { for (size_t i = 0; i < param_count; ++i) {
// blocking get one var from client. // blocking get one var from client.
const detail::TensorWithName &v = rpc_service_->Get(); const detail::TensorWithName &v = rpc_service_->Get();
...@@ -106,6 +109,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -106,6 +109,7 @@ class RecvOp : public framework::OperatorBase {
out.second = out_var->Get<framework::LoDTensor>(); out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out); rpc_service_->Push(out);
} }
} // while(true)
} }
protected: protected:
......
...@@ -93,7 +93,7 @@ class Executor(object): ...@@ -93,7 +93,7 @@ class Executor(object):
dtype=var.dtype, dtype=var.dtype,
type=var.type, type=var.type,
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=True) persistable=var.persistable)
def _optimize_distributed(self, optimize_ops, program, params_and_grads, def _optimize_distributed(self, optimize_ops, program, params_and_grads,
**kwargs): **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册