提交 1e549563 编写于 作者: T typhoonzero

multi trainers

上级 e6079390
......@@ -33,21 +33,40 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
}
Status SendRecvServerImpl::GetVariable(ServerContext *context,
const VoidMessage *in_var,
const VariableMessage *in_var,
VariableMessage *out_var) {
// Block util the sub graph is done.
auto out_tensor_with_name = var_return_queue_.Pop();
std::string get_var_name = in_var->varname();
auto *var = scope_->FindVar(get_var_name);
auto tensor = var->Get<framework::LoDTensor>();
std::ostringstream oss;
framework::SerializeToStream(oss, out_tensor_with_name.second,
platform::CPUDeviceContext());
framework::SerializeToStream(oss, tensor, platform::CPUDeviceContext());
std::string *varname = out_var->mutable_varname();
*varname = out_tensor_with_name.first;
*varname = get_var_name;
std::string *serialized = out_var->mutable_serialized();
*serialized = oss.str();
return Status::OK;
}
Status SendRecvServerImpl::Wait(ServerContext *context,
const VoidMessage *in_var,
VoidMessage *out_var) {
std::unique_lock<std::mutex> lock(this->mutex_);
condition_.wait(lock, [=] { return this->done_ == true; });
return Status::OK;
}
void SendRecvServerImpl::Start() {
std::unique_lock<std::mutex> lock(this->mutex_);
done_ = false;
}
void SendRecvServerImpl::Done() {
std::unique_lock<std::mutex> lock(this->mutex_);
done_ = true;
condition_.notify_all();
}
} // namespace detail
} // namespace operators
} // namespace paddle
......@@ -43,19 +43,20 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
return true;
}
bool RPCClient::GetVariable(const framework::Scope& scope) {
bool RPCClient::GetVariable(const framework::Scope& scope,
const std::string& outname) {
ClientContext context;
VariableMessage msg;
VoidMessage void_msg;
VariableMessage call_msg, ret_msg;
call_msg.set_varname(outname);
auto ctx = platform::CPUDeviceContext();
Status status = stub_->GetVariable(&context, void_msg, &msg);
Status status = stub_->GetVariable(&context, call_msg, &ret_msg);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}
std::istringstream iss(msg.serialized());
auto outname = msg.varname();
std::istringstream iss(ret_msg.serialized());
framework::LoDTensor ret_tensor;
framework::DeserializeFromStream(iss, &ret_tensor);
auto* outvar = scope.FindVar(outname);
......
......@@ -22,7 +22,9 @@ service SendRecvService {
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VoidMessage) returns (VariableMessage) {}
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// wait for one execution of the program
rpc Wait(VoidMessage) returns (VoidMessage) {}
}
// VariableMessage is serialized paddle variable message.
......
......@@ -20,10 +20,6 @@
#include "paddle/framework/selected_rows.h"
#include "paddle/operators/detail/simple_block_queue.h"
// #include <grpc++/channel.h>
// #include <grpc++/client_context.h>
// #include <grpc++/create_channel.h>
// #include <grpc++/security/credentials.h>
#include "paddle/operators/detail/send_recv.grpc.pb.h"
#include "paddle/operators/detail/send_recv.pb.h"
......@@ -56,18 +52,24 @@ class SendRecvServerImpl final : public SendRecvService::Service {
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VoidMessage *out_var) override;
Status GetVariable(ServerContext *context, const VoidMessage *in_var,
Status GetVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override;
Status Wait(ServerContext *context, const VoidMessage *in_var,
VoidMessage *out_var) override;
void Start();
void Done();
void SetScope(framework::Scope *scope) { scope_ = scope; };
const TensorWithName Get() { return this->var_recv_queue_.Pop(); }
void Push(const TensorWithName &var) { this->var_return_queue_.Push(var); }
private:
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<TensorWithName> var_recv_queue_;
// calculated variable should push to this queue.
SimpleBlockQueue<TensorWithName> var_return_queue_;
framework::Scope *scope_;
// condition of the sub program
std::mutex mutex_;
bool done_;
std::condition_variable condition_;
};
// RPCClient is a class to send tensors to pserver sub-network
......@@ -78,7 +80,7 @@ class RPCClient {
: stub_(SendRecvService::NewStub(channel)) {}
bool SendVariable(const framework::Scope &scope, const std::string &inname);
bool GetVariable(const framework::Scope &scope);
bool GetVariable(const framework::Scope &scope, const std::string &outname);
private:
std::unique_ptr<SendRecvService::Stub> stub_;
......
......@@ -76,12 +76,14 @@ class RecvOp : public framework::OperatorBase {
const platform::DeviceContext &dev_ctx) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
rpc_service_.SetScope(&recv_scope);
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto trainer_count = Attr<int>("Trainers");
size_t param_count = param_list.size();
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
while (true) {
rpc_service_.Start();
// Get from multiple trainers, we don't care about order in which
// the gradient arrives, just add suffix 0~n then average the gradient.
for (size_t i = 0; i < param_count * trainer_count; ++i) {
......@@ -125,13 +127,13 @@ class RecvOp : public framework::OperatorBase {
LOG(ERROR) << "run sub program error " << e.what();
}
for (size_t i = 0; i < param_count; ++i) {
auto *out_var = recv_scope.FindVar(param_list[i]);
detail::TensorWithName out;
out.first = param_list[i];
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out);
}
// for (size_t i = 0; i < param_count; ++i) {
// auto *out_var = recv_scope.FindVar(param_list[i]);
// detail::TensorWithName out;
// out.first = param_list[i];
// out.second = out_var->Get<framework::LoDTensor>();
// rpc_service_->Push(out);
// }
} // while(true)
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册