提交 40d0fff2 编写于 作者: T typhoonzero

single pserver workable version

上级 2b47fb3d
...@@ -69,43 +69,47 @@ class RecvOp : public framework::OperatorBase { ...@@ -69,43 +69,47 @@ class RecvOp : public framework::OperatorBase {
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
for (size_t i = 0; i < param_count; ++i) { // TODO(typhoonzero): change this to a while_op for every cluster-batch.
// blocking get one var from client. while (true) {
const detail::TensorWithName &v = rpc_service_->Get(); // TODO(typhoonzero): get from multiple trainers.
auto grad_var_name = v.first; for (size_t i = 0; i < param_count; ++i) {
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); // blocking get one var from client.
std::string param_var_name; const detail::TensorWithName &v = rpc_service_->Get();
if (it != grad_list.end()) { auto grad_var_name = v.first;
param_var_name = param_list[it - grad_list.begin()]; auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
}
VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
} }
VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
}
std::string program_str = Attr<std::string>("OptimizeProgram"); std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc; framework::ProgramDesc program_desc;
program_desc.ParseFromString(program_str); program_desc.ParseFromString(program_str);
framework::ProgramDescBind program(program_desc); framework::ProgramDescBind program(program_desc);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
// Run sub graph to get optimized tensor // Run sub graph to get optimized tensor
try { try {
executor.Run(program, &recv_scope, 0, /*global_block*/ executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/); false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) { } catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
for (size_t i = 0; i < param_count; ++i) { for (size_t i = 0; i < param_count; ++i) {
auto *out_var = recv_scope.FindVar(param_list[i]); auto *out_var = recv_scope.FindVar(param_list[i]);
detail::TensorWithName out; detail::TensorWithName out;
out.first = param_list[i]; out.first = param_list[i];
out.second = out_var->Get<framework::LoDTensor>(); out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out); rpc_service_->Push(out);
} }
} // while(true)
} }
protected: protected:
......
...@@ -93,7 +93,7 @@ class Executor(object): ...@@ -93,7 +93,7 @@ class Executor(object):
dtype=var.dtype, dtype=var.dtype,
type=var.type, type=var.type,
lod_level=var.lod_level, lod_level=var.lod_level,
persistable=True) persistable=var.persistable)
def _optimize_distributed(self, optimize_ops, program, params_and_grads, def _optimize_distributed(self, optimize_ops, program, params_and_grads,
**kwargs): **kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册