未验证 提交 be6a315f 编写于 作者: T tangwei12 提交者: GitHub

Fix/sync barrier (#25016)

* fix sync barrier with barrier monitor, test=develop
上级 8db66fc3
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -24,6 +21,7 @@ limitations under the License. */ ...@@ -24,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
...@@ -95,22 +93,39 @@ class CGenNCCLIdOp : public framework::OperatorBase { ...@@ -95,22 +93,39 @@ class CGenNCCLIdOp : public framework::OperatorBase {
new RPCSERVER_T(endpoint, 1)); new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get()); distributed::RequestNotifyHandler notify_h(
distributed::DistributedMode::kSync, -1);
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_service->RegisterRPC(distributed::kRequestNotify, &notify_h);
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetRPCServer(rpc_service.get());
rpc_h.SetScope(scope); rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx); rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program); rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor); rpc_h.SetExecutor(&executor);
notify_h.SetRPCServer(rpc_service.get());
notify_h.SetScope(scope);
notify_h.SetDevCtx(&dev_ctx);
notify_h.SetProgram(&empty_program);
notify_h.SetExecutor(&executor);
distributed::BarrierMonitor::Init(1);
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
std::thread server_thread( std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0..."; VLOG(3) << "start getting nccl id from trainer 0...";
rpc_service->WaitBarrier(distributed::kRequestSend); barrier->WaitServerWeakup();
barrier->ServerWeakup();
VLOG(3) << "got nccl id and stop server..."; VLOG(3) << "got nccl id and stop server...";
barrier->Stop();
rpc_service->ShutDown(); rpc_service->ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
server_thread.join(); server_thread.join();
...@@ -123,7 +138,6 @@ class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -123,7 +138,6 @@ class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces."); AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC( AddComment(R"DOC(
CGenNCCLId operator CGenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers. For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC"); )DOC");
......
...@@ -15,6 +15,8 @@ cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_r ...@@ -15,6 +15,8 @@ cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_r
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool) cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor) cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
cc_library(barrier_monitor SRCS barrier_monitor.cc DEPS enforce simple_threadpool trainer_desc_proto)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files # FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC) if(WITH_GRPC)
...@@ -26,7 +28,7 @@ if(WITH_GRPC) ...@@ -26,7 +28,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc collective_client.cc collective_server.cc
${GRPC_SRCS} ${GRPC_SRCS}
PROTO send_recv.proto PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor) DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor barrier_monitor)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS}) set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
bool BarrierMonitor::IncreaseBarrier(const int worker_id,
const std::string &barrier) {
release_ = false;
if (barrier == BATCH_BARRIER_MESSAGE) {
VLOG(4) << "BarrierMonitor send queue recv trainer: " << worker_id;
send_barrier_queue->Push(worker_id);
} else if (barrier == FETCH_BARRIER_MESSAGE) {
VLOG(4) << "BarrierMonitor recv queue recv trainer: " << worker_id;
recv_barrier_queue->Push(worker_id);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"unknown Message status %s, only "
"BATCH_BARRIER_MESSAGE/FETCH_BARRIER_MESSAGE",
barrier));
}
return Wait();
}
void BarrierMonitor::DecreaseWorker() {
std::unique_lock<std::mutex> lck(mutex_);
workers_--;
VLOG(1) << "decrement worker num to " << workers_;
}
void BarrierMonitor::Reset(int workers, BarrierType type) {
std::unique_lock<std::mutex> lk(server_mutex_);
workers_ = workers;
barrier_type = type;
send_barrier_queue->Clear();
recv_barrier_queue->Clear();
VLOG(2) << "reset monitor workers: " << workers_ << " type: " << barrier_type;
}
void BarrierMonitor::Monitor() {
while (!IsReady() && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
VLOG(3) << "sync at first time, wait all trainer ready";
}
while (running_) {
int timer = 0;
if (IsReady()) {
Swap(true);
} else {
VLOG(4) << "running timer: " << timer << " barrier: " << barrier_type
<< " sendQ:" << send_barrier_queue->Size()
<< " recvQ: " << recv_barrier_queue->Size();
timer++;
if (max_wait_ms == -1 || timer < max_wait_ms) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
} else {
VLOG(1) << "time out of " << max_wait_ms
<< ", need barreir: " << barrier_type << " retry";
Swap(false);
}
}
}
}
bool BarrierMonitor::IsReady() {
if (barrier_type == BarrierType::kSendBarrier) {
return static_cast<int>(send_barrier_queue->Size()) == workers_;
} else {
return static_cast<int>(recv_barrier_queue->Size()) == workers_;
}
}
void BarrierMonitor::Swap(bool is_valid) {
std::unique_lock<std::mutex> lck(mutex_);
valid_ = is_valid;
release_ = true;
if (barrier_type == BarrierType::kSendBarrier) {
barrier_type = BarrierType::kRecvBarrier;
send_barrier_queue->Clear();
VLOG(4) << "barrier monitor server clean up queue and barrier";
ServerWeakup();
VLOG(4) << "barrier monitor server weak up sync to do";
WaitServerWeakup();
VLOG(4) << "barrier monitor server weak up sync done";
} else {
barrier_type = BarrierType::kSendBarrier;
recv_barrier_queue->Clear();
VLOG(4) << "barrier monitor server switch to send barrier";
}
worker_cv_.notify_all();
}
void BarrierMonitor::Stop() {
valid_ = true;
release_ = true;
running_ = false;
barrier_type = BarrierType::kRecvBarrier;
send_barrier_queue->Clear();
recv_barrier_queue->Clear();
worker_cv_.notify_all();
server_cv_.notify_all();
if (monitor_thread_) monitor_thread_->join();
monitor_thread_ = nullptr;
}
bool BarrierMonitor::Wait() {
std::unique_lock<std::mutex> lk(mutex_);
worker_cv_.wait(lk, [this] { return (release_); });
return valid_;
}
void BarrierMonitor::WaitServerWeakup() {
std::unique_lock<std::mutex> lk(server_mutex_);
server_cv_.wait(lk);
}
void BarrierMonitor::ServerWeakup() { server_cv_.notify_all(); }
std::once_flag BarrierMonitor::init_flag_;
std::unique_ptr<BarrierMonitor> BarrierMonitor::monitor_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <chrono> // NOLINT
#include <deque>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
enum BarrierType { kSendBarrier, kRecvBarrier };
constexpr int64_t kMaxWaitMS = 120000;
template <typename T>
class BlockingQueueForBarrier {
public:
explicit BlockingQueueForBarrier(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.push_back(elem);
}
worker_cv_.notify_one();
return true;
}
bool Push(T &&elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.emplace_back(std::move(elem));
}
worker_cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
worker_cv_.notify_one();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
std::deque<T>().swap(queue_);
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable worker_cv_;
};
class BarrierMonitor {
public:
explicit BarrierMonitor(int workers)
: BarrierMonitor(workers, BarrierType::kRecvBarrier, kMaxWaitMS) {}
explicit BarrierMonitor(int workers, BarrierType type, int64_t max_wait_times)
: workers_(workers), barrier_type(type), max_wait_ms(max_wait_times) {
PADDLE_ENFORCE_GT(workers, 0, platform::errors::InvalidArgument(
"trainers must have one or more"));
send_barrier_queue =
std::make_shared<BlockingQueueForBarrier<int>>(workers);
recv_barrier_queue =
std::make_shared<BlockingQueueForBarrier<int>>(workers);
running_ = true;
monitor_thread_.reset(
new std::thread(std::bind(&BarrierMonitor::Monitor, this)));
}
static BarrierMonitor *Init(int workers) {
InitImpl(workers);
return GetInstance();
}
static BarrierMonitor *GetInstance() { return monitor_.get(); }
bool IncreaseBarrier(const int worker_id, const std::string &barrier);
void DecreaseWorker();
int GetWorkerNum() { return workers_; }
void Monitor();
void Swap(bool is_valid);
void Stop();
bool IsReady();
bool Wait();
void WaitServerWeakup();
void ServerWeakup();
void WorkerWeakup();
void Reset(int workers, BarrierType type);
private:
// Init is called by GetInstance.
static void InitImpl(int workers) {
monitor_.reset(new BarrierMonitor(workers));
}
static std::once_flag init_flag_;
static std::unique_ptr<BarrierMonitor> monitor_;
int workers_;
bool running_ = false;
bool valid_ = false;
bool release_ = false;
std::condition_variable worker_cv_;
std::condition_variable server_cv_;
std::mutex server_mutex_;
std::mutex mutex_;
BarrierType barrier_type;
int64_t max_wait_ms;
std::unique_ptr<std::thread> monitor_thread_{nullptr};
std::shared_ptr<BlockingQueueForBarrier<int>> send_barrier_queue;
std::shared_ptr<BlockingQueueForBarrier<int>> recv_barrier_queue;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
...@@ -306,52 +306,19 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -306,52 +306,19 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); platform::CPUDeviceContext ctx;
auto* scope = new framework::Scope();
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); auto h = AsyncDistributeNotify(ep, ctx, *scope, BATCH_BARRIER_MESSAGE);
const std::string method = kBatchBarrierRPC; delete scope;
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h; return h;
} }
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); platform::CPUDeviceContext ctx;
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); auto* scope = new framework::Scope();
const std::string method = kFetchBarrierRPC; auto h = AsyncDistributeNotify(ep, ctx, *scope, FETCH_BARRIER_MESSAGE);
VarHandlePtr h( delete scope;
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h; return h;
} }
...@@ -384,27 +351,10 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep, ...@@ -384,27 +351,10 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); platform::CPUDeviceContext ctx;
auto* scope = new framework::Scope();
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); auto h = AsyncDistributeNotify(ep, ctx, *scope, COMPLETE_MESSAGE);
const std::string method = kSendCompleteRPC; delete scope;
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_trainer_id(trainer_id_);
req.set_varname(COMPLETE_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h; return h;
} }
...@@ -454,10 +404,21 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( ...@@ -454,10 +404,21 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer buf;
::grpc::ByteBuffer req; if (var_name_val == BATCH_BARRIER_MESSAGE ||
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_); var_name_val == FETCH_BARRIER_MESSAGE ||
var_name_val == COMPLETE_MESSAGE) {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(var_name_val);
req.set_trainer_id(trainer_id_);
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
} else {
auto* var = p_scope->FindVar(var_name_val);
SerializeToByteBuffer(var_name_val, var, *p_ctx, &buf, "", trainer_id_);
}
VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
...@@ -467,7 +428,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( ...@@ -467,7 +428,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
platform::RecordRPCEvent record_event(method); platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req, s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", buf,
&cq_); &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/string/split.h" #include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
namespace paddle { namespace paddle {
...@@ -38,161 +39,130 @@ namespace distributed { ...@@ -38,161 +39,130 @@ namespace distributed {
// to directory specified. // to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath"; constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
// Sync if (invar == nullptr) {
if (varname == BATCH_BARRIER_MESSAGE) { PADDLE_THROW(platform::errors::NotFound(
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; "sync: Can not find server side var: %s", varname));
rpc_server_->IncreaseBatchBarrier(kRequestSend); return false;
} else if (varname == COMPLETE_MESSAGE) { }
VLOG(3) << "sync: recv complete message";
if (HeartBeatMonitor::GetInstance() != nullptr) { if (distributed_mode_ == DistributedMode::kSync) {
HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED); return true;
} }
rpc_server_->Complete(); HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
} else {
// Async
if (distributed_mode_ != DistributedMode::kSync) {
VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
std::string run_varname = varname; std::string run_varname = varname;
string::Piece part_piece("@PIECE");
string::Piece var_name_piece = string::Piece(varname);
string::Piece part_piece("@PIECE"); if (string::Contains(var_name_piece, part_piece)) {
string::Piece var_name_piece = string::Piece(varname); auto varname_splits = paddle::string::Split(varname, '@');
run_varname = varname_splits[0];
scope->Rename(varname, run_varname);
}
if (string::Contains(var_name_piece, part_piece)) { if (distributed_mode_ == DistributedMode::kGeo &&
auto varname_splits = paddle::string::Split(varname, '@'); AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
PADDLE_ENFORCE_EQ(varname_splits.size(), 3); auto &grad_slr =
run_varname = varname_splits[0]; scope->FindVar(run_varname)->Get<framework::SelectedRows>();
scope->Rename(varname, run_varname); AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
} grad_slr.rows());
}
if (distributed_mode_ == DistributedMode::kGeo && executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) { scope);
auto& grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
grad_slr.rows());
}
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope);
return true;
} else { // sync
rpc_server_->WaitCond(kRequestSend);
VLOG(3) << "sync: processing received var: " << varname;
PADDLE_ENFORCE_NOT_NULL(
invar, platform::errors::NotFound(
"sync: Can not find server side var %s.", varname));
}
}
return true; return true;
} }
bool RequestGetHandler::Handle(const std::string& varname, bool RequestGetHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(3) << "RequestGetHandler:" << varname VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name; << " table_name: " << table_name;
if (distributed_mode_ == DistributedMode::kSync) { if (distributed_mode_ == DistributedMode::kSync) {
if (varname == FETCH_BARRIER_MESSAGE) { *outvar = scope_->FindVar(varname);
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else {
rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
} else { } else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { if (enable_dc_asgd_) {
if (enable_dc_asgd_) { // NOTE: the format is determined by distribute_transpiler.py
// NOTE: the format is determined by distribute_transpiler.py std::string param_bak_name =
std::string param_bak_name = string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id); VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id; auto var = scope_->FindVar(varname);
auto var = scope_->FindVar(varname); auto t_orig = var->Get<framework::LoDTensor>();
auto t_orig = var->Get<framework::LoDTensor>(); auto param_bak = scope_->Var(param_bak_name);
auto param_bak = scope_->Var(param_bak_name); auto t = param_bak->GetMutable<framework::LoDTensor>();
auto t = param_bak->GetMutable<framework::LoDTensor>(); t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type()); VLOG(3) << "copying " << varname << " to " << param_bak_name;
VLOG(3) << "copying " << varname << " to " << param_bak_name; framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t); }
}
VLOG(1) << "Table name empty? " << table_name.empty(); if (distributed_mode_ == DistributedMode::kGeo &&
if (distributed_mode_ == DistributedMode::kGeo) { AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist " !table_name.empty()) {
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam( std::vector<int64_t> updated_rows;
varname); AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
} varname, trainer_id, &updated_rows);
if (distributed_mode_ == DistributedMode::kGeo && if (VLOG_IS_ON(3)) {
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) && std::ostringstream sstream;
!table_name.empty()) { sstream << "[";
std::vector<int64_t> updated_rows; for (auto &row_id : updated_rows) {
AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear( sstream << row_id << ", ";
varname, trainer_id, &updated_rows);
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto& row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto& origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto* origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims();
*outvar = scope->Var();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto* data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
} }
} else { sstream << "]";
*outvar = scope_->FindVar(varname); VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto &origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto *origin_tensor_data = origin_tensor.data<float>();
auto &dims = origin_tensor.dims();
*outvar = scope->Var();
auto *out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto *data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0],
platform::errors::OutOfRange(
"expected >= 0 and < %ld, but got %ld.", dims[0],
updated_rows[i]));
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
} }
} else {
*outvar = scope_->FindVar(varname);
} }
} }
return true; return true;
} }
bool RequestGetNoBarrierHandler::Handle(const std::string& varname, bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name;
...@@ -207,18 +177,19 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname, ...@@ -207,18 +177,19 @@ bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
*outvar = scope_->FindVar(var_name_piece.ToString()); *outvar = scope_->FindVar(var_name_piece.ToString());
return true; return true;
} else { } else {
PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE); PADDLE_THROW(platform::errors::InvalidArgument(
"GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE));
} }
return true; return true;
} }
bool RequestPrefetchHandler::Handle(const std::string& varname, bool RequestPrefetchHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname; VLOG(4) << "RequestPrefetchHandler " << varname;
if (table_name.empty()) { if (table_name.empty()) {
...@@ -236,19 +207,20 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, ...@@ -236,19 +207,20 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestCheckpointHandler::Handle(const std::string& varname, bool RequestCheckpointHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
PADDLE_ENFORCE( PADDLE_ENFORCE_NE(
checkpoint_notify_id != -1, checkpoint_notify_id, -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke."); platform::errors::Unavailable(
"when checkpoint_notify_id = -1, there should be no RPC invoke."));
// TODO(tangwei12): find out why scope will be error. // TODO(tangwei12): find out why scope will be error.
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>(); auto *lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear(); lt_var->clear();
lt_var->append(out_var_name); lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: " VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
...@@ -257,33 +229,56 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -257,33 +229,56 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestNotifyHandler::Handle(const std::string& varname, bool RequestNotifyHandler::Handle(const std::string &varname,
framework::Scope* scope, framework::Scope *scope,
framework::Variable* invar, framework::Variable *invar,
framework::Variable** outvar, framework::Variable **outvar,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string &out_var_name,
const std::string& table_name) { const std::string &table_name) {
VLOG(4) << "RequestNotifyHandler: " << varname;
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id; VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER); string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
string::Piece batch_piece(BATCH_BARRIER_MESSAGE);
string::Piece fetch_piece(FETCH_BARRIER_MESSAGE);
string::Piece complete_piece(COMPLETE_MESSAGE);
string::Piece var_name_piece = string::Piece(varname); string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) {
if (string::Contains(var_name_piece, batch_piece)) {
return BarrierMonitor::GetInstance()->IncreaseBarrier(
trainer_id, BATCH_BARRIER_MESSAGE);
} else if (string::Contains(var_name_piece, fetch_piece)) {
return BarrierMonitor::GetInstance()->IncreaseBarrier(
trainer_id, FETCH_BARRIER_MESSAGE);
} else if (string::Contains(var_name_piece, complete_piece)) {
if (HeartBeatMonitor::GetInstance() != nullptr) {
HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED);
}
rpc_server_->Complete();
BarrierMonitor::GetInstance()->DecreaseWorker();
return true;
} else if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update"; VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
lr_decay_block_id, -1, lr_decay_block_id, -1,
"when lr_decay_block_id = -1, there should be no RPC invoke."); platform::errors::InvalidArgument(
auto* origin_var = scope_->FindVar(varname); "when lr_decay_block_id = -1, there should be no RPC invoke."));
auto *origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>(); auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto* send_var = scope->FindVar(varname); auto *send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>(); auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t* origin_value = int64_t *origin_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place()); origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t* send_value = int64_t *send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place()); send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0]; origin_value[0] += send_value[0];
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
return true;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"unkown varname %s with RequestNotifyHandler", varname));
} }
return true; return true;
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -24,6 +21,7 @@ limitations under the License. */ ...@@ -24,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
...@@ -119,6 +117,7 @@ void StartServer(const std::string& rpc_name) { ...@@ -119,6 +117,7 @@ void StartServer(const std::string& rpc_name) {
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get()); g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
distributed::HeartBeatMonitor::Init(2, true, "w@grad"); distributed::HeartBeatMonitor::Init(2, true, "w@grad");
distributed::BarrierMonitor::Init(2);
g_req_handler->SetRPCServer(g_rpc_service.get()); g_req_handler->SetRPCServer(g_rpc_service.get());
...@@ -164,6 +163,9 @@ TEST(PREFETCH, CPU) { ...@@ -164,6 +163,9 @@ TEST(PREFETCH, CPU) {
} }
} }
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Stop();
g_rpc_service->ShutDown(); g_rpc_service->ShutDown();
server_thread.join(); server_thread.join();
LOG(INFO) << "begin reset"; LOG(INFO) << "begin reset";
...@@ -174,20 +176,24 @@ TEST(PREFETCH, CPU) { ...@@ -174,20 +176,24 @@ TEST(PREFETCH, CPU) {
TEST(COMPLETE, CPU) { TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
g_req_handler.reset( g_req_handler.reset(new distributed::RequestNotifyHandler(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync)); distributed::DistributedMode::kSync, -1));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE(client != nullptr); PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestSend); std::thread server_thread(StartServer, distributed::kRequestNotify);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort(); int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port); std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
client->AsyncSendComplete(ep); client->AsyncSendComplete(ep);
client->Wait(); client->Wait();
EXPECT_EQ(g_rpc_service->GetClientNum(), 1); auto* barrier = distributed::BarrierMonitor::GetInstance();
EXPECT_EQ(barrier->GetWorkerNum(), 1);
barrier->Stop();
g_rpc_service->ShutDown(); g_rpc_service->ShutDown();
server_thread.join(); server_thread.join();
......
...@@ -2,9 +2,9 @@ include(operators) ...@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
if(WITH_GRPC) if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node)
else() else()
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node) set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA) if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs) find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL) ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -21,6 +18,7 @@ limitations under the License. */ ...@@ -21,6 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
...@@ -30,16 +28,16 @@ namespace operators { ...@@ -30,16 +28,16 @@ namespace operators {
class GenNCCLIdOp : public framework::OperatorBase { class GenNCCLIdOp : public framework::OperatorBase {
public: public:
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, GenNCCLIdOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap& outputs, const framework::VariableNameMap &outputs,
const framework::AttributeMap& attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope &scope,
const platform::Place& dev_place) const override { const platform::Place &dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
// put nccl id in CPUPlace // put nccl id in CPUPlace
auto& dev_ctx = *pool.Get(platform::CPUPlace()); auto &dev_ctx = *pool.Get(platform::CPUPlace());
int trainer_id = Attr<int>("trainer_id"); int trainer_id = Attr<int>("trainer_id");
std::vector<std::string> trainers = std::vector<std::string> trainers =
...@@ -55,7 +53,7 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -55,7 +53,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::string endpoint = trainers[trainer_id]; std::string endpoint = trainers[trainer_id];
framework::Scope& local_scope = scope.NewScope(); framework::Scope &local_scope = scope.NewScope();
int nccl_comm_num = Attr<int>("nccl_comm_num"); int nccl_comm_num = Attr<int>("nccl_comm_num");
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce"); int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
...@@ -171,10 +169,10 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -171,10 +169,10 @@ class GenNCCLIdOp : public framework::OperatorBase {
} }
private: private:
void GenerateAndSend(framework::Scope* scope, void GenerateAndSend(framework::Scope *scope,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext &dev_ctx,
const std::string& nccl_id_name, const std::string &nccl_id_name,
const std::vector<std::string>& endpoint_list) const { const std::vector<std::string> &endpoint_list) const {
auto var = scope->FindVar(nccl_id_name); auto var = scope->FindVar(nccl_id_name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found", var, platform::errors::NotFound("Variable with name %s is not found",
...@@ -182,76 +180,96 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -182,76 +180,96 @@ class GenNCCLIdOp : public framework::OperatorBase {
auto id = var->GetMutable<ncclUniqueId>(); auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(id)); PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(id));
distributed::RPCClient* client = distributed::RPCClient *client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) { for (auto &ep : endpoint_list) {
VLOG(3) << "sending nccl_id_var:" << nccl_id_name << " to " << ep; VLOG(3) << "sending nccl_id_var:" << nccl_id_name << " to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, nccl_id_name); client->AsyncSendVar(ep, dev_ctx, *scope, nccl_id_name);
} }
client->Wait(); client->Wait();
for (auto& ep : endpoint_list) { for (auto &ep : endpoint_list) {
client->AsyncSendBatchBarrier(ep); client->AsyncSendBatchBarrier(ep);
} }
client->Wait(); client->Wait();
VLOG(3) << "sending completed..."; VLOG(3) << "sending completed...";
} }
void GetIdByServer(const std::string& endpoint, framework::Scope* scope, void GetIdByServer(const std::string &endpoint, framework::Scope *scope,
const platform::DeviceContext& dev_ctx, int nccl_comm_num, const platform::DeviceContext &dev_ctx, int nccl_comm_num,
bool use_hierarchical_allreduce, int trainer_id, bool use_hierarchical_allreduce, int trainer_id,
int inter_trainer_id, int exter_trainer_id) const { int inter_trainer_id, int exter_trainer_id) const {
// std::string endpoint = Attr<std::string>("endpoint"); // std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default // NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and // deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash. // that will cause a wired crash.
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service( std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1)); new RPCSERVER_T(endpoint, 1));
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
distributed::RequestNotifyHandler notify_h(
distributed::DistributedMode::kSync, -1);
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h); rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_h.SetRPCServer(rpc_service.get()); rpc_service->RegisterRPC(distributed::kRequestNotify, &notify_h);
framework::ProgramDesc empty_program; framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace()); framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetRPCServer(rpc_service.get());
rpc_h.SetScope(scope); rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx); rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program); rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor); rpc_h.SetExecutor(&executor);
notify_h.SetRPCServer(rpc_service.get());
notify_h.SetScope(scope);
notify_h.SetDevCtx(&dev_ctx);
notify_h.SetProgram(&empty_program);
notify_h.SetExecutor(&executor);
distributed::BarrierMonitor::Init(1);
auto *barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
std::thread server_thread( std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
for (int i = 0; i < nccl_comm_num; i++) { for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend); barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
VLOG(3) << "trainer_id:" << trainer_id VLOG(3) << "trainer_id:" << trainer_id
<< " start getting nccl id from trainer 0, nccl_comm_no:" << i; << " start getting nccl id from trainer 0, nccl_comm_no:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
} }
if (use_hierarchical_allreduce) { if (use_hierarchical_allreduce) {
if (inter_trainer_id > 0) { if (inter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) { for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend); barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
VLOG(3) << "trainer_id:" << trainer_id VLOG(3) << "trainer_id:" << trainer_id
<< ", inter_trainer_id:" << inter_trainer_id << ", inter_trainer_id:" << inter_trainer_id
<< " start getting nccl id from inter_trainer:" << i; << " start getting nccl id from inter_trainer:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
} }
} }
if (exter_trainer_id > 0) { if (exter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) { for (int i = 0; i < nccl_comm_num; i++) {
rpc_service->SetCond(distributed::kRequestSend); barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
VLOG(3) VLOG(3)
<< "trainer_id:" << trainer_id << "trainer_id:" << trainer_id
<< ", exter_trainer_id:" << exter_trainer_id << ", exter_trainer_id:" << exter_trainer_id
<< " start getting nccl id from exter_trainer 0, nccl_comm_no:" << " start getting nccl id from exter_trainer 0, nccl_comm_no:"
<< i; << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
} }
} }
} }
...@@ -260,6 +278,7 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -260,6 +278,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
<< ", inter_trainer_id:" << inter_trainer_id << ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id << ", exter_trainer_id:" << exter_trainer_id
<< " got nccl id and stop server..."; << " got nccl id and stop server...";
barrier->Stop();
rpc_service->ShutDown(); rpc_service->ShutDown();
VLOG(3) << "rpc server stopped"; VLOG(3) << "rpc server stopped";
server_thread.join(); server_thread.join();
...@@ -272,7 +291,6 @@ class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -272,7 +291,6 @@ class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces."); AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC( AddComment(R"DOC(
GenNCCLId operator GenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers. For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC"); )DOC");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -25,6 +22,7 @@ limitations under the License. */ ...@@ -25,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h" #include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h" #include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
...@@ -38,10 +36,13 @@ DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch"); ...@@ -38,10 +36,13 @@ DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
volatile sig_atomic_t gSignalStatus;
void RunServer(std::shared_ptr<distributed::RPCServer> service) { void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer(); service->StartServer();
VLOG(4) << "RunServer thread end"; VLOG(4) << "RunServer thread end";
} }
static void split(const std::string &str, char sep, static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) { std::vector<std::string> *pieces) {
pieces->clear(); pieces->clear();
...@@ -126,6 +127,7 @@ void ListenAndServOp::RunSyncLoop( ...@@ -126,6 +127,7 @@ void ListenAndServOp::RunSyncLoop(
for (size_t i = 1; i < program->Size(); ++i) { for (size_t i = 1; i < program->Size(); ++i) {
optimize_blocks_list.push_back(i); optimize_blocks_list.push_back(i);
} }
auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list); auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list);
// Insert placeholder for block0 which holds current op itself, // Insert placeholder for block0 which holds current op itself,
// NOTE the first block in `optimize_prepared` should never be ran. // NOTE the first block in `optimize_prepared` should never be ran.
...@@ -135,21 +137,15 @@ void ListenAndServOp::RunSyncLoop( ...@@ -135,21 +137,15 @@ void ListenAndServOp::RunSyncLoop(
// Trainers will get all parameters from pserver in the // Trainers will get all parameters from pserver in the
// startup program, so we will wait RequestGet first // startup program, so we will wait RequestGet first
rpc_service_->SetCond(distributed::kRequestGet); auto *barrier = distributed::BarrierMonitor::GetInstance();
rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter();
while (true) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG(3) << "wait all clients to send gradient"; barrier->WaitServerWeakup();
rpc_service_->SetCond(distributed::kRequestSend);
VLOG(3) << "wait all clients to send send_barrier";
rpc_service_->WaitBarrier(distributed::kRequestSend);
if (rpc_service_->IsExit()) { if (gSignalStatus != 0) {
LOG(WARNING) << "get exit!rpc_processor break!"; LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(distributed::kRequestGet);
break; break;
} }
...@@ -180,12 +176,8 @@ void ListenAndServOp::RunSyncLoop( ...@@ -180,12 +176,8 @@ void ListenAndServOp::RunSyncLoop(
VLOG(3) << "ResetReceivedVars"; VLOG(3) << "ResetReceivedVars";
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars()); ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
VLOG(3) << "wait all clients to get parameters back"; barrier->ServerWeakup();
rpc_service_->SetCond(distributed::kRequestGet); VLOG(3) << "kRecvBarrier to push params to trainers";
VLOG(3) << "wait all clients to send fetch_barrier";
rpc_service_->WaitBarrier(distributed::kRequestGet);
VLOG(3) << "ResetBarrierCounter";
rpc_service_->ResetBarrierCounter();
} // while(true) } // while(true)
} }
...@@ -281,7 +273,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -281,7 +273,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
while (true) { while (true) {
if (rpc_service_->IsExit()) { if (gSignalStatus != 0) {
VLOG(4) << "get exit!rpc_processor break!"; VLOG(4) << "get exit!rpc_processor break!";
break; break;
} }
...@@ -391,7 +383,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -391,7 +383,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get()); request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify, rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), rpc_send_thread_num); request_notify_handler_.get(), fan_in * 2);
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
...@@ -440,6 +432,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -440,6 +432,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared_ctx; prefetch_var_name_to_prepared_ctx;
for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) { for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
auto block_id = prefetch_block_id_list[i]; auto block_id = prefetch_block_id_list[i];
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id]; auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
...@@ -448,8 +441,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -448,8 +441,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// parse attr of kSparseGradToParam sparse_grad_name -> param_name // parse attr of kSparseGradToParam sparse_grad_name -> param_name
std::unordered_map<std::string, std::string> sparse_grad_name_to_param_name; std::unordered_map<std::string, std::string> sparse_grad_name_to_param_name;
auto sparse_grad_name_to_param_name_str = auto sparse_grad_name_to_param_name_str =
Attr<std::vector<std::string>>(kSparseGradToParam); Attr<std::vector<std::string>>(kSparseGradToParam);
for (const auto &sparse_grad_name_and_param_name : for (const auto &sparse_grad_name_and_param_name :
sparse_grad_name_to_param_name_str) { sparse_grad_name_to_param_name_str) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
...@@ -477,17 +472,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -477,17 +472,18 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit); signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit); signal(SIGTERM, SignalHandler::StopAndExit);
distributed::BarrierMonitor::Init(fan_in);
if (distributed_mode == distributed::DistributedMode::kSync) { if (distributed_mode == distributed::DistributedMode::kSync) {
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready..."; VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady(); rpc_service_->WaitServerReady();
CacheVarsType(inputs, recv_scope);
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(); SavePort();
CacheVarsType(inputs, recv_scope);
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx, RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id); prefetch_block_id_list, checkpoint_block_id);
} else { } else {
...@@ -574,9 +570,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -574,9 +570,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
void SignalHandler::StopAndExit(int signal_num) { void SignalHandler::StopAndExit(int signal_num) {
// Do not use VLOG here for the device for printing maybe already released. // Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces. // exit will release interal allocated resoureces.
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid()); distributed::BarrierMonitor::GetInstance()->Stop();
remove(file_path.c_str()); gSignalStatus = signal_num;
exit(0);
} }
} // namespace operators } // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -20,6 +17,7 @@ limitations under the License. */ ...@@ -20,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h" #include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
...@@ -42,6 +40,7 @@ namespace string = paddle::string; ...@@ -42,6 +40,7 @@ namespace string = paddle::string;
std::unique_ptr<distributed::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
std::unique_ptr<distributed::RequestNotifyHandler> g_notify_handler;
void StartServer() { void StartServer() {
f::Scope scope; f::Scope scope;
...@@ -52,21 +51,35 @@ void StartServer() { ...@@ -52,21 +51,35 @@ void StartServer() {
f::ProgramDesc empty_program; f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace()); f::Executor executor(dev_ctx.GetPlace());
g_req_handler->SetScope(&scope); g_req_handler->SetScope(&scope);
g_req_handler->SetDevCtx(&dev_ctx); g_req_handler->SetDevCtx(&dev_ctx);
g_req_handler->SetProgram(&empty_program); g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor); g_req_handler->SetExecutor(&executor);
g_req_handler->SetRPCServer(g_rpc_service.get());
g_notify_handler.SetRPCServer(rpc_service.get());
g_notify_handler.SetScope(scope);
g_notify_handler.SetDevCtx(&dev_ctx);
g_notify_handler.SetProgram(&empty_program);
g_notify_handler.SetExecutor(&executor);
g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get()); g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get()); g_rpc_service->RegisterRPC(distributed::RequestNotifyHandler,
g_notify_handler.get());
distributed::BarrierMonitor::Init(1);
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
std::thread server_thread( std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get())); std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
g_rpc_service->SetCond(distributed::kRequestSend); barrier->WaitServerWeakup();
g_rpc_service->WaitBarrier(distributed::kRequestSend); barrier->ServerWeakup();
LOG(INFO) << "got nccl id and stop server..."; LOG(INFO) << "got nccl id and stop server...";
barrier->Stop();
g_rpc_service->ShutDown(); g_rpc_service->ShutDown();
server_thread.join(); server_thread.join();
} }
...@@ -74,6 +87,10 @@ void StartServer() { ...@@ -74,6 +87,10 @@ void StartServer() {
TEST(SendNcclId, RPCServer) { TEST(SendNcclId, RPCServer) {
g_req_handler.reset( g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync)); new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_notify_handler.reset(new distributed::RequestNotifyHandler(
distributed::DistributedMode::kSync, -1));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer); std::thread server_thread(StartServer);
...@@ -104,4 +121,5 @@ TEST(SendNcclId, RPCServer) { ...@@ -104,4 +121,5 @@ TEST(SendNcclId, RPCServer) {
server_thread.join(); server_thread.join();
g_rpc_service.reset(nullptr); g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr); g_req_handler.reset(nullptr);
g_notify_handler.reset(nullptr);
} }
...@@ -94,7 +94,7 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase): ...@@ -94,7 +94,7 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase):
current_id=0, current_id=0,
role=role_maker.Role.WORKER role=role_maker.Role.WORKER
if training_role == "TRAINER" else role_maker.Role.SERVER, if training_role == "TRAINER" else role_maker.Role.SERVER,
worker_num=2, worker_num=1,
server_endpoints=["127.0.0.1:6002"]) server_endpoints=["127.0.0.1:6002"])
if training_role == "TRAINER": if training_role == "TRAINER":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册