提交 40d0fff2 编写于 作者: T typhoonzero

single pserver workable version

上级 2b47fb3d
......@@ -69,43 +69,47 @@ class RecvOp : public framework::OperatorBase {
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
size_t param_count = param_list.size();
for (size_t i = 0; i < param_count; ++i) {
// blocking get one var from client.
const detail::TensorWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) {
// TODO(typhoonzero): get from multiple trainers.
for (size_t i = 0; i < param_count; ++i) {
// blocking get one var from client.
const detail::TensorWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
}
VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
}
VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
}
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDescBind program(program_desc);
framework::Executor executor(dev_ctx);
// Run sub graph to get optimized tensor
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc;
program_desc.ParseFromString(program_str);
framework::ProgramDescBind program(program_desc);
framework::Executor executor(dev_ctx);
// Run sub graph to get optimized tensor
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
for (size_t i = 0; i < param_count; ++i) {
auto *out_var = recv_scope.FindVar(param_list[i]);
detail::TensorWithName out;
out.first = param_list[i];
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out);
}
for (size_t i = 0; i < param_count; ++i) {
auto *out_var = recv_scope.FindVar(param_list[i]);
detail::TensorWithName out;
out.first = param_list[i];
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out);
}
} // while(true)
}
protected:
......
......@@ -93,7 +93,7 @@ class Executor(object):
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=True)
persistable=var.persistable)
def _optimize_distributed(self, optimize_ops, program, params_and_grads,
**kwargs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册