提交 1e549563 编写于 作者: T typhoonzero

multi trainers

上级 e6079390
...@@ -33,21 +33,40 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context, ...@@ -33,21 +33,40 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
} }
Status SendRecvServerImpl::GetVariable(ServerContext *context, Status SendRecvServerImpl::GetVariable(ServerContext *context,
const VoidMessage *in_var, const VariableMessage *in_var,
VariableMessage *out_var) { VariableMessage *out_var) {
// Block util the sub graph is done. std::string get_var_name = in_var->varname();
auto out_tensor_with_name = var_return_queue_.Pop(); auto *var = scope_->FindVar(get_var_name);
auto tensor = var->Get<framework::LoDTensor>();
std::ostringstream oss; std::ostringstream oss;
framework::SerializeToStream(oss, out_tensor_with_name.second, framework::SerializeToStream(oss, tensor, platform::CPUDeviceContext());
platform::CPUDeviceContext());
std::string *varname = out_var->mutable_varname(); std::string *varname = out_var->mutable_varname();
*varname = out_tensor_with_name.first; *varname = get_var_name;
std::string *serialized = out_var->mutable_serialized(); std::string *serialized = out_var->mutable_serialized();
*serialized = oss.str(); *serialized = oss.str();
return Status::OK; return Status::OK;
} }
Status SendRecvServerImpl::Wait(ServerContext *context,
const VoidMessage *in_var,
VoidMessage *out_var) {
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
return Status::OK;
}
void SendRecvServerImpl::Start() {
std::unique_lock<std::mutex> lock(this->mutex_);
done_ = false;
}
void SendRecvServerImpl::Done() {
std::unique_lock<std::mutex> lock(this->mutex_);
done_ = true;
condition_.notify_all();
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -43,19 +43,20 @@ bool RPCClient::SendVariable(const framework::Scope& scope, ...@@ -43,19 +43,20 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
return true; return true;
} }
bool RPCClient::GetVariable(const framework::Scope& scope) { bool RPCClient::GetVariable(const framework::Scope& scope,
const std::string& outname) {
ClientContext context; ClientContext context;
VariableMessage msg; VariableMessage call_msg, ret_msg;
VoidMessage void_msg; call_msg.set_varname(outname);
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
Status status = stub_->GetVariable(&context, void_msg, &msg); Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
if (!status.ok()) { if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message(); LOG(ERROR) << "gRPC error: " << status.error_message();
return false; return false;
} }
std::istringstream iss(msg.serialized()); std::istringstream iss(ret_msg.serialized());
auto outname = msg.varname();
framework::LoDTensor ret_tensor; framework::LoDTensor ret_tensor;
framework::DeserializeFromStream(iss, &ret_tensor); framework::DeserializeFromStream(iss, &ret_tensor);
auto* outvar = scope.FindVar(outname); auto* outvar = scope.FindVar(outname);
......
...@@ -22,7 +22,9 @@ service SendRecvService { ...@@ -22,7 +22,9 @@ service SendRecvService {
// TODO(typhoonzero): add streaming API // TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {} rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname. // Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VoidMessage) returns (VariableMessage) {} rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// wait for one execution of the program
rpc Wait(VoidMessage) returns (VoidMessage) {}
} }
// VariableMessage is serialized paddle variable message. // VariableMessage is serialized paddle variable message.
......
...@@ -20,10 +20,6 @@ ...@@ -20,10 +20,6 @@
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/simple_block_queue.h" #include "paddle/operators/detail/simple_block_queue.h"
// #include <grpc++/channel.h>
// #include <grpc++/client_context.h>
// #include <grpc++/create_channel.h>
// #include <grpc++/security/credentials.h>
#include "paddle/operators/detail/send_recv.grpc.pb.h" #include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h" #include "paddle/operators/detail/send_recv.pb.h"
...@@ -56,18 +52,24 @@ class SendRecvServerImpl final : public SendRecvService::Service { ...@@ -56,18 +52,24 @@ class SendRecvServerImpl final : public SendRecvService::Service {
Status SendVariable(ServerContext *context, const VariableMessage *in_var, Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VoidMessage *out_var) override; VoidMessage *out_var) override;
Status GetVariable(ServerContext *context, const VoidMessage *in_var, Status GetVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override; VariableMessage *out_var) override;
Status Wait(ServerContext *context, const VoidMessage *in_var,
VoidMessage *out_var) override;
void Start();
void Done();
void SetScope(framework::Scope *scope) { scope_ = scope; };
const TensorWithName Get() { return this->var_recv_queue_.Pop(); } const TensorWithName Get() { return this->var_recv_queue_.Pop(); }
void Push(const TensorWithName &var) { this->var_return_queue_.Push(var); }
private: private:
// received variable from RPC, operators fetch variable from this queue. // received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<TensorWithName> var_recv_queue_; SimpleBlockQueue<TensorWithName> var_recv_queue_;
// calculated variable should push to this queue. framework::Scope *scope_;
SimpleBlockQueue<TensorWithName> var_return_queue_; // condition of the sub program
std::mutex mutex_;
bool done_;
std::condition_variable condition_;
}; };
// RPCClient is a class to send tensors to pserver sub-network // RPCClient is a class to send tensors to pserver sub-network
...@@ -78,7 +80,7 @@ class RPCClient { ...@@ -78,7 +80,7 @@ class RPCClient {
: stub_(SendRecvService::NewStub(channel)) {} : stub_(SendRecvService::NewStub(channel)) {}
bool SendVariable(const framework::Scope &scope, const std::string &inname); bool SendVariable(const framework::Scope &scope, const std::string &inname);
bool GetVariable(const framework::Scope &scope); bool GetVariable(const framework::Scope &scope, const std::string &outname);
private: private:
std::unique_ptr<SendRecvService::Stub> stub_; std::unique_ptr<SendRecvService::Stub> stub_;
......
...@@ -76,12 +76,14 @@ class RecvOp : public framework::OperatorBase { ...@@ -76,12 +76,14 @@ class RecvOp : public framework::OperatorBase {
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
// FIXME(typhoonzero): no new scopes for every run. // FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
rpc_service_.SetScope(&recv_scope);
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers"); auto trainer_count = Attr<int>("Trainers");
size_t param_count = param_list.size(); size_t param_count = param_list.size();
// TODO(typhoonzero): change this to a while_op for every cluster-batch. // TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) { while (true) {
rpc_service_.Start();
// Get from multiple trainers, we don't care about order in which // Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient. // the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) { for (size_t i = 0; i < param_count * trainer_count; ++i) {
...@@ -125,13 +127,13 @@ class RecvOp : public framework::OperatorBase { ...@@ -125,13 +127,13 @@ class RecvOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what(); LOG(ERROR) << "run sub program error " << e.what();
} }
for (size_t i = 0; i < param_count; ++i) { // for (size_t i = 0; i < param_count; ++i) {
auto *out_var = recv_scope.FindVar(param_list[i]); // auto *out_var = recv_scope.FindVar(param_list[i]);
detail::TensorWithName out; // detail::TensorWithName out;
out.first = param_list[i]; // out.first = param_list[i];
out.second = out_var->Get<framework::LoDTensor>(); // out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out); // rpc_service_->Push(out);
} // }
} // while(true) } // while(true)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册