提交 b1b7af40 编写于 作者: T typhoonzero

support multi node

上级 7be79231
...@@ -51,19 +51,23 @@ Status SendRecvServerImpl::GetVariable(ServerContext *context, ...@@ -51,19 +51,23 @@ Status SendRecvServerImpl::GetVariable(ServerContext *context,
Status SendRecvServerImpl::Wait(ServerContext *context, Status SendRecvServerImpl::Wait(ServerContext *context,
const VoidMessage *in_var, const VoidMessage *in_var,
VoidMessage *out_var) { VoidMessage *out_var) {
{
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; }); condition_.wait(lock, [=] { return this->done_ == true; });
}
return Status::OK; return Status::OK;
} }
void SendRecvServerImpl::Start() { void SendRecvServerImpl::Start() {
std::unique_lock<std::mutex> lock(this->mutex_); std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false; done_ = false;
} }
void SendRecvServerImpl::Done() { void SendRecvServerImpl::Done() {
std::unique_lock<std::mutex> lock(this->mutex_); {
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true; done_ = true;
}
condition_.notify_all(); condition_.notify_all();
} }
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include <stdint.h> #include <stdint.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <iostream>
#include <ostream> #include <ostream>
#include <thread> #include <thread>
...@@ -81,9 +80,9 @@ class RecvOp : public framework::OperatorBase { ...@@ -81,9 +80,9 @@ class RecvOp : public framework::OperatorBase {
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers"); auto trainer_count = Attr<int>("Trainers");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
rpc_service_->Start();
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) { while (true) {
rpc_service_->Start();
// Get from multiple trainers, we don't care about order in which // Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient. // the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) { for (size_t i = 0; i < param_count * trainer_count; ++i) {
...@@ -95,7 +94,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -95,7 +94,7 @@ class RecvOp : public framework::OperatorBase {
if (it != grad_list.end()) { if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()]; param_var_name = param_list[it - grad_list.begin()];
} }
VLOG(10) << "recved grad: " << grad_var_name VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name; << " updating param: " << param_var_name;
auto *merged_grad = recv_scope.FindVar(grad_var_name); auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) { if (merged_grad == nullptr) {
...@@ -113,6 +112,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -113,6 +112,7 @@ class RecvOp : public framework::OperatorBase {
// FIXME(typhoonzero): do not copy // FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor); framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
} }
rpc_service_->Start();
std::string program_str = Attr<std::string>("OptimizeProgram"); std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc; framework::ProgramDesc program_desc;
...@@ -127,14 +127,7 @@ class RecvOp : public framework::OperatorBase { ...@@ -127,14 +127,7 @@ class RecvOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
rpc_service_->Done(); rpc_service_->Done();
grads_counter_.clear();
// for (size_t i = 0; i < param_count; ++i) {
// auto *out_var = recv_scope.FindVar(param_list[i]);
// detail::TensorWithName out;
// out.first = param_list[i];
// out.second = out_var->Get<framework::LoDTensor>();
// rpc_service_->Push(out);
// }
} // while(true) } // while(true)
} }
......
...@@ -52,7 +52,8 @@ class SendOp : public framework::OperatorBase { ...@@ -52,7 +52,8 @@ class SendOp : public framework::OperatorBase {
LOG(ERROR) << "send variable error: " << ins[i]; LOG(ERROR) << "send variable error: " << ins[i];
} }
} }
client_map_[0]->Wait(); // TODO(typhoonzero): support async optimization // TODO(typhoonzero): support async optimization
client_map_[epmap[0]]->Wait();
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]); bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]);
if (!ret) { if (!ret) {
......
...@@ -149,9 +149,8 @@ class DistributeTranspiler: ...@@ -149,9 +149,8 @@ class DistributeTranspiler:
epmap = [] epmap = []
for ep, v in self.param_grad_map.iteritems(): for ep, v in self.param_grad_map.iteritems():
send_op_ordered_inputs.extend(v["grads"]) send_op_ordered_inputs.extend(v["grads"])
for i in v: for i in v["grads"]:
epmap.append(ep) epmap.append(ep)
send_op = program.global_block().append_op( send_op = program.global_block().append_op(
type="send", type="send",
inputs={"X": send_op_ordered_inputs inputs={"X": send_op_ordered_inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册