提交 b1b7af40 编写于 作者: T typhoonzero

support multi node

上级 7be79231
......@@ -51,19 +51,23 @@ Status SendRecvServerImpl::GetVariable(ServerContext *context,
Status SendRecvServerImpl::Wait(ServerContext *context,
const VoidMessage *in_var,
VoidMessage *out_var) {
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
{
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
}
return Status::OK;
}
void SendRecvServerImpl::Start() {
std::unique_lock<std::mutex> lock(this->mutex_);
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = false;
}
void SendRecvServerImpl::Done() {
std::unique_lock<std::mutex> lock(this->mutex_);
done_ = true;
{
std::lock_guard<std::mutex> lock(this->mutex_);
done_ = true;
}
condition_.notify_all();
}
......
......@@ -14,7 +14,6 @@
#include <stdint.h>
#include <sys/stat.h>
#include <iostream>
#include <ostream>
#include <thread>
......@@ -81,9 +80,9 @@ class RecvOp : public framework::OperatorBase {
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers");
size_t param_count = param_list.size();
rpc_service_->Start();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) {
rpc_service_->Start();
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) {
......@@ -95,8 +94,8 @@ class RecvOp : public framework::OperatorBase {
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
}
VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
VLOG(3) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *merged_grad = recv_scope.FindVar(grad_var_name);
if (merged_grad == nullptr) {
// create output of merged var.
......@@ -113,6 +112,7 @@ class RecvOp : public framework::OperatorBase {
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
}
rpc_service_->Start();
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc;
......@@ -127,14 +127,7 @@ class RecvOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
rpc_service_->Done();
// for (size_t i = 0; i < param_count; ++i) {
// auto *out_var = recv_scope.FindVar(param_list[i]);
// detail::TensorWithName out;
// out.first = param_list[i];
// out.second = out_var->Get<framework::LoDTensor>();
// rpc_service_->Push(out);
// }
grads_counter_.clear();
} // while(true)
}
......
......@@ -52,7 +52,8 @@ class SendOp : public framework::OperatorBase {
LOG(ERROR) << "send variable error: " << ins[i];
}
}
client_map_[0]->Wait(); // TODO(typhoonzero): support async optimization
// TODO(typhoonzero): support async optimization
client_map_[epmap[0]]->Wait();
for (size_t i = 0; i < ins.size(); ++i) {
bool ret = client_map_[epmap[i]]->GetVariable(scope, ins[i]);
if (!ret) {
......
......@@ -149,9 +149,8 @@ class DistributeTranspiler:
epmap = []
for ep, v in self.param_grad_map.iteritems():
send_op_ordered_inputs.extend(v["grads"])
for i in v:
for i in v["grads"]:
epmap.append(ep)
send_op = program.global_block().append_op(
type="send",
inputs={"X": send_op_ordered_inputs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册