未验证 提交 4fb7cc7f 编写于 作者: G gongweibao 提交者: GitHub

Move sync_mode device ctx from grpc server (#10881)

上级 5870a6b4
......@@ -49,7 +49,7 @@ def parse_args():
parser.add_argument(
'--fluid', default=1, type=int, help='whether is fluid job')
parser.add_argument(
'--rdma', action='store_ture', help='whether mount rdma libs')
'--rdma', action='store_true', help='whether mount rdma libs')
parser.add_argument(
'--disttype',
default="pserver",
......
......@@ -21,7 +21,10 @@ limitations under the License. */
#include <deque>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/analysis/graph_traits.h"
#include "paddle/fluid/inference/analysis/node.h"
......
......@@ -44,6 +44,6 @@ TEST_F(DFG_Tester, Test) {
LOG(INFO) << graph.nodes.size();
}
} // analysis
} // inference
} // paddle
}; // namespace analysis
}; // namespace inference
}; // namespace paddle
......@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
#include <string>
#include <vector>
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
namespace paddle {
namespace inference {
namespace analysis {
......
......@@ -19,6 +19,8 @@
#pragma once
#include <string>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/pass.h"
......
......@@ -32,6 +32,6 @@ TEST_F(DFG_Tester, Init) {
LOG(INFO) << '\n' << graph.DotString();
}
} // analysis
} // inference
} // paddle
} // namespace analysis
} // namespace inference
} // namespace paddle
......@@ -50,7 +50,7 @@ struct DataTypeNamer {
return dic_.at(x);
}
const std::string &repr(size_t &hash) const {
const std::string &repr(size_t &hash) const { // NOLINT
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation");
return dic_.at(hash);
}
......@@ -62,7 +62,9 @@ struct DataTypeNamer {
SET_TYPE(float);
}
std::unordered_map<decltype(typeid(int).hash_code()), std::string> dic_;
std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT
std::string>
dic_;
};
#undef SET_TYPE
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <glog/logging.h>
#include <iosfwd>
#include <string>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
......
......@@ -18,6 +18,8 @@ limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/node.h"
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <string>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
#include "paddle/fluid/inference/analysis/fluid_to_data_flow_graph_pass.h"
......
......@@ -19,6 +19,9 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
......@@ -58,7 +61,7 @@ class TRTConvertValidation {
public:
TRTConvertValidation() = delete;
TRTConvertValidation(int batch_size, int workspace_size = 1 << 10) {
explicit TRTConvertValidation(int batch_size, int workspace_size = 1024) {
// create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
engine_->InitNetwork();
......
if(WITH_DISTRIBUTE)
grpc_library(sendrecvop_grpc SRCS bytebuffer_stream.cc sendrecvop_utils.cc grpc_client.cc
grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor selected_rows)
request_handler_impl.cc rpc_server.cc grpc_server.cc variable_response.cc PROTO send_recv.proto DEPS lod_tensor
selected_rows memory)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(serde_test.cc grpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(serde_test SRCS serde_test.cc variable_response.cc DEPS grpc++_unsecure grpc_unsecure gpr
......
......@@ -205,6 +205,8 @@ void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
}
bool RPCClient::Wait() {
VLOG(3) << "RPCClient begin Wait()"
<< " req_count_:" << req_count_;
if (req_count_ <= 0) {
return true;
}
......
......@@ -14,6 +14,8 @@ limitations under the License. */
#pragma once
#include <map>
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
......@@ -28,6 +30,8 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/grpc_service.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/send_recv.grpc.pb.h"
#include "paddle/fluid/operators/detail/send_recv.pb.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
......@@ -37,106 +41,48 @@ namespace paddle {
namespace operators {
namespace detail {
typedef std::pair<std::string, std::shared_ptr<VariableResponse>>
ReceivedMessage;
typedef framework::BlockingQueue<ReceivedMessage> ReceivedQueue;
typedef std::pair<std::string, sendrecv::VariableMessage> MessageWithName;
class RequestBase;
class AsyncGRPCServer final {
class AsyncGRPCServer final : public RPCServer {
public:
explicit AsyncGRPCServer(const std::string &address, bool sync_mode)
: address_(address), sync_mode_(sync_mode), ready_(0) {}
~AsyncGRPCServer() {}
void WaitServerReady();
void RunSyncUpdate();
// functions to sync server barrier status.
void WaitCond(int cond);
void SetCond(int cond);
void WaitClientGet(int count);
void SetScope(framework::Scope *scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc *program) { program_ = program; }
void SetExecutor(framework::Executor *executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
}
int GetSelectedPort() const { return selected_port_; }
const ReceivedMessage Get() { return this->var_recv_queue_.Pop(); }
explicit AsyncGRPCServer(const std::string& address, int client_num)
: RPCServer(address, client_num), ready_(0) {}
void Push(const std::string &msg_name) {
this->var_recv_queue_.Push(std::make_pair(msg_name, nullptr));
}
virtual ~AsyncGRPCServer() {}
void WaitServerReady() override;
void StartServer() override;
void ShutDown();
private:
void HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(const std::string&, int)> TryToRegisterNewOne);
protected:
void HandleRequest(::grpc::ServerCompletionQueue *cq,
const std::string &cq_name,
std::function<void(int)> TryToRegisterNewOne);
void TryToRegisterNewSendOne(int req_id);
void TryToRegisterNewGetOne(int req_id);
void TryToRegisterNewPrefetchOne(int req_id);
void TryToRegisterNewOne(const std::string& rpc_name, int req_id);
void ShutdownQueue();
void ShutDownImpl() override;
private:
static const int kSendReqsBufSize = 100;
static const int kGetReqsBufSize = 100;
static const int kPrefetchReqsBufSize = 10;
static const int kRequestBufSize = 100;
std::mutex cq_mutex_;
volatile bool is_shut_down_ = false;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_send_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_get_;
std::unique_ptr<::grpc::ServerCompletionQueue> cq_prefetch_;
RequestBase *send_reqs_[kSendReqsBufSize];
RequestBase *get_reqs_[kGetReqsBufSize];
RequestBase *prefetch_reqs_[kPrefetchReqsBufSize];
GrpcService::AsyncService service_;
std::unique_ptr<::grpc::Server> server_;
std::string address_;
const bool sync_mode_;
framework::Scope *scope_;
const platform::DeviceContext *dev_ctx_;
// received variable from RPC, operators fetch variable from this queue.
framework::BlockingQueue<MessageWithName> var_get_queue_;
// client send variable to this queue.
ReceivedQueue var_recv_queue_;
// condition of the sub program
std::mutex barrier_mutex_;
mutable int barrier_cond_step_;
std::condition_variable barrier_condition_;
std::vector<std::unique_ptr<std::thread>> t_sends_;
std::vector<std::unique_ptr<std::thread>> t_gets_;
std::vector<std::unique_ptr<std::thread>> t_prefetchs_;
std::unique_ptr<std::thread> t_prefetch_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
framework::ProgramDesc *program_;
framework::Executor *executor_;
int selected_port_;
std::mutex mutex_ready_;
std::condition_variable condition_ready_;
int ready_;
std::map<std::string, std::unique_ptr<::grpc::ServerCompletionQueue>> rpc_cq_;
std::map<std::string, std::vector<std::unique_ptr<std::thread>>> rpc_threads_;
std::map<std::string, std::vector<RequestBase*>> rpc_reqs_;
};
}; // namespace detail
......
......@@ -24,13 +24,16 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace detail = paddle::operators::detail;
USE_OP(lookup_table);
std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0);
......@@ -88,8 +91,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
}
}
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true));
void StartServer() {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
......@@ -99,42 +101,59 @@ void StartServer(const std::string& endpoint) {
auto prepared = exe.Prepare(program, block->ID());
InitTensorsOnServer(&scope, &place, 10);
rpc_service_->SetProgram(&program);
rpc_service_->SetPrefetchPreparedCtx(std::move(prepared));
rpc_service_->SetDevCtx(&ctx);
rpc_service_->SetScope(&scope);
rpc_service_->SetExecutor(&exe);
g_req_handler->SetProgram(&program);
g_req_handler->SetPrefetchPreparedCtx(std::move(prepared));
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(detail::kRequestPrefetch, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
rpc_service_->RunSyncUpdate();
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
// FIXME(gongwb): don't use hard time.
sleep(10);
LOG(INFO) << "got nccl id and stop server...";
g_rpc_service->ShutDown();
server_thread.join();
}
TEST(PREFETCH, DISABLED_CPU) {
// start up a server instance backend
std::thread server_thread(StartServer, "127.0.0.1:8889");
sleep(2);
TEST(PREFETCH, CPU) {
g_req_handler.reset(new detail::RequestPrefetchHandler(true));
g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
detail::RPCClient client;
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
{
// create var on local scope
int64_t rows_numel = 5;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("ids");
std::string out_var_name("out");
auto client = detail::RPCClient::GetInstance();
client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client->Wait();
client.AsyncPrefetchVariable(ep, ctx, scope, in_var_name, out_var_name);
client.Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value();
auto ptr = value.mutable_data<float>(place);
rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[0 + i * value.dims()[1]], static_cast<float>(i * 2));
}
}
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
class RPCServer;
class RequestHandler {
public:
explicit RequestHandler(bool sync_mode)
: sync_mode_(sync_mode),
dev_ctx_(nullptr),
executor_(nullptr),
scope_(nullptr),
program_(nullptr),
rpc_server_(nullptr) {}
virtual ~RequestHandler() {}
// Set attributes.
void SetScope(framework::Scope* scope) { scope_ = scope; }
void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
void SetProgram(framework::ProgramDesc* program) { program_ = program; }
void SetExecutor(framework::Executor* executor) { executor_ = executor; }
void SetPrefetchPreparedCtx(
std::unique_ptr<framework::ExecutorPrepareContext> prepared) {
prefetch_ctx_.reset(prepared.release());
}
// Used for async.
void SetGradToPreparedCtx(
std::unordered_map<
std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
grad_to_prepared_ctx_ = g;
}
void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }
// Get attributes.
bool sync_mode() { return sync_mode_; }
framework::Scope* scope() { return scope_; }
const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
framework::ExecutorPrepareContext* prefetch_ctx() {
return prefetch_ctx_.get();
}
framework::ProgramDesc* program() { return program_; }
framework::Executor* executor() { return executor_; }
std::vector<framework::Variable*>& sparse_vars() { return sparse_vars_; }
// This function processes user's rpc request.
// The implemention is in request_handler_impl.
// example:
// std::string varname = request_.varname();
//
// auto scope = request_handler_->scope();
// auto invar = scope->FindVar(varname);
// framework::Variable* outvar = nullptr;
//
// request_handler_->Handle(varname, scope, invar, &outvar);
// if (outvar) {
// SerializeToByteBuffer(varname, outvar,
// *request_handler_->dev_ctx(), &reply_);
// }
virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var,
framework::Variable** outvar) = 0;
protected:
const bool sync_mode_;
const platform::DeviceContext* dev_ctx_;
framework::Executor* executor_;
framework::Scope* scope_;
framework::ProgramDesc* program_;
std::unique_ptr<framework::ExecutorPrepareContext> prefetch_ctx_;
// Used for async.
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>*
grad_to_prepared_ctx_;
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable*> sparse_vars_;
RPCServer* rpc_server_;
std::mutex sparse_var_mutex_;
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/operators/detail/variable_response.h"
namespace paddle {
namespace operators {
namespace detail {
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestSendHandler:" << varname;
// Async
if (!sync_mode_) {
try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope);
} catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true;
}
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv batch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else {
VLOG(3) << "sync: received var_name: " << varname;
if (sync_mode_) {
rpc_server_->WaitCond(kRequestSend);
}
if (invar == nullptr) {
LOG(ERROR) << "sync: Can not find server side var: " << varname;
PADDLE_THROW("sync: Can not find server side var");
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(sparse_var_mutex_);
sparse_vars_.push_back(invar);
}
}
return true;
}
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestGetHandler:" << varname;
if (varname != FETCH_BARRIER_MESSAGE) {
if (sync_mode_) {
rpc_server_->WaitCond(kRequestGet);
}
*outvar = scope_->FindVar(varname);
return true;
}
// FETCH_BARRIER_MESSAGE
if (sync_mode_) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar) {
VLOG(4) << "RequestPrefetchHandler " << varname;
auto var_desc = program_->Block(0).FindVar(varname);
*outvar = scope->FindVar(varname);
InitializeVariable(*outvar, var_desc->GetType());
executor_->RunPreparedContext(prefetch_ctx_.get(), scope);
return true;
}
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <time.h>
#include <functional>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
namespace paddle {
namespace operators {
namespace detail {
class RequestSendHandler final : public RequestHandler {
public:
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
class RequestGetHandler final : public RequestHandler {
public:
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
class RequestPrefetchHandler final : public RequestHandler {
public:
explicit RequestPrefetchHandler(bool sync_mode) : RequestHandler(sync_mode) {}
virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar) override;
};
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include <iostream>
#include <limits>
#include <string>
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace paddle {
namespace operators {
namespace detail {
void RPCServer::ShutDown() {
LOG(INFO) << "RPCServer ShutDown ";
ShutDownImpl();
exit_flag_ = true;
barrier_cond_.notify_all();
rpc_cond_.notify_all();
}
void RPCServer::SavePort() const {
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_;
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void RPCServer::WaitBarrier(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [=] {
return (barrier_counter_[rpc_name] >= client_num_ || exit_flag_.load());
});
VLOG(3) << "batch_barrier_:" << barrier_counter_[rpc_name];
}
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
int b = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name];
}
VLOG(3) << "RPCServer IncreaseBatchBarrier " << rpc_name
<< ", barrier_count:" << b << ", fan_in" << client_num_;
if (b >= client_num_) {
barrier_cond_.notify_all();
}
}
void RPCServer::ResetBarrierCounter() {
VLOG(3) << "RPCServer ResetBarrierCounter ";
std::unique_lock<std::mutex> lock(mutex_);
for (auto& t : barrier_counter_) {
t.second = 0;
}
}
void RPCServer::RegisterRPC(const std::string& rpc_name,
RequestHandler* handler, int thread_num) {
rpc_call_map_[rpc_name] = handler;
rpc_thread_num_[rpc_name] = thread_num;
static int cond = -1;
rpc_cond_map_[rpc_name] = ++cond;
VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler
<< ", cond:" << rpc_cond_map_[rpc_name];
}
void RPCServer::SetCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer SetCond " << rpc_name;
{
std::unique_lock<std::mutex> lock(mutex_);
cur_cond_ = rpc_cond_map_[rpc_name];
}
rpc_cond_.notify_all();
}
void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(3) << "RPCServer WaitCond " << rpc_name;
int cond = 0;
{
std::unique_lock<std::mutex> lock(mutex_);
cond = rpc_cond_map_[rpc_name];
}
std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait(
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
}
} // namespace detail
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <set>
#include <string>
#include <thread> // NOLINT
#include <utility>
#include <vector>
#include "paddle/fluid/operators/detail/request_handler.h"
namespace paddle {
namespace operators {
namespace detail {
class RPCServer {
public:
explicit RPCServer(const std::string& address, int client_num)
: cur_cond_(0),
bind_address_(address),
exit_flag_(false),
selected_port_(0),
client_num_(client_num) {}
virtual ~RPCServer() {}
virtual void StartServer() = 0;
virtual void WaitServerReady() = 0;
void ShutDown();
bool IsExit() { return exit_flag_.load(); }
int GetSelectedPort() const { return selected_port_; }
void SavePort() const;
// RegisterRPC, register the rpc method name to a handler
// class, and auto generate a condition id for this call
// to be used for the barrier.
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5);
// Wait util all the clients have reached the barrier for one
// rpc method. This function should be called in the
// RequestHandler if you want to run the server/client in a
// synchronous mode.
void WaitBarrier(const std::string& rpc_name);
void SetCond(const std::string& rpc_name);
void WaitCond(const std::string& rpc_name);
void IncreaseBatchBarrier(const std::string rpc_name);
void ResetBarrierCounter();
protected:
virtual void ShutDownImpl() = 0;
private:
std::mutex mutex_;
std::unordered_map<std::string, int> barrier_counter_;
std::condition_variable barrier_cond_;
std::unordered_map<std::string, int> rpc_cond_map_;
std::atomic<int> cur_cond_;
std::condition_variable rpc_cond_;
protected:
std::string bind_address_;
std::atomic<int> exit_flag_;
int selected_port_;
const int client_num_;
std::unordered_map<std::string, RequestHandler*> rpc_call_map_;
std::unordered_map<std::string, int> rpc_thread_num_;
friend class RequestHandler;
};
}; // namespace detail
}; // namespace operators
}; // namespace paddle
......@@ -67,8 +67,8 @@ class VariableResponse {
framework::Scope* GetMutableLocalScope() const { return local_scope_; }
inline std::string Varname() { return meta_.varname(); }
inline std::string OutVarname() { return meta_.out_varname(); }
inline std::string Varname() const { return meta_.varname(); }
inline std::string OutVarname() const { return meta_.out_varname(); }
// should call parse first.
framework::Variable* GetVar() {
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
......@@ -75,19 +76,23 @@ class GenNCCLIdOp : public framework::OperatorBase {
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
detail::AsyncGRPCServer rpc_service(endpoint, true);
detail::RequestSendHandler rpc_h(true);
detail::AsyncGRPCServer rpc_service(endpoint, 1);
rpc_service.RegisterRPC(detail::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(&rpc_service);
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_service.SetScope(scope);
rpc_service.SetDevCtx(&dev_ctx);
rpc_service.SetProgram(&empty_program);
rpc_service.SetExecutor(&executor);
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, &rpc_service));
rpc_service.SetCond(0);
std::bind(&detail::AsyncGRPCServer::StartServer, &rpc_service));
rpc_service.SetCond(detail::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0...";
auto recv = rpc_service.Get();
rpc_service.WaitBarrier(detail::kRequestSend);
VLOG(3) << "got nccl id and stop server...";
rpc_service.ShutDown();
VLOG(3) << "rpc server stopped";
......
......@@ -19,14 +19,16 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
service->RunSyncUpdate();
void RunServer(std::shared_ptr<detail::RPCServer> service) {
service->StartServer();
VLOG(4) << "RunServer thread end";
}
static void split(const std::string &str, char sep,
......@@ -67,8 +69,6 @@ static void ParallelExecuteBlocks(
for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}
std::atomic_int ListenAndServOp::selected_port_{0};
ListenAndServOp::ListenAndServOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
......@@ -78,7 +78,6 @@ ListenAndServOp::ListenAndServOp(const std::string &type,
ListenAndServOp::~ListenAndServOp() { Stop(); }
void ListenAndServOp::Stop() {
rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
rpc_service_->ShutDown();
server_thread_->join();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
......@@ -87,26 +86,13 @@ void ListenAndServOp::Stop() {
void ListenAndServOp::SavePort() const {
// NOTE: default write file to /tmp/paddle.selected_port
selected_port_ = rpc_service_->GetSelectedPort();
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
std::ofstream port_file;
port_file.open(file_path);
port_file << selected_port_.load();
port_file.close();
VLOG(4) << "selected port written to " << file_path;
}
void ListenAndServOp::WaitServerReady() {
while (selected_port_.load() == 0) {
}
rpc_service_->SavePort();
}
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
framework::ProgramDesc *program,
framework::Scope *recv_scope,
framework::BlockDesc *prefetch_block) const {
auto fan_in = Attr<int>("Fanin");
size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks");
......@@ -121,49 +107,24 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
optimize_prepared.begin(),
std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
bool exit_flag = false;
rpc_service_->ResetBarrierCounter();
// Record received sparse variables, so that
// we could reset those after execute optimize program
std::vector<framework::Variable *> sparse_vars;
while (!exit_flag && !SignalHandler::IsProgramExit()) {
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
size_t recv_var_cnt = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
} else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
} else {
VLOG(3) << "received grad: " << recv_var_name;
recv_var_cnt++;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
if (var->IsType<framework::SelectedRows>()) {
sparse_vars.push_back(var);
}
}
}
if (exit_flag) {
rpc_service_->SetCond(1);
rpc_service_->ShutDown();
rpc_service_->SetCond(detail::kRequestSend);
rpc_service_->WaitBarrier(detail::kRequestSend);
if (rpc_service_->IsExit()) {
LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(detail::kRequestGet);
break;
}
// NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
// and this will still work.
// The optimize blocks which have the same parent ID would run parallel
// TODO(Yancey1989): need to use ParallelExecutor for future
int32_t last_parent_blkid = program->Block(1).Parent();
......@@ -194,52 +155,18 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
}
rpc_service_->SetCond(1);
// FIXME(typhoonzero): use another condition to sync wait clients get.
rpc_service_->WaitClientGet(fan_in);
sparse_vars.clear();
rpc_service_->SetCond(detail::kRequestGet);
rpc_service_->WaitBarrier(detail::kRequestGet);
rpc_service_->ResetBarrierCounter();
} // while(true)
}
static void AsyncUpdateThread(
const std::string &var_name, const bool &exit_flag,
const std::shared_ptr<detail::ReceivedQueue> &queue,
framework::Executor *executor,
framework::ExecutorPrepareContext *prepared) {
VLOG(3) << "update thread for " << var_name << " started";
while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = queue->Pop();
if (SignalHandler::IsProgramExit()) {
VLOG(3) << "update thread for " << var_name << " exit";
break;
}
auto recv_var_name = v.first;
VLOG(4) << "async update " << recv_var_name;
auto var = v.second->GetVar();
if (var == nullptr) {
LOG(ERROR) << "Can not find server side var: " << recv_var_name;
PADDLE_THROW("Can not find server side var");
}
auto fs = framework::Async([var_name, &executor, &v, prepared] {
try {
executor->RunPreparedContext(prepared,
v.second->GetMutableLocalScope());
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
});
fs.wait();
}
}
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const {
VLOG(3) << "RunAsyncLoop in";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
grad_to_queue;
auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id");
......@@ -249,13 +176,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id;
std::shared_ptr<detail::ReceivedQueue> queue =
std::make_shared<detail::ReceivedQueue>();
grad_to_queue[pieces[0]] = queue;
// record blocking queue in SignalHandler
SignalHandler::RegisterBlockingQueue(queue);
id_to_grad[block_id] = pieces[0];
}
size_t num_blocks = program->Size();
......@@ -274,39 +197,36 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
}
bool exit_flag = false;
request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
VLOG(3) << "start async optimize threads";
std::vector<std::future<void>> fs;
for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
std::string grad_name = iter->first;
VLOG(3) << "create async update thread for " << grad_name;
fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
&grad_to_queue, &grad_to_prepared_ctx]() {
AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
executor, grad_to_prepared_ctx[grad_name].get());
}));
}
VLOG(3) << "RunAsyncLoop into while";
while (!exit_flag && !SignalHandler::IsProgramExit()) {
const detail::ReceivedMessage v = rpc_service_->Get();
auto recv_var_name = v.first;
if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
while (true) {
if (rpc_service_->IsExit()) {
LOG(INFO) << "get exit!rpc_processor break!";
break;
} else {
VLOG(3) << "received grad: " << recv_var_name;
grad_to_queue[recv_var_name]->Push(v);
}
if (exit_flag) {
rpc_service_->ShutDown();
break;
}
sleep(1);
} // while(true)
}
static void FillRequestCtx(detail::RequestHandler *h, framework::Scope *scope,
platform::DeviceContext *dev_ctx,
framework::Executor *executor,
framework::ProgramDesc *program,
framework::ExecutorPrepareContext *prefetch_ctx,
detail::RPCServer *rpc_server) {
h->SetScope(scope);
h->SetDevCtx(dev_ctx);
h->SetExecutor(executor);
h->SetProgram(program);
h->SetPrefetchPreparedCtx(std::move(
std::unique_ptr<framework::ExecutorPrepareContext>(prefetch_ctx)));
h->SetRPCServer(rpc_server);
}
void ListenAndServOp::RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const {
// Mark this as PS that it should decide profiling by listening from trainer.
......@@ -316,27 +236,42 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode");
auto fan_in = Attr<int>("Fanin");
PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint");
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode));
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint;
// request_handler_.reset(new detail::GRPCRequestSendHandler(sync_mode));
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, fan_in));
request_send_handler_.reset(new detail::RequestSendHandler(sync_mode));
request_get_handler_.reset(new detail::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset(
new detail::RequestPrefetchHandler(sync_mode));
rpc_service_->RegisterRPC(detail::kRequestSend, request_send_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestGet, request_get_handler_.get());
rpc_service_->RegisterRPC(detail::kRequestPrefetch,
request_prefetch_handler_.get());
auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
auto *program = optimize_block->Program();
framework::Executor executor(dev_place);
// prepare rpc_service
rpc_service_->SetScope(&recv_scope);
rpc_service_->SetDevCtx(&dev_ctx);
rpc_service_->SetProgram(program);
rpc_service_->SetExecutor(&executor);
// prepare for prefetch
VLOG(3) << "prefetch block id is " << prefetch_block->ID();
auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared));
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope,
&dev_ctx, &executor, program, prefetch_prepared.release(),
rpc_service_.get());
f(request_send_handler_.get());
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
......@@ -348,8 +283,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGTERM, SignalHandler::StopAndExit);
// Write to a file of server selected port for python use.
std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
static_cast<int>(::getpid()));
SavePort();
if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
......@@ -385,27 +318,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
bool SignalHandler::program_exit_flag_ = false;
SignalHandler::BlockingQueueSet SignalHandler::blocking_queue_set_{};
void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
program_exit_flag_ = true;
// awake all blocking queues
for (BlockingQueueSet::iterator iter = blocking_queue_set_.begin();
iter != blocking_queue_set_.end(); iter++) {
iter->get()->Push(
std::make_pair(std::string(LISTEN_TERMINATE_MESSAGE), nullptr));
}
exit(EXIT_SUCCESS);
}
void SignalHandler::RegisterBlockingQueue(BlockingQueue &queue) {
blocking_queue_set_.insert(queue);
exit(0);
}
} // namespace operators
......
......@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler.h"
#include "paddle/fluid/operators/detail/rpc_server.h"
namespace paddle {
namespace operators {
......@@ -31,7 +32,7 @@ namespace operators {
constexpr char kOptimizeBlock[] = "OptimizeBlock";
constexpr char kPrefetchBlock[] = "PrefetchBlock";
void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service);
void RunServer(std::shared_ptr<detail::RPCServer> service);
class ListenAndServOp : public framework::OperatorBase {
public:
......@@ -52,41 +53,27 @@ class ListenAndServOp : public framework::OperatorBase {
void SavePort() const;
void WaitServerReady();
int GetSelectedPort() { return selected_port_; }
int GetSelectedPort() { return rpc_service_->GetSelectedPort(); }
void Stop() override;
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override;
static void ResetPort() { selected_port_ = 0; }
protected:
mutable std::shared_ptr<detail::AsyncGRPCServer> rpc_service_;
mutable std::shared_ptr<detail::RPCServer> rpc_service_;
mutable std::shared_ptr<detail::RequestHandler> request_send_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_get_handler_;
mutable std::shared_ptr<detail::RequestHandler> request_prefetch_handler_;
mutable std::shared_ptr<std::thread> server_thread_;
// FIXME(wuyi): it's static so that the operator can be cloned.
static std::atomic_int selected_port_;
};
class SignalHandler {
public:
typedef std::shared_ptr<detail::ReceivedQueue> BlockingQueue;
typedef std::unordered_set<BlockingQueue> BlockingQueueSet;
public:
static void StopAndExit(int signal_num);
static void RegisterBlockingQueue(BlockingQueue&);
static inline bool IsProgramExit() { return program_exit_flag_; }
private:
static bool program_exit_flag_;
static BlockingQueueSet blocking_queue_set_;
DISABLE_COPY_AND_ASSIGN(SignalHandler);
};
......
......@@ -46,6 +46,8 @@ class SendBarrierOp : public framework::OperatorBase {
auto rpc_client = detail::RPCClient::GetInstance();
VLOG(3) << "SendBarrierOp sync_mode:" << sync_mode;
// need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) {
......
......@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/detail/grpc_server.h"
#include "paddle/fluid/operators/detail/request_handler_impl.h"
#include "paddle/fluid/operators/listen_and_serv_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
......@@ -35,42 +37,44 @@ namespace m = paddle::operators::math;
namespace detail = paddle::operators::detail;
namespace string = paddle::string;
std::unique_ptr<detail::AsyncGRPCServer> rpc_service;
std::unique_ptr<detail::AsyncGRPCServer> g_rpc_service;
std::unique_ptr<detail::RequestHandler> g_req_handler;
void StartServer(std::atomic<bool>* initialized) {
void StartServer() {
f::Scope scope;
p::CPUPlace place;
scope.Var(NCCL_ID_VARNAME);
p::DeviceContextPool& pool = p::DeviceContextPool::Instance();
auto& dev_ctx = *pool.Get(p::CPUPlace());
rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", true));
f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace());
rpc_service->SetScope(&scope);
rpc_service->SetDevCtx(&dev_ctx);
rpc_service->SetProgram(&empty_program);
rpc_service->SetExecutor(&executor);
g_req_handler->SetScope(&scope);
g_req_handler->SetDevCtx(&dev_ctx);
g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor);
g_rpc_service->RegisterRPC(detail::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&detail::AsyncGRPCServer::RunSyncUpdate, rpc_service.get()));
*initialized = true;
rpc_service->SetCond(0);
auto recv = rpc_service->Get();
std::bind(&detail::AsyncGRPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(detail::kRequestSend);
std::cout << "before WaitFanInOfSend" << std::endl;
g_rpc_service->WaitBarrier(detail::kRequestSend);
LOG(INFO) << "got nccl id and stop server...";
rpc_service->ShutDown();
g_rpc_service->ShutDown();
server_thread.join();
}
TEST(SendNcclId, DISABLED_Normal) {
std::atomic<bool> initialized{false};
std::thread server_thread(StartServer, &initialized);
while (!initialized) {
}
// wait server to start
// sleep(2);
rpc_service->WaitServerReady();
TEST(SendNcclId, GrpcServer) {
g_req_handler.reset(new detail::RequestSendHandler(true));
g_rpc_service.reset(new detail::AsyncGRPCServer("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
g_rpc_service->WaitServerReady();
f::Scope scope;
p::CPUPlace place;
......@@ -78,17 +82,20 @@ TEST(SendNcclId, DISABLED_Normal) {
auto& dev_ctx = *pool.Get(p::CPUPlace());
auto var = scope.Var(NCCL_ID_VARNAME);
// var->SetType(f::proto::VarType_Type_RAW);
auto id = var->GetMutable<ncclUniqueId>();
p::dynload::ncclGetUniqueId(id);
int port = rpc_service->GetSelectedPort();
int port = g_rpc_service->GetSelectedPort();
std::string ep = string::Sprintf("127.0.0.1:%d", port);
detail::RPCClient client;
LOG(INFO) << "connect to server" << ep;
client.AsyncSendVariable(ep, dev_ctx, scope, NCCL_ID_VARNAME);
client.Wait();
client.AsyncSendBatchBarrier(ep);
client.Wait();
server_thread.join();
auto* ptr = rpc_service.release();
delete ptr;
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
......@@ -15,6 +15,7 @@
#pragma once
#include <stdio.h>
#include <string>
#include <thread> // NOLINT
#include <typeindex>
#include <vector>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册