未验证 提交 4fb7cc7f 编写于 作者: G gongweibao 提交者: GitHub

Move sync_mode device ctx from grpc server (#10881)

上级 5870a6b4
...@@ -49,7 +49,7 @@ def parse_args(): ...@@ -49,7 +49,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--fluid', default=1, type=int, help='whether is fluid job') '--fluid', default=1, type=int, help='whether is fluid job')
parser.add_argument( parser.add_argument(
'--rdma', action='store_ture', help='whether mount rdma libs') '--rdma', action='store_true', help='whether mount rdma libs')
parser.add_argument( parser.add_argument(
'--disttype', '--disttype',
default="pserver", default="pserver",
......
...@@ -21,7 +21,10 @@ limitations under the License. */ ...@@ -21,7 +21,10 @@ limitations under the License. */
#include <deque> #include <deque>
#include <stack> #include <stack>
#include <string>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/graph_traits.h" #include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
......
...@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) { ...@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
LOG(INFO) << graph.nodes.size(); LOG(INFO) << graph.nodes.size();
} }
} // analysis }; // namespace analysis
} // inference }; // namespace inference
} // paddle }; // namespace paddle
...@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h" #include "paddle/fluid/inference/analysis/pass.h"
......
...@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) { ...@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
LOG(INFO) << '\n' << graph.DotString(); LOG(INFO) << '\n' << graph.DotString();
} }
} // analysis } // namespace analysis
} // inference } // namespace inference
} // paddle } // namespace paddle
...@@ -50,7 +50,7 @@ struct DataTypeNamer { ...@@ -50,7 +50,7 @@ struct DataTypeNamer {
return dic_.at(x); return dic_.at(x);
} }
const std::string &repr(size_t &hash) const { const std::string &repr(size_t &hash) const { // NOLINT
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation"); PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation");
return dic_.at(hash); return dic_.at(hash);
} }
...@@ -62,7 +62,9 @@ struct DataTypeNamer { ...@@ -62,7 +62,9 @@ struct DataTypeNamer {
SET_TYPE(float); SET_TYPE(float);
} }
std::unordered_map<decltype(typeid(int).hash_code()), std::string> dic_; std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT
std::string>
dic_;
}; };
#undef SET_TYPE #undef SET_TYPE
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h> #include <glog/logging.h>
#include <iosfwd> #include <iosfwd>
#include <string>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h" #include "paddle/fluid/inference/analysis/node.h"
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h" #include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h" #include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
......
...@@ -19,6 +19,9 @@ limitations under the License. */ ...@@ -19,6 +19,9 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
...@@ -58,7 +61,7 @@ class TRTConvertValidation { ...@@ -58,7 +61,7 @@ class TRTConvertValidation {
public: public:
TRTConvertValidation() = delete; TRTConvertValidation() = delete;
TRTConvertValidation(int batch_size, int workspace_size = 1 << 10) { explicit TRTConvertValidation(int batch_size, int workspace_size = 1024) {
// create engine. // create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_)); engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
engine_->InitNetwork(); engine_->InitNetwork();
......
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows) request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
......
...@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { ...@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
} }
bool RPCClient::Wait() { bool RPCClient::Wait() {
VLOG(3) << "RPCClient begin Wait()"
<< " req_count_:" << req_count_;
if (req_count_ <= 0) { if (req_count_ <= 0) {
return true; return true;
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /*Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,19 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <limits> #include <limits>
#include <string> #include <string>
using ::grpc::ServerAsyncResponseWriter; #include "paddle/fluid/operators/detail/grpc_server.h"
DEFINE_int32(rpc_server_handle_send_threads, 20, using ::grpc::ServerAsyncResponseWriter;
"Number of threads used to handle send at rpc server.");
DEFINE_int32(rpc_server_handle_get_threads, 20,
"Number of threads used to handle get at rpc server.");
DEFINE_int32(rpc_server_handle_prefetch_threads, 1,
"Number of threads used to handle prefetch at rpc server.");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH }; ...@@ -36,49 +29,40 @@ enum CallStatus { PROCESS = 0, FINISH };
class RequestBase { class RequestBase {
public: public:
explicit RequestBase(GrpcService::AsyncService* service, explicit RequestBase(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq,
const platform::DeviceContext* dev_ctx) RequestHandler* request_handler, int req_id)
: service_(service), : service_(service),
cq_(cq), cq_(cq),
sync_mode_(sync_mode),
status_(PROCESS), status_(PROCESS),
dev_ctx_(dev_ctx) { request_handler_(request_handler),
req_id_(req_id) {
PADDLE_ENFORCE(cq_); PADDLE_ENFORCE(cq_);
} }
virtual ~RequestBase() {} virtual ~RequestBase() {}
virtual void Process() { assert(false); } virtual void Process() = 0;
CallStatus Status() { return status_; } CallStatus Status() { return status_; }
void SetStatus(CallStatus status) { status_ = status; } void SetStatus(CallStatus status) { status_ = status; }
virtual std::string GetReqName() { virtual std::string GetReqName() = 0;
assert(false);
return "";
}
protected: protected:
::grpc::ServerContext ctx_; ::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_; GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_; ::grpc::ServerCompletionQueue* cq_;
const bool sync_mode_;
CallStatus status_; CallStatus status_;
const platform::DeviceContext* dev_ctx_; RequestHandler* request_handler_;
int req_id_;
}; };
class RequestSend final : public RequestBase { class RequestSend final : public RequestBase {
public: public:
explicit RequestSend(GrpcService::AsyncService* service, explicit RequestSend(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, ReceivedQueue* queue, RequestHandler* request_handler, int req_id)
const platform::DeviceContext* dev_ctx, int req_id) : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
: RequestBase(service, cq, sync_mode, dev_ctx), request_.reset(new VariableResponse(request_handler->scope(),
queue_(queue), request_handler->dev_ctx(),
responder_(&ctx_), !request_handler->sync_mode()));
req_id_(req_id) {
if (sync_mode_) {
request_.reset(new VariableResponse(scope, dev_ctx_, false));
} else {
request_.reset(new VariableResponse(scope, dev_ctx_, true));
}
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable); int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
...@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase { ...@@ -87,12 +71,17 @@ class RequestSend final : public RequestBase {
virtual ~RequestSend() {} virtual ~RequestSend() {}
virtual std::string GetReqName() { return request_->Varname(); } std::string GetReqName() override { return request_->Varname(); }
void Process() override {
std::string varname = GetReqName();
VLOG(3) << "RequestSend var_name:" << varname;
virtual void Process() { auto scope = request_->GetMutableLocalScope();
std::string var_name = GetReqName(); auto invar = request_->GetVar();
VLOG(3) << "RequestSend " << var_name; framework::Variable* outvar = nullptr;
queue_->Push(std::make_pair(var_name, request_));
request_handler_->Handle(varname, scope, invar, &outvar);
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
...@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase { ...@@ -102,105 +91,85 @@ class RequestSend final : public RequestBase {
protected: protected:
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::shared_ptr<VariableResponse> request_; std::shared_ptr<VariableResponse> request_;
ReceivedQueue* queue_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
int req_id_;
}; };
class RequestGet final : public RequestBase { class RequestGet final : public RequestBase {
public: public:
explicit RequestGet(GrpcService::AsyncService* service, explicit RequestGet(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, RequestHandler* request_handler, int req_id)
const platform::DeviceContext* dev_ctx, : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
framework::BlockingQueue<MessageWithName>* queue,
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_),
scope_(scope),
queue_(queue),
req_id_(req_id) {
auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable); auto method_id = static_cast<int>(detail::GrpcMethod::kGetVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_, method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
} }
virtual ~RequestGet() {} virtual ~RequestGet() {}
virtual std::string GetReqName() { return request_.varname(); } std::string GetReqName() override { return request_.varname(); }
virtual void Process() { void Process() override {
// proc request. // proc request.
std::string var_name = request_.varname(); std::string varname = request_.varname();
VLOG(3) << "RequestGet " << var_name; VLOG(3) << "RequestGet " << varname;
auto* var = scope_->FindVar(var_name);
auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr;
if (var_name != FETCH_BARRIER_MESSAGE) { request_handler_->Handle(varname, scope, invar, &outvar);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_);
if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_);
} }
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
if (var_name == FETCH_BARRIER_MESSAGE) {
sendrecv::VariableMessage msg;
MessageWithName msg_with_name = std::make_pair(var_name, msg);
queue_->Push(msg_with_name);
}
} }
protected: protected:
sendrecv::VariableMessage request_; sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_;
framework::BlockingQueue<MessageWithName>* queue_;
int req_id_;
}; };
class RequestPrefetch final : public RequestBase { class RequestPrefetch final : public RequestBase {
public: public:
explicit RequestPrefetch(GrpcService::AsyncService* service, explicit RequestPrefetch(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq, bool sync_mode, ::grpc::ServerCompletionQueue* cq,
framework::Scope* scope, RequestHandler* request_handler, int req_id)
const platform::DeviceContext* dev_ctx, : RequestBase(service, cq, request_handler, req_id),
framework::Executor* executor,
framework::ProgramDesc* program,
framework::ExecutorPrepareContext* prefetch_ctx,
int req_id)
: RequestBase(service, cq, sync_mode, dev_ctx),
responder_(&ctx_), responder_(&ctx_),
scope_(scope), local_scope_(nullptr) {
executor_(executor), request_.reset(new VariableResponse(request_handler->scope(),
program_(program), request_handler->dev_ctx(), true));
prefetch_ctx_(prefetch_ctx),
req_id_(req_id) {
// prefetch always create a new sub scope
request_.reset(new VariableResponse(scope, dev_ctx_, true));
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable); int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
} }
virtual ~RequestPrefetch() {} virtual ~RequestPrefetch() {}
virtual std::string GetReqName() { return request_->Varname(); } std::string GetReqName() override { return request_->Varname(); }
virtual void Process() { void Process() override {
// prefetch process... // prefetch process...
std::string varname = request_->OutVarname();
VLOG(3) << "RequestPrefetch " << varname;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr;
std::string var_name = request_->OutVarname(); request_handler_->Handle(varname, scope, invar, &outvar);
VLOG(3) << "RequestPrefetch " << var_name;
auto var_desc = program_->Block(0).FindVar(var_name);
framework::Scope* local_scope = request_->GetMutableLocalScope();
auto* var = local_scope->FindVar(var_name);
InitializeVariable(var, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_, local_scope);
SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_);
status_ = FINISH; status_ = FINISH;
responder_.Finish(reply_, ::grpc::Status::OK, responder_.Finish(reply_, ::grpc::Status::OK,
...@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase { ...@@ -211,202 +180,169 @@ class RequestPrefetch final : public RequestBase {
std::shared_ptr<VariableResponse> request_; std::shared_ptr<VariableResponse> request_;
::grpc::ByteBuffer reply_; ::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
framework::Scope* scope_; framework::Scope* local_scope_;
framework::Executor* executor_;
framework::ProgramDesc* program_;
framework::ExecutorPrepareContext* prefetch_ctx_;
int req_id_;
}; };
void AsyncGRPCServer::WaitClientGet(int count) {
int fetch_barriers = 0;
while (fetch_barriers < count) {
auto msg = var_get_queue_.Pop();
if (msg.first == FETCH_BARRIER_MESSAGE) {
fetch_barriers++;
}
}
}
void AsyncGRPCServer::WaitServerReady() { void AsyncGRPCServer::WaitServerReady() {
VLOG(3) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
condition_ready_.wait(lock, [=] { return this->ready_ == 1; }); condition_ready_.wait(lock, [=] { return this->ready_ == 1; });
VLOG(3) << "AsyncGRPCServer WaitSeverReady";
} }
void AsyncGRPCServer::RunSyncUpdate() { void AsyncGRPCServer::StartServer() {
::grpc::ServerBuilder builder; ::grpc::ServerBuilder builder;
builder.AddListeningPort(address_, ::grpc::InsecureServerCredentials(), builder.AddListeningPort(bind_address_, ::grpc::InsecureServerCredentials(),
&selected_port_); &selected_port_);
builder.SetMaxSendMessageSize(std::numeric_limits<int>::max()); builder.SetMaxSendMessageSize(std::numeric_limits<int>::max());
builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max()); builder.SetMaxReceiveMessageSize(std::numeric_limits<int>::max());
builder.RegisterService(&service_); builder.RegisterService(&service_);
cq_send_ = builder.AddCompletionQueue(); for (auto t : rpc_call_map_) {
cq_get_ = builder.AddCompletionQueue(); rpc_cq_[t.first].reset(builder.AddCompletionQueue().release());
cq_prefetch_ = builder.AddCompletionQueue(); }
server_ = builder.BuildAndStart(); server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ LOG(INFO) << "Server listening on " << bind_address_
<< " selected port: " << selected_port_; << " selected port: " << selected_port_;
std::function<void(int)> send_register = std::bind( std::function<void(const std::string&, int)> f =
&AsyncGRPCServer::TryToRegisterNewSendOne, this, std::placeholders::_1); std::bind(&AsyncGRPCServer::TryToRegisterNewOne, this,
std::function<void(int)> get_register = std::bind( std::placeholders::_1, std::placeholders::_2);
&AsyncGRPCServer::TryToRegisterNewGetOne, this, std::placeholders::_1);
std::function<void(int)> prefetch_register =
std::bind(&AsyncGRPCServer::TryToRegisterNewPrefetchOne, this,
std::placeholders::_1);
for (int i = 0; i < kSendReqsBufSize; ++i) { for (auto& t : rpc_call_map_) {
TryToRegisterNewSendOne(i); auto& rpc_name = t.first;
} auto& cq = rpc_cq_[rpc_name];
for (int i = 0; i < kGetReqsBufSize; ++i) { auto threadnum = rpc_thread_num_[rpc_name];
TryToRegisterNewGetOne(i); auto& reqs = rpc_reqs_[rpc_name];
}
for (int i = 0; i < kPrefetchReqsBufSize; ++i) {
TryToRegisterNewPrefetchOne(i);
}
for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) { reqs.reserve(kRequestBufSize);
t_sends_.emplace_back(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, for (int i = 0; i < kRequestBufSize; i++) {
cq_send_.get(), "cq_send", send_register))); TryToRegisterNewOne(rpc_name, i);
} }
for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) {
t_gets_.emplace_back( for (int i = 0; i < threadnum; i++) {
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, rpc_threads_[rpc_name].emplace_back(new std::thread(std::bind(
cq_get_.get(), "cq_get", get_register))); &AsyncGRPCServer::HandleRequest, this, cq.get(), rpc_name, f)));
} VLOG(3) << t.first << " creates threads!";
for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { }
t_prefetchs_.emplace_back(new std::thread(
std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(),
"cq_prefetch", prefetch_register)));
} }
{ {
std::lock_guard<std::mutex> lock(this->mutex_ready_); std::lock_guard<std::mutex> lock(this->mutex_ready_);
ready_ = 1; ready_ = 1;
} }
condition_ready_.notify_all(); condition_ready_.notify_all();
// wait server // wait server
server_->Wait(); server_->Wait();
for (int i = 0; i < FLAGS_rpc_server_handle_send_threads; ++i) {
t_sends_[i]->join(); for (auto& t : rpc_threads_) {
} auto& threads = t.second;
for (int i = 0; i < FLAGS_rpc_server_handle_get_threads; ++i) { for (size_t i = 0; i < threads.size(); ++i) {
t_gets_[i]->join(); threads[i]->join();
} VLOG(3) << t.first << " threads ends!";
for (int i = 0; i < FLAGS_rpc_server_handle_prefetch_threads; ++i) { }
t_prefetchs_[i]->join();
} }
} }
void AsyncGRPCServer::ShutdownQueue() { void AsyncGRPCServer::ShutdownQueue() {
std::unique_lock<std::mutex> lock(cq_mutex_); for (auto& t : rpc_cq_) {
cq_send_->Shutdown(); t.second->Shutdown();
cq_get_->Shutdown(); VLOG(3) << t.first << " shutdown!";
cq_prefetch_->Shutdown(); }
} }
// This URL explains why shutdown is complicate: void AsyncGRPCServer::ShutDownImpl() {
void AsyncGRPCServer::ShutDown() { std::unique_lock<std::mutex> lock(cq_mutex_);
is_shut_down_ = true; is_shut_down_ = true;
ShutdownQueue(); ShutdownQueue();
VLOG(3) << "server_ shutdown!";
server_->Shutdown(); server_->Shutdown();
} }
void AsyncGRPCServer::TryToRegisterNewSendOne(int i) { void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
int req_id) {
std::unique_lock<std::mutex> lock(cq_mutex_); std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) { if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; VLOG(3) << "shutdown, do not TryToRegisterNewSendOne";
return; return;
} }
RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_,
scope_, &var_recv_queue_, dev_ctx_, i);
send_reqs_[i] = static_cast<RequestBase*>(send);
VLOG(4) << "Create RequestSend status:" << send->Status();
}
void AsyncGRPCServer::TryToRegisterNewGetOne(int req_id) { VLOG(4) << "register send rpc_name:" << rpc_name
std::unique_lock<std::mutex> lock(cq_mutex_); << ", handler:" << rpc_call_map_[kRequestSend];
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; auto& reqs = rpc_reqs_[rpc_name];
return; auto& handler = rpc_call_map_[rpc_name];
auto& cq = rpc_cq_[rpc_name];
RequestBase* b = nullptr;
if (rpc_name == kRequestSend) {
b = new RequestSend(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else {
PADDLE_ENFORCE(false, "not surpported rpc");
} }
RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_,
dev_ctx_, &var_get_queue_, req_id);
get_reqs_[req_id] = static_cast<RequestBase*>(get);
VLOG(4) << "Create RequestGet status:" << get->Status();
}
void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { reqs[req_id] = b;
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
VLOG(3) << "shutdown, do not TryToRegisterNewPrefetchOne";
return;
}
RequestPrefetch* prefetch = new RequestPrefetch(
&service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_,
program_, prefetch_ctx_.get(), req_id);
prefetch_reqs_[req_id] = static_cast<RequestBase*>(prefetch);
VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); VLOG(4) << "Create RequestSend status:" << b->Status();
} }
// FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest( void AsyncGRPCServer::HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& cq_name, ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(int)> TryToRegisterNewOne) { std::function<void(const std::string&, int)> TryToRegisterNewOne) {
void* tag = NULL; void* tag = NULL;
bool ok = false; bool ok = false;
while (true) { while (true) {
VLOG(3) << "HandleRequest for " << cq_name << " wait Next"; VLOG(3) << "HandleRequest " << rpc_name << " wait next";
if (!cq->Next(&tag, &ok)) { if (!cq->Next(&tag, &ok)) {
LOG(INFO) << cq_name << " CompletionQueue shutdown!"; LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!";
break; break;
} }
VLOG(3) << "HandleRequest for " << cq_name << " get Next";
int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
if (sync_mode_) { int req_id = static_cast<int>(reinterpret_cast<intptr_t>(tag));
// FIXME(typhoonzero): de-couple the barriers with recv_op VLOG(3) << "HandleRequest " << rpc_name << ", req_id:" << req_id
if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); << " get next";
if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0);
VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond";
}
auto& reqs = rpc_reqs_[rpc_name];
RequestBase* base = nullptr; RequestBase* base = nullptr;
{ {
std::lock_guard<std::mutex> l(cq_mutex_); PADDLE_ENFORCE(req_id >= 0 && req_id < kRequestBufSize);
if (cq_name == "cq_get") { std::unique_lock<std::mutex> lock(cq_mutex_);
base = get_reqs_[req_id]; base = reqs[req_id];
} else if (cq_name == "cq_send") {
base = send_reqs_[req_id];
} else if (cq_name == "cq_prefetch") {
base = prefetch_reqs_[req_id];
}
} }
// reference: // reference:
// https://github.com/tensorflow/tensorflow/issues/5596 // https://github.com/tensorflow/tensorflow/issues/5596
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM // https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I // https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) { if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name[" LOG(WARNING) << "completion queue:" << rpc_name
<< " recv no regular event:argument name["
<< base->GetReqName() << "]"; << base->GetReqName() << "]";
TryToRegisterNewOne(req_id); TryToRegisterNewOne(rpc_name, req_id);
delete base; delete base;
continue; continue;
} }
VLOG(3) << "queue id:" << rpc_name << ", req_id:" << req_id
<< ", status:" << base->Status();
switch (base->Status()) { switch (base->Status()) {
case PROCESS: { case PROCESS: {
base->Process(); base->Process();
VLOG(4) << cq_name << " PROCESS status:" << base->Status();
break; break;
} }
case FINISH: { case FINISH: {
TryToRegisterNewOne(req_id); TryToRegisterNewOne(rpc_name, req_id);
VLOG(4) << cq_name << " FINISH status:" << base->Status();
delete base; delete base;
break; break;
} }
...@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest( ...@@ -415,20 +351,6 @@ void AsyncGRPCServer::HandleRequest(
} }
} }
void AsyncGRPCServer::WaitCond(int cond) {
std::unique_lock<std::mutex> lock(this->barrier_mutex_);
barrier_condition_.wait(lock,
[=] { return this->barrier_cond_step_ == cond; });
}
void AsyncGRPCServer::SetCond(int cond) {
{
std::lock_guard<std::mutex> lock(this->barrier_mutex_);
barrier_cond_step_ = cond;
}
barrier_condition_.notify_all();
}
} // namespace detail } // namespace detail
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,8 @@ limitations under the License. */ ...@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <set>
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <utility> #include <utility>
...@@ -28,6 +30,8 @@ limitations under the License. */ ...@@ -28,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h" #include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h" #include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h" #include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
...@@ -37,106 +41,48 @@ namespace paddle { ...@@ -37,106 +41,48 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef framework::BlockingQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase; class RequestBase;
class AsyncGRPCServer final { class AsyncGRPCServer final : public RPCServer {
public: public:
explicit AsyncGRPCServer(const std::string &address, bool sync_mode) explicit AsyncGRPCServer(const std::string& address, int client_num)
: address_(address), sync_mode_(sync_mode), ready_(0) {} : RPCServer(address, client_num), ready_(0) {}
~AsyncGRPCServer() {}
void WaitServerReady();
void RunSyncUpdate();
// functions to sync server barrier status.
void WaitCond(int cond);
void SetCond(int cond);
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc *program) { program_ = program; }
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
}
int GetSelectedPort() const { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
void Push(const std::string &msg_name) { virtual ~AsyncGRPCServer() {}
this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr)); void WaitServerReady() override;
} void StartServer() override;
void ShutDown(); private:
void HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(const std::string&, int)> TryToRegisterNewOne);
protected: void TryToRegisterNewOne(const std::string& rpc_name, int req_id);
void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name,
std::function<void(int)> TryToRegisterNewOne);
void TryToRegisterNewSendOne(int req_id);
void TryToRegisterNewGetOne(int req_id);
void TryToRegisterNewPrefetchOne(int req_id);
void ShutdownQueue(); void ShutdownQueue();
void ShutDownImpl() override;
private: private:
static const int kSendReqsBufSize = 100; static const int kRequestBufSize = 100;
static const int kGetReqsBufSize = 100;
static const int kPrefetchReqsBufSize = 10;
std::mutex cq_mutex_; std::mutex cq_mutex_;
volatile bool is_shut_down_ = false; volatile bool is_shut_down_ = false;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
RequestBase *send_reqs_[kSendReqsBufSize];
RequestBase *get_reqs_[kGetReqsBufSize];
RequestBase *prefetch_reqs_[kPrefetchReqsBufSize];
GrpcService::AsyncService service_; GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_; std::unique_ptr<::grpc::Server> server_;
std::string address_;
const bool sync_mode_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
framework::BlockingQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_;
// condition of the sub program // condition of the sub program
std::mutex barrier_mutex_; std::mutex barrier_mutex_;
mutable int barrier_cond_step_; mutable int barrier_cond_step_;
std::condition_variable barrier_condition_; std::condition_variable barrier_condition_;
std::vector<std::unique_ptr<std::thread>> t_sends_;
std::vector<std::unique_ptr<std::thread>> t_gets_;
std::vector<std::unique_ptr<std::thread>> t_prefetchs_;
std::unique_ptr<std::thread> t_prefetch_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
std::mutex mutex_ready_; std::mutex mutex_ready_;
std::condition_variable condition_ready_; std::condition_variable condition_ready_;
int ready_; int ready_;
std::map<std::string, std::unique_ptr<::grpc::ServerCompletionQueue>> rpc_cq_;
std::map<std::string, std::vector<std::unique_ptr<std::thread>>> rpc_threads_;
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
}; };
}; // namespace detail }; // namespace detail
......
...@@ -24,13 +24,16 @@ limitations under the License. */ ...@@ -24,13 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace detail = paddle::operators::detail; namespace detail = paddle::operators::detail;
USE_OP(lookup_table); USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_; std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
...@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, ...@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
} }
} }
void StartServer(const std::string& endpoint) { void StartServer() {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
framework::ProgramDesc program; framework::ProgramDesc program;
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
...@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) { ...@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
auto prepared = exe.Prepare(program, block->ID()); auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10); InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program); g_req_handler->SetProgram(&program);
rpc_service_->SetPrefetchPreparedCtx(std::move(prepared)); g_req_handler->SetPrefetchPreparedCtx(std::move(prepared));
rpc_service_->SetDevCtx(&ctx); g_req_handler->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope); g_req_handler->SetScope(&scope);
rpc_service_->SetExecutor(&exe); g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
rpc_service_->RunSyncUpdate(); // FIXME(gongwb): don't use hard time.
sleep(10);
LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown();
server_thread.join();
} }
TEST(PREFETCH, DISABLED_CPU) { TEST(PREFETCH, CPU) {
// start up a server instance backend g_req_handler.reset(new detail::RequestPrefetchHandler(true));
std::thread server_thread(StartServer, "127.0.0.1:8889"); g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1));
sleep(2);
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
detail::RPCClient client;
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope; framework::Scope scope;
platform::CPUPlace place; platform::CPUPlace place;
platform::CPUDeviceContext ctx(place); platform::CPUDeviceContext ctx(place);
// create var on local scope {
int64_t rows_numel = 5; // create var on local scope
InitTensorsOnClient(&scope, &place, rows_numel); int64_t rows_numel = 5;
std::string in_var_name("ids"); InitTensorsOnClient(&scope, &place, rows_numel);
std::string out_var_name("out"); std::string in_var_name("ids");
std::string out_var_name("out");
auto client = detail::RPCClient::GetInstance();
client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
out_var_name); client.Wait();
client->Wait(); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto var = scope.Var(out_var_name); auto ptr = value.mutable_data<float>(place);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place); for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
rpc_service_->ShutDown(); }
server_thread.join();
rpc_service_.reset(nullptr);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
} }
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
} }
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
class RPCServer;
class RequestHandler {
public:
explicit RequestHandler(bool sync_mode)
: sync_mode_(sync_mode),
dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
program_(nullptr),
rpc_server_(nullptr) {}
virtual ~RequestHandler() {}
// Set attributes.
void SetScope(framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
}
// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
grad_to_prepared_ctx_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes.
bool sync_mode() { return sync_mode_; }
framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ExecutorPrepareContext* prefetch_ctx() {
return prefetch_ctx_.get();
}
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var,
framework::Variable** outvar) = 0;
protected:
const bool sync_mode_;
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
// Used for async.
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable*> sparse_vars_;
RPCServer* rpc_server_;
std::mutex sparse_var_mutex_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace paddle {
namespace operators {
namespace detail {
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestSendHandler:" << varname;
// Async
if (!sync_mode_) {
try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope);
} catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true;
}
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else {
VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) {
rpc_server_->WaitCond(kRequestSend);
}
if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname;
PADDLE_THROW("sync: Can not find server side var");
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(sparse_var_mutex_);
sparse_vars_.push_back(invar);
}
}
return true;
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) {
if (sync_mode_) {
rpc_server_->WaitCond(kRequestGet);
}
*outvar = scope_->FindVar(varname);
return true;
}
// FETCH_BARRIER_MESSAGE
if (sync_mode_) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(varname);
*outvar = scope->FindVar(varname);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
return true;
}
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
class RequestSendHandler final : public RequestHandler {
public:
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
class RequestGetHandler final : public RequestHandler {
public:
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace paddle {
namespace operators {
namespace detail {
void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown ";
ShutDownImpl();
exit_flag_ = true;
barrier_cond_.notify_all();
rpc_cond_.notify_all();
}
void RPCServer::SavePort() const {
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_;
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [=] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
});
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
}
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
}
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
<< ", barrier_count:" << b << ", fan_in" << client_num_;
if (b >= client_num_) {
barrier_cond_.notify_all();
}
}
void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
for (auto& t : barrier_counter_) {
t.second = 0;
}
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) {
rpc_call_map_[rpc_name] = handler;
rpc_thread_num_[rpc_name] = thread_num;
static int cond = -1;
rpc_cond_map_[rpc_name] = ++cond;
VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler
<< ", cond:" << rpc_cond_map_[rpc_name];
}
void RPCServer::SetCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer SetCond " << rpc_name;
{
std::unique_lock<std::mutex> lock(mutex_);
cur_cond_ = rpc_cond_map_[rpc_name];
}
rpc_cond_.notify_all();
}
void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer WaitCond " << rpc_name;
int cond = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
cond = rpc_cond_map_[rpc_name];
}
std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait(
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
}
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
namespace paddle {
namespace operators {
namespace detail {
class RPCServer {
public:
explicit RPCServer(const std::string& address, int client_num)
: cur_cond_(0),
bind_address_(address),
exit_flag_(false),
selected_port_(0),
client_num_(client_num) {}
virtual ~RPCServer() {}
virtual void StartServer() = 0;
virtual void WaitServerReady() = 0;
void ShutDown();
bool IsExit() { return exit_flag_.load(); }
int GetSelectedPort() const { return selected_port_; }
void SavePort() const;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5);
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void WaitBarrier(const std::string& rpc_name);
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter();
protected:
virtual void ShutDownImpl() = 0;
private:
std::mutex mutex_;
std::unordered_map<std::string, int> barrier_counter_;
std::condition_variable barrier_cond_;
std::unordered_map<std::string, int> rpc_cond_map_;
std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_;
protected:
std::string bind_address_;
std::atomic<int> exit_flag_;
int selected_port_;
const int client_num_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
friend class RequestHandler;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
...@@ -67,8 +67,8 @@ class VariableResponse { ...@@ -67,8 +67,8 @@ class VariableResponse {
framework::Scope* GetMutableLocalScope() const { return local_scope_; } framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() { return meta_.varname(); } inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); } inline std::string OutVarname() const { return meta_.out_varname(); }
// should call parse first. // should call parse first.
framework::Variable* GetVar() { framework::Variable* GetVar() {
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
namespace paddle { namespace paddle {
...@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
detail::AsyncGRPCServer rpc_service(endpoint, true); detail::RequestSendHandler rpc_h(true);
detail::AsyncGRPCServer rpc_service(endpoint, 1);
rpc_service.RegisterRPC(detail::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(&rpc_service);
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
rpc_service.SetScope(scope); rpc_h.SetScope(scope);
rpc_service.SetDevCtx(&dev_ctx); rpc_h.SetDevCtx(&dev_ctx);
rpc_service.SetProgram(&empty_program); rpc_h.SetProgram(&empty_program);
rpc_service.SetExecutor(&executor); rpc_h.SetExecutor(&executor);
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service)); std::bind(&detail::AsyncGRPCServer::StartServer, &rpc_service));
rpc_service.SetCond(0); rpc_service.SetCond(detail::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service.Get(); rpc_service.WaitBarrier(detail::kRequestSend);
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
rpc_service.ShutDown(); rpc_service.ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
......
...@@ -19,14 +19,16 @@ limitations under the License. */ ...@@ -19,14 +19,16 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <vector> #include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) { void RunServer(std::shared_ptr<detail::RPCServer> service) {
service->RunSyncUpdate(); service->StartServer();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void split(const std::string &str, char sep, static void split(const std::string &str, char sep,
...@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks( ...@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait(); for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
} }
std::atomic_int ListenAndServOp::selected_port_{0};
ListenAndServOp::ListenAndServOp(const std::string &type, ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs, const framework::VariableNameMap &outputs,
...@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type, ...@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
ListenAndServOp::~ListenAndServOp() { Stop(); } ListenAndServOp::~ListenAndServOp() { Stop(); }
void ListenAndServOp::Stop() { void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
rpc_service_->ShutDown(); rpc_service_->ShutDown();
server_thread_->join(); server_thread_->join();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
...@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() { ...@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
void ListenAndServOp::SavePort() const { void ListenAndServOp::SavePort() const {
// NOTE: default write file to /tmp/paddle.selected_port // NOTE: default write file to /tmp/paddle.selected_port
selected_port_ = rpc_service_->GetSelectedPort(); rpc_service_->SavePort();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_.load();
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void ListenAndServOp::WaitServerReady() {
while (selected_port_.load() == 0) {
}
} }
void ListenAndServOp::RunSyncLoop(framework::Executor *executor, void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const { framework::BlockDesc *prefetch_block) const {
auto fan_in = Attr<int>("Fanin");
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
...@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
optimize_prepared.begin(), optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr)); std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
bool exit_flag = false; rpc_service_->ResetBarrierCounter();
// Record received sparse variables, so that // Record received sparse variables, so that
// we could reset those after execute optimize program // we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars; std::vector<framework::Variable *> sparse_vars;
while (!exit_flag && !SignalHandler::IsProgramExit()) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0); rpc_service_->SetCond(detail::kRequestSend);
size_t recv_var_cnt = 0; rpc_service_->WaitBarrier(detail::kRequestSend);
int batch_barrier = 0;
while (batch_barrier != fan_in) { if (rpc_service_->IsExit()) {
const detail::ReceivedMessage v = rpc_service_->Get(); LOG(WARNING) << "get exit!rpc_processor break!";
auto recv_var_name = v.first; rpc_service_->SetCond(detail::kRequestGet);
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
} else {
VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
}
}
}
if (exit_flag) {
rpc_service_->SetCond(1);
rpc_service_->ShutDown();
break; break;
} }
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work. // and this will still work.
// The optimize blocks which have the same parent ID would run parallel // The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future // TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent(); int32_t last_parent_blkid = program->Block(1).Parent();
...@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, ...@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear(); var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
} }
rpc_service_->SetCond(1); rpc_service_->SetCond(detail::kRequestGet);
// FIXME(typhoonzero): use another condition to sync wait clients get. rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->WaitClientGet(fan_in); rpc_service_->ResetBarrierCounter();
sparse_vars.clear();
} // while(true) } // while(true)
} }
static void AsyncUpdateThread(
const std::string &var_name, const bool &exit_flag,
const std::shared_ptr<detail::ReceivedQueue> &queue,
framework::Executor *executor,
framework::ExecutorPrepareContext *prepared) {
VLOG(3) << "update thread for " << var_name << " started";
while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = queue->Pop();
if (SignalHandler::IsProgramExit()) {
VLOG(3) << "update thread for " << var_name << " exit";
break;
}
auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
auto fs = framework::Async([var_name, &executor, &v, prepared] {
try {
executor->RunPreparedContext(prepared,
v.second->GetMutableLocalScope());
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
fs.wait();
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const { framework::ProgramDesc *program) const {
VLOG(3) << "RunAsyncLoop in"; VLOG(3) << "RunAsyncLoop in";
// grad name to block id // grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id; std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
grad_to_queue;
auto grad_to_block_id_str = auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id"); Attr<std::vector<std::string>>("grad_to_block_id");
...@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id; grad_to_block_id[pieces[0]] = block_id;
std::shared_ptr<detail::ReceivedQueue> queue =
std::make_shared<detail::ReceivedQueue>();
grad_to_queue[pieces[0]] = queue;
// record blocking queue in SignalHandler
SignalHandler::RegisterBlockingQueue(queue);
id_to_grad[block_id] = pieces[0]; id_to_grad[block_id] = pieces[0];
} }
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
...@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i]; grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
} }
bool exit_flag = false; request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
VLOG(3) << "start async optimize threads";
std::vector<std::future<void>> fs;
for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
std::string grad_name = iter->first;
VLOG(3) << "create async update thread for " << grad_name;
fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
&grad_to_queue, &grad_to_prepared_ctx]() {
AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
executor, grad_to_prepared_ctx[grad_name].get());
}));
}
VLOG(3) << "RunAsyncLoop into while"; VLOG(3) << "RunAsyncLoop into while";
while (!exit_flag && !SignalHandler::IsProgramExit()) { while (true) {
const detail::ReceivedMessage v = rpc_service_->Get(); if (rpc_service_->IsExit()) {
auto recv_var_name = v.first; LOG(INFO) << "get exit!rpc_processor break!";
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break; break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
grad_to_queue[recv_var_name]->Push(v);
} }
if (exit_flag) { sleep(1);
rpc_service_->ShutDown();
break;
}
} // while(true) } // while(true)
} }
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::ExecutorPrepareContext *prefetch_ctx,
detail::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetPrefetchPreparedCtx(std::move(
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx)));
h->SetRPCServer(rpc_server);
}
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const { const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer. // Mark this as PS that it should decide profiling by listening from trainer.
...@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode"); bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin");
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode)); LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock); auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock); auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program(); auto *program = optimize_block->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
// prepare rpc_service
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(program);
rpc_service_->SetExecutor(&executor);
// prepare for prefetch // prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID(); VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID()); auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared));
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, prefetch_prepared.release(),
rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
...@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGTERM, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit);
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
static_cast<int>(::getpid()));
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block); RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
...@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
bool SignalHandler::program_exit_flag_ = false;
SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{};
void SignalHandler::StopAndExit(int signal_num) { void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit"; VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
exit(0);
program_exit_flag_ = true;
// awake all blocking queues
for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin();
iter != blocking_queue_set_.end(); iter++) {
iter->get()->Push(
std::make_pair(std::string(LISTEN_TERMINATE_MESSAGE), nullptr));
}
exit(EXIT_SUCCESS);
}
void SignalHandler::RegisterBlockingQueue(BlockingQueue &queue) {
blocking_queue_set_.insert(queue);
} }
} // namespace operators } // namespace operators
......
...@@ -23,7 +23,8 @@ limitations under the License. */ ...@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h" #include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -31,7 +32,7 @@ namespace operators { ...@@ -31,7 +32,7 @@ namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock"; constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock"; constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service); void RunServer(std::shared_ptr<detail::RPCServer> service);
class ListenAndServOp : public framework::OperatorBase { class ListenAndServOp : public framework::OperatorBase {
public: public:
...@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
void SavePort() const; void SavePort() const;
void WaitServerReady(); int GetSelectedPort() { return rpc_service_->GetSelectedPort(); }
int GetSelectedPort() { return selected_port_; }
void Stop() override; void Stop() override;
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override; const platform::Place& dev_place) const override;
static void ResetPort() { selected_port_ = 0; }
protected: protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_; mutable std::shared_ptr<detail::RPCServer> rpc_service_;
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
// FIXME(wuyi): it's static so that the operator can be cloned.
static std::atomic_int selected_port_;
}; };
class SignalHandler { class SignalHandler {
public:
typedef std::shared_ptr<detail::ReceivedQueue> BlockingQueue;
typedef std::unordered_set<BlockingQueue> BlockingQueueSet;
public: public:
static void StopAndExit(int signal_num); static void StopAndExit(int signal_num);
static void RegisterBlockingQueue(BlockingQueue&);
static inline bool IsProgramExit() { return program_exit_flag_; }
private: private:
static bool program_exit_flag_;
static BlockingQueueSet blocking_queue_set_;
DISABLE_COPY_AND_ASSIGN(SignalHandler); DISABLE_COPY_AND_ASSIGN(SignalHandler);
}; };
......
...@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto rpc_client = detail::RPCClient::GetInstance(); auto rpc_client = detail::RPCClient::GetInstance();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) { if (sync_mode) {
......
...@@ -21,6 +21,8 @@ limitations under the License. */ ...@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h" #include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
...@@ -35,42 +37,44 @@ namespace m = paddle::operators::math; ...@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail; namespace detail = paddle::operators::detail;
namespace string = paddle::string; namespace string = paddle::string;
std::unique_ptr<detail::AsyncGRPCServer> rpc_service; std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
void StartServer(std::atomic<bool>* initialized) { void StartServer() {
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
scope.Var(NCCL_ID_VARNAME); scope.Var(NCCL_ID_VARNAME);
p::DeviceContextPool& pool = p::DeviceContextPool::Instance(); p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace()); auto& dev_ctx = *pool.Get(p::CPUPlace());
rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
f::ProgramDesc empty_program; f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace()); f::Executor executor(dev_ctx.GetPlace());
rpc_service->SetScope(&scope); g_req_handler->SetScope(&scope);
rpc_service->SetDevCtx(&dev_ctx); g_req_handler->SetDevCtx(&dev_ctx);
rpc_service->SetProgram(&empty_program); g_req_handler->SetProgram(&empty_program);
rpc_service->SetExecutor(&executor); g_req_handler->SetExecutor(&executor);
g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread( std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get())); std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
*initialized = true;
rpc_service->SetCond(0); g_rpc_service->SetCond(detail::kRequestSend);
auto recv = rpc_service->Get(); std::cout << "before WaitFanInOfSend" << std::endl;
g_rpc_service->WaitBarrier(detail::kRequestSend);
LOG(INFO) << "got nccl id and stop server..."; LOG(INFO) << "got nccl id and stop server...";
rpc_service->ShutDown(); g_rpc_service->ShutDown();
server_thread.join(); server_thread.join();
} }
TEST(SendNcclId, DISABLED_Normal) { TEST(SendNcclId, GrpcServer) {
std::atomic<bool> initialized{false}; g_req_handler.reset(new detail::RequestSendHandler(true));
std::thread server_thread(StartServer, &initialized); g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1));
while (!initialized) {
} std::thread server_thread(StartServer);
// wait server to start g_rpc_service->WaitServerReady();
// sleep(2);
rpc_service->WaitServerReady();
f::Scope scope; f::Scope scope;
p::CPUPlace place; p::CPUPlace place;
...@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) { ...@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
auto& dev_ctx = *pool.Get(p::CPUPlace()); auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var(NCCL_ID_VARNAME); auto var = scope.Var(NCCL_ID_VARNAME);
// var->SetType(f::proto::VarType_Type_RAW);
auto id = var->GetMutable<ncclUniqueId>(); auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id); p::dynload::ncclGetUniqueId(id);
int port = rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client; detail::RPCClient client;
LOG(INFO) << "connect to server" << ep;
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME); client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait(); client.Wait();
client.AsyncSendBatchBarrier(ep);
client.Wait();
server_thread.join(); server_thread.join();
auto* ptr = rpc_service.release(); g_rpc_service.reset(nullptr);
delete ptr; g_req_handler.reset(nullptr);
} }
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <stdio.h> #include <stdio.h>
#include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册