未验证 提交 4b3778a3 编写于 作者: T tangwei12 提交者: GitHub

Revert/barrier for sync (#25417)

* add retry for prefetch

* Revert "Fix/sync barrier (#25016)"

This reverts commit be6a315f.

* reopen dist UT, test=develop

* remove fl UT, test=develop
上级 dd9d146e
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -21,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
......@@ -93,39 +95,22 @@ class CGenNCCLIdOp : public framework::OperatorBase {
new RPCSERVER_T(endpoint, 1));
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
distributed::RequestNotifyHandler notify_h(
distributed::DistributedMode::kSync, -1);
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_service->RegisterRPC(distributed::kRequestNotify, &notify_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetRPCServer(rpc_service.get());
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
notify_h.SetRPCServer(rpc_service.get());
notify_h.SetScope(scope);
notify_h.SetDevCtx(&dev_ctx);
notify_h.SetProgram(&empty_program);
notify_h.SetExecutor(&executor);
distributed::BarrierMonitor::Init(1);
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "start getting nccl id from trainer 0...";
barrier->WaitServerWeakup();
barrier->ServerWeakup();
rpc_service->WaitBarrier(distributed::kRequestSend);
VLOG(3) << "got nccl id and stop server...";
barrier->Stop();
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
......@@ -138,6 +123,7 @@ class CGenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
CGenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
......
......@@ -15,8 +15,6 @@ cc_test(async_sparse_param_update_recorder_test SRCS async_sparse_param_update_r
cc_library(heart_beat_monitor SRCS heart_beat_monitor.cc DEPS enforce simple_threadpool)
cc_test(heart_beat_monitor_test SRCS heart_beat_monitor_test.cc DEPS heart_beat_monitor)
cc_library(barrier_monitor SRCS barrier_monitor.cc DEPS enforce simple_threadpool trainer_desc_proto device_context)
# FIXME(typhoonzero): use add_subdirectory once we clean the dependency of these files
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
if(WITH_GRPC)
......@@ -28,7 +26,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${GRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor barrier_monitor)
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS} async_sparse_param_update_recorder heart_beat_monitor)
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include <gflags/gflags.h>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
bool BarrierMonitor::IncreaseBarrier(const int worker_id,
const std::string &barrier) {
release_ = false;
if (barrier == BATCH_BARRIER_MESSAGE) {
VLOG(4) << "BarrierMonitor send queue recv trainer: " << worker_id;
send_barrier_queue->Push(worker_id);
} else if (barrier == FETCH_BARRIER_MESSAGE) {
VLOG(4) << "BarrierMonitor recv queue recv trainer: " << worker_id;
recv_barrier_queue->Push(worker_id);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"unknown Message status %s, only "
"BATCH_BARRIER_MESSAGE/FETCH_BARRIER_MESSAGE",
barrier));
}
return Wait();
}
void BarrierMonitor::DecreaseWorker() {
std::unique_lock<std::mutex> lck(mutex_);
workers_--;
VLOG(1) << "decrement worker num to " << workers_;
}
void BarrierMonitor::Reset(int workers, BarrierType type) {
std::unique_lock<std::mutex> lk(server_mutex_);
workers_ = workers;
barrier_type = type;
send_barrier_queue->Clear();
recv_barrier_queue->Clear();
VLOG(2) << "reset monitor workers: " << workers_ << " type: " << barrier_type;
}
void BarrierMonitor::Monitor() {
while (!IsReady() && running_) {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
VLOG(3) << "sync at first time, wait all trainer ready";
}
while (running_) {
int timer = 0;
if (IsReady()) {
Swap(true);
} else {
VLOG(4) << "running timer: " << timer << " barrier: " << barrier_type
<< " sendQ:" << send_barrier_queue->Size()
<< " recvQ: " << recv_barrier_queue->Size();
timer++;
if (max_wait_ms == -1 || timer < max_wait_ms) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
} else {
VLOG(1) << "time out of " << max_wait_ms
<< ", need barreir: " << barrier_type << " retry";
Swap(false);
}
}
}
}
bool BarrierMonitor::IsReady() {
if (barrier_type == BarrierType::kSendBarrier) {
return static_cast<int>(send_barrier_queue->Size()) == workers_;
} else {
return static_cast<int>(recv_barrier_queue->Size()) == workers_;
}
}
void BarrierMonitor::Swap(bool is_valid) {
std::unique_lock<std::mutex> lck(mutex_);
valid_ = is_valid;
release_ = true;
if (barrier_type == BarrierType::kSendBarrier) {
barrier_type = BarrierType::kRecvBarrier;
send_barrier_queue->Clear();
VLOG(4) << "barrier monitor server clean up queue and barrier";
ServerWeakup();
VLOG(4) << "barrier monitor server weak up sync to do";
WaitServerWeakup();
VLOG(4) << "barrier monitor server weak up sync done";
} else {
barrier_type = BarrierType::kSendBarrier;
recv_barrier_queue->Clear();
VLOG(4) << "barrier monitor server switch to send barrier";
}
worker_cv_.notify_all();
}
void BarrierMonitor::Stop() {
valid_ = true;
release_ = true;
running_ = false;
barrier_type = BarrierType::kRecvBarrier;
send_barrier_queue->Clear();
recv_barrier_queue->Clear();
worker_cv_.notify_all();
server_cv_.notify_all();
if (monitor_thread_) monitor_thread_->join();
monitor_thread_ = nullptr;
}
bool BarrierMonitor::Wait() {
std::unique_lock<std::mutex> lk(mutex_);
worker_cv_.wait(lk, [this] { return (release_); });
return valid_;
}
void BarrierMonitor::WaitServerWeakup() {
std::unique_lock<std::mutex> lk(server_mutex_);
server_cv_.wait(lk);
}
void BarrierMonitor::ServerWeakup() { server_cv_.notify_all(); }
std::once_flag BarrierMonitor::init_flag_;
std::unique_ptr<BarrierMonitor> BarrierMonitor::monitor_(nullptr);
} // namespace distributed
} // namespace operators
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gflags/gflags.h>
#include <chrono> // NOLINT
#include <deque>
#include <functional>
#include <future> // NOLINT
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <thread> // NOLINT
#include <ThreadPool.h>
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace distributed {
enum BarrierType { kSendBarrier, kRecvBarrier };
constexpr int64_t kMaxWaitMS = 120000;
template <typename T>
class BlockingQueueForBarrier {
public:
explicit BlockingQueueForBarrier(size_t capacity) : capacity_(capacity) {
PADDLE_ENFORCE_GT(capacity_, 0,
platform::errors::InvalidArgument(
"The capacity must be greater than 0."));
}
bool Push(const T &elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.push_back(elem);
}
worker_cv_.notify_one();
return true;
}
bool Push(T &&elem) {
{
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [&] { return queue_.size() < capacity_; });
queue_.emplace_back(std::move(elem));
}
worker_cv_.notify_one();
return true;
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
worker_cv_.wait(lock, [=] { return !queue_.empty(); });
T rc(std::move(queue_.front()));
queue_.pop_front();
worker_cv_.notify_one();
return rc;
}
size_t Cap() const {
std::lock_guard<std::mutex> lock(mutex_);
return capacity_;
}
size_t Size() const {
std::lock_guard<std::mutex> lock(mutex_);
return queue_.size();
}
void Clear() {
std::lock_guard<std::mutex> lock(mutex_);
std::deque<T>().swap(queue_);
}
private:
const size_t capacity_;
std::deque<T> queue_;
mutable std::mutex mutex_;
std::condition_variable worker_cv_;
};
class BarrierMonitor {
public:
explicit BarrierMonitor(int workers)
: BarrierMonitor(workers, BarrierType::kRecvBarrier, kMaxWaitMS) {}
explicit BarrierMonitor(int workers, BarrierType type, int64_t max_wait_times)
: workers_(workers), barrier_type(type), max_wait_ms(max_wait_times) {
PADDLE_ENFORCE_GT(workers, 0, platform::errors::InvalidArgument(
"trainers must have one or more"));
send_barrier_queue =
std::make_shared<BlockingQueueForBarrier<int>>(workers);
recv_barrier_queue =
std::make_shared<BlockingQueueForBarrier<int>>(workers);
running_ = true;
monitor_thread_.reset(
new std::thread(std::bind(&BarrierMonitor::Monitor, this)));
}
static BarrierMonitor *Init(int workers) {
InitImpl(workers);
return GetInstance();
}
static BarrierMonitor *GetInstance() { return monitor_.get(); }
bool IncreaseBarrier(const int worker_id, const std::string &barrier);
void DecreaseWorker();
int GetWorkerNum() { return workers_; }
void Monitor();
void Swap(bool is_valid);
void Stop();
bool IsReady();
bool Wait();
void WaitServerWeakup();
void ServerWeakup();
void WorkerWeakup();
void Reset(int workers, BarrierType type);
private:
// Init is called by GetInstance.
static void InitImpl(int workers) {
monitor_.reset(new BarrierMonitor(workers));
}
static std::once_flag init_flag_;
static std::unique_ptr<BarrierMonitor> monitor_;
int workers_;
bool running_ = false;
bool valid_ = false;
bool release_ = false;
std::condition_variable worker_cv_;
std::condition_variable server_cv_;
std::mutex server_mutex_;
std::mutex mutex_;
BarrierType barrier_type;
int64_t max_wait_ms;
std::unique_ptr<std::thread> monitor_thread_{nullptr};
std::shared_ptr<BlockingQueueForBarrier<int>> send_barrier_queue;
std::shared_ptr<BlockingQueueForBarrier<int>> recv_barrier_queue;
};
} // namespace distributed
} // namespace operators
} // namespace paddle
......@@ -260,7 +260,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
while (true) {
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
s->Prepare(h, kPrefetchTimeout);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope,
p_ctx, s, method, h, table_name_val, this] {
......@@ -306,19 +306,52 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
platform::CPUDeviceContext ctx;
auto* scope = new framework::Scope();
auto h = AsyncDistributeNotify(ep, ctx, *scope, BATCH_BARRIER_MESSAGE);
delete scope;
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = kBatchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
platform::CPUDeviceContext ctx;
auto* scope = new framework::Scope();
auto h = AsyncDistributeNotify(ep, ctx, *scope, FETCH_BARRIER_MESSAGE);
delete scope;
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
const std::string method = kFetchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -351,10 +384,27 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
platform::CPUDeviceContext ctx;
auto* scope = new framework::Scope();
auto h = AsyncDistributeNotify(ep, ctx, *scope, COMPLETE_MESSAGE);
delete scope;
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = kSendCompleteRPC;
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_trainer_id(trainer_id_);
req.set_varname(COMPLETE_MESSAGE);
platform::RecordRPCEvent record_event(method);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -404,21 +454,10 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
::grpc::ByteBuffer buf;
if (var_name_val == BATCH_BARRIER_MESSAGE ||
var_name_val == FETCH_BARRIER_MESSAGE ||
var_name_val == COMPLETE_MESSAGE) {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(var_name_val);
req.set_trainer_id(trainer_id_);
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
} else {
auto* var = p_scope->FindVar(var_name_val);
SerializeToByteBuffer(var_name_val, var, *p_ctx, &buf, "", trainer_id_);
}
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
......@@ -428,7 +467,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", buf,
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
......@@ -448,6 +487,18 @@ bool GRPCClient::Wait() {
return ok_;
}
inline bool ShouldRetry(const std::string& method, int error_code) {
if (method == kPrefetchRPC) {
return true;
}
if (error_code == grpc::StatusCode::DEADLINE_EXCEEDED) {
return true;
}
return false;
}
void GRPCClient::Proceed() {
void* tag = nullptr;
bool ok = false;
......@@ -461,19 +512,9 @@ void GRPCClient::Proceed() {
if (c->status_.ok()) {
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process();
} else if (c->status_.error_code() == grpc::StatusCode::DEADLINE_EXCEEDED) {
PADDLE_THROW(platform::errors::External(
"%s meets grpc error, error_code is %d, error message is %s, error "
"details is %s.",
c->GetVarHandlePtr()->String(), c->status_.error_code(),
c->status_.error_message(), c->status_.error_details()));
{
std::lock_guard<std::mutex> lk(sync_mutex_);
ok_ = false;
}
c->Finish(false);
} else if (c->status_.error_code() == grpc::StatusCode::UNAVAILABLE) {
VLOG(3) << c->GetVarHandlePtr()->String()
} else if (ShouldRetry(c->GetVarHandlePtr()->method(),
c->status_.error_code())) {
VLOG(0) << c->GetVarHandlePtr()->String()
<< " meets grpc error, error_code:" << c->status_.error_code()
<< " error_message:" << c->status_.error_message()
<< " error_details:" << c->status_.error_details()
......
......@@ -57,6 +57,7 @@ constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
constexpr int64_t kPrefetchTimeout = 60000;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
......
......@@ -28,7 +28,6 @@
#include "paddle/fluid/string/split.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
namespace paddle {
......@@ -39,65 +38,93 @@ namespace distributed {
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestSendHandler:" << varname;
if (invar == nullptr) {
PADDLE_THROW(platform::errors::NotFound(
"sync: Can not find server side var: %s", varname));
return false;
}
// Sync
if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
rpc_server_->IncreaseBatchBarrier(kRequestSend);
} else if (varname == COMPLETE_MESSAGE) {
VLOG(3) << "sync: recv complete message";
if (distributed_mode_ == DistributedMode::kSync) {
return true;
if (HeartBeatMonitor::GetInstance() != nullptr) {
HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED);
}
rpc_server_->Complete();
} else {
// Async
if (distributed_mode_ != DistributedMode::kSync) {
VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
std::string run_varname = varname;
string::Piece part_piece("@PIECE");
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, part_piece)) {
auto varname_splits = paddle::string::Split(varname, '@');
PADDLE_ENFORCE_EQ(varname_splits.size(), 3);
run_varname = varname_splits[0];
scope->Rename(varname, run_varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
auto &grad_slr =
auto& grad_slr =
scope->FindVar(run_varname)->Get<framework::SelectedRows>();
AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
grad_slr.rows());
}
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
scope);
return true;
} else { // sync
rpc_server_->WaitCond(kRequestSend);
VLOG(3) << "sync: processing received var: " << varname;
PADDLE_ENFORCE_NOT_NULL(
invar, platform::errors::NotFound(
"sync: Can not find server side var %s.", varname));
}
}
return true;
}
bool RequestGetHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(3) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
<< " table_name: " << table_name;
if (distributed_mode_ == DistributedMode::kSync) {
if (varname == FETCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv fetch barrier message";
rpc_server_->IncreaseBatchBarrier(kRequestGet);
} else {
rpc_server_->WaitCond(kRequestGet);
*outvar = scope_->FindVar(varname);
}
} else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) {
// NOTE: the format is determined by distribute_transpiler.py
std::string param_bak_name =
......@@ -111,7 +138,12 @@ bool RequestGetHandler::Handle(const std::string &varname,
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
VLOG(1) << "Table name empty? " << table_name.empty();
if (distributed_mode_ == DistributedMode::kGeo) {
VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
<< AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
varname);
}
if (distributed_mode_ == DistributedMode::kGeo &&
AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
!table_name.empty()) {
......@@ -121,31 +153,28 @@ bool RequestGetHandler::Handle(const std::string &varname,
if (VLOG_IS_ON(3)) {
std::ostringstream sstream;
sstream << "[";
for (auto &row_id : updated_rows) {
for (auto& row_id : updated_rows) {
sstream << row_id << ", ";
}
sstream << "]";
VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
<< sstream.str();
}
auto &origin_tensor =
auto& origin_tensor =
scope_->FindVar(varname)->Get<framework::LoDTensor>();
auto *origin_tensor_data = origin_tensor.data<float>();
auto &dims = origin_tensor.dims();
auto* origin_tensor_data = origin_tensor.data<float>();
auto& dims = origin_tensor.dims();
*outvar = scope->Var();
auto *out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
out_slr->set_rows(updated_rows);
out_slr->set_height(dims[0]);
auto out_dims = framework::make_ddim(
{static_cast<int64_t>(updated_rows.size()), dims[1]});
auto *data = out_slr->mutable_value()->mutable_data<float>(
auto* data = out_slr->mutable_value()->mutable_data<float>(
out_dims, origin_tensor.place());
auto width = dims[1];
for (size_t i = 0; i < updated_rows.size(); ++i) {
PADDLE_ENFORCE_LT(updated_rows[i], dims[0],
platform::errors::OutOfRange(
"expected >= 0 and < %ld, but got %ld.", dims[0],
updated_rows[i]));
PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
sizeof(float) * width);
}
......@@ -153,16 +182,17 @@ bool RequestGetHandler::Handle(const std::string &varname,
*outvar = scope_->FindVar(varname);
}
}
}
return true;
}
bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
......@@ -177,19 +207,18 @@ bool RequestGetNoBarrierHandler::Handle(const std::string &varname,
*outvar = scope_->FindVar(var_name_piece.ToString());
return true;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE));
PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestPrefetchHandler " << varname;
if (table_name.empty()) {
......@@ -207,20 +236,19 @@ bool RequestPrefetchHandler::Handle(const std::string &varname,
return true;
}
bool RequestCheckpointHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
PADDLE_ENFORCE_NE(
checkpoint_notify_id, -1,
platform::errors::Unavailable(
"when checkpoint_notify_id = -1, there should be no RPC invoke."));
const std::string& out_var_name,
const std::string& table_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
// TODO(tangwei12): find out why scope will be error.
auto *lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
......@@ -229,56 +257,33 @@ bool RequestCheckpointHandler::Handle(const std::string &varname,
return true;
}
bool RequestNotifyHandler::Handle(const std::string &varname,
framework::Scope *scope,
framework::Variable *invar,
framework::Variable **outvar,
bool RequestNotifyHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestNotifyHandler: " << varname;
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
string::Piece batch_piece(BATCH_BARRIER_MESSAGE);
string::Piece fetch_piece(FETCH_BARRIER_MESSAGE);
string::Piece complete_piece(COMPLETE_MESSAGE);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, batch_piece)) {
return BarrierMonitor::GetInstance()->IncreaseBarrier(
trainer_id, BATCH_BARRIER_MESSAGE);
} else if (string::Contains(var_name_piece, fetch_piece)) {
return BarrierMonitor::GetInstance()->IncreaseBarrier(
trainer_id, FETCH_BARRIER_MESSAGE);
} else if (string::Contains(var_name_piece, complete_piece)) {
if (HeartBeatMonitor::GetInstance() != nullptr) {
HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED);
}
rpc_server_->Complete();
BarrierMonitor::GetInstance()->DecreaseWorker();
return true;
} else if (string::Contains(var_name_piece, decay_piece)) {
if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1,
platform::errors::InvalidArgument(
"when lr_decay_block_id = -1, there should be no RPC invoke."));
auto *origin_var = scope_->FindVar(varname);
"when lr_decay_block_id = -1, there should be no RPC invoke.");
auto* origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto *send_var = scope->FindVar(varname);
auto* send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t *origin_value =
int64_t* origin_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t *send_value =
int64_t* send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0];
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
return true;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"unkown varname %s with RequestNotifyHandler", varname));
}
return true;
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -21,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
......@@ -117,7 +119,6 @@ void StartServer(const std::string& rpc_name) {
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
distributed::HeartBeatMonitor::Init(2, true, "w@grad");
distributed::BarrierMonitor::Init(2);
g_req_handler->SetRPCServer(g_rpc_service.get());
......@@ -163,9 +164,6 @@ TEST(PREFETCH, CPU) {
}
}
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Stop();
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
......@@ -176,24 +174,20 @@ TEST(PREFETCH, CPU) {
TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestNotifyHandler(
distributed::DistributedMode::kSync, -1));
g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestNotify);
std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
client->AsyncSendComplete(ep);
client->Wait();
auto* barrier = distributed::BarrierMonitor::GetInstance();
EXPECT_EQ(barrier->GetWorkerNum(), 1);
barrier->Stop();
EXPECT_EQ(g_rpc_service->GetClientNum(), 1);
g_rpc_service->ShutDown();
server_thread.join();
......
......@@ -2,9 +2,9 @@ include(operators)
set(DISTRIBUTE_DEPS "")
if(WITH_GRPC)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder grpc++_unsecure grpc_unsecure gpr zlib protobuf node)
else()
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv barrier_monitor communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node)
set(DISTRIBUTE_DEPS sendrecvop_rpc parameter_send parameter_recv communicator async_sparse_param_update_recorder brpc leveldb protobuf ssl crypto zlib node)
if(WITH_BRPC_RDMA)
find_library(IBVERBS_LIBRARY NAMES ibverbs)
ADD_LIBRARY(ibverbs SHARED IMPORTED GLOBAL)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -18,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/platform/nccl_helper.h"
......@@ -28,16 +30,16 @@ namespace operators {
class GenNCCLIdOp : public framework::OperatorBase {
public:
GenNCCLIdOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
// put nccl id in CPUPlace
auto &dev_ctx = *pool.Get(platform::CPUPlace());
auto& dev_ctx = *pool.Get(platform::CPUPlace());
int trainer_id = Attr<int>("trainer_id");
std::vector<std::string> trainers =
......@@ -53,7 +55,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::string endpoint = trainers[trainer_id];
framework::Scope &local_scope = scope.NewScope();
framework::Scope& local_scope = scope.NewScope();
int nccl_comm_num = Attr<int>("nccl_comm_num");
int use_hierarchical_allreduce = Attr<bool>("use_hierarchical_allreduce");
......@@ -169,10 +171,10 @@ class GenNCCLIdOp : public framework::OperatorBase {
}
private:
void GenerateAndSend(framework::Scope *scope,
const platform::DeviceContext &dev_ctx,
const std::string &nccl_id_name,
const std::vector<std::string> &endpoint_list) const {
void GenerateAndSend(framework::Scope* scope,
const platform::DeviceContext& dev_ctx,
const std::string& nccl_id_name,
const std::vector<std::string>& endpoint_list) const {
auto var = scope->FindVar(nccl_id_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Variable with name %s is not found",
......@@ -180,96 +182,76 @@ class GenNCCLIdOp : public framework::OperatorBase {
auto id = var->GetMutable<ncclUniqueId>();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(id));
distributed::RPCClient *client =
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto &ep : endpoint_list) {
for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl_id_var:" << nccl_id_name << " to " << ep;
client->AsyncSendVar(ep, dev_ctx, *scope, nccl_id_name);
}
client->Wait();
for (auto &ep : endpoint_list) {
for (auto& ep : endpoint_list) {
client->AsyncSendBatchBarrier(ep);
}
client->Wait();
VLOG(3) << "sending completed...";
}
void GetIdByServer(const std::string &endpoint, framework::Scope *scope,
const platform::DeviceContext &dev_ctx, int nccl_comm_num,
void GetIdByServer(const std::string& endpoint, framework::Scope* scope,
const platform::DeviceContext& dev_ctx, int nccl_comm_num,
bool use_hierarchical_allreduce, int trainer_id,
int inter_trainer_id, int exter_trainer_id) const {
// std::string endpoint = Attr<std::string>("endpoint");
// NOTE: Can not use unique_ptr here because the default
// deleter will call GRPC Server's base class's dtor and
// that will cause a wired crash.
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
std::unique_ptr<distributed::RPCServer> rpc_service(
new RPCSERVER_T(endpoint, 1));
distributed::RequestSendHandler rpc_h(distributed::DistributedMode::kSync);
distributed::RequestNotifyHandler notify_h(
distributed::DistributedMode::kSync, -1);
rpc_service->RegisterRPC(distributed::kRequestSend, &rpc_h);
rpc_service->RegisterRPC(distributed::kRequestNotify, &notify_h);
rpc_h.SetRPCServer(rpc_service.get());
framework::ProgramDesc empty_program;
framework::Executor executor(dev_ctx.GetPlace());
rpc_h.SetRPCServer(rpc_service.get());
rpc_h.SetScope(scope);
rpc_h.SetDevCtx(&dev_ctx);
rpc_h.SetProgram(&empty_program);
rpc_h.SetExecutor(&executor);
notify_h.SetRPCServer(rpc_service.get());
notify_h.SetScope(scope);
notify_h.SetDevCtx(&dev_ctx);
notify_h.SetProgram(&empty_program);
notify_h.SetExecutor(&executor);
distributed::BarrierMonitor::Init(1);
auto *barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, rpc_service.get()));
for (int i = 0; i < nccl_comm_num; i++) {
barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "trainer_id:" << trainer_id
<< " start getting nccl id from trainer 0, nccl_comm_no:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
if (use_hierarchical_allreduce) {
if (inter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) {
barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3) << "trainer_id:" << trainer_id
<< ", inter_trainer_id:" << inter_trainer_id
<< " start getting nccl id from inter_trainer:" << i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
}
if (exter_trainer_id > 0) {
for (int i = 0; i < nccl_comm_num; i++) {
barrier->WaitServerWeakup();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
barrier->ServerWeakup();
rpc_service->SetCond(distributed::kRequestSend);
VLOG(3)
<< "trainer_id:" << trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< " start getting nccl id from exter_trainer 0, nccl_comm_no:"
<< i;
rpc_service->WaitBarrier(distributed::kRequestSend);
rpc_service->ResetBarrierCounter();
}
}
}
......@@ -278,7 +260,6 @@ class GenNCCLIdOp : public framework::OperatorBase {
<< ", inter_trainer_id:" << inter_trainer_id
<< ", exter_trainer_id:" << exter_trainer_id
<< " got nccl id and stop server...";
barrier->Stop();
rpc_service->ShutDown();
VLOG(3) << "rpc server stopped";
server_thread.join();
......@@ -291,6 +272,7 @@ class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces.");
AddComment(R"DOC(
GenNCCLId operator
For trainer 0: generate a new UniqueId and send it to all the other trainers.
For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server.
)DOC");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -22,7 +25,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
......@@ -36,13 +38,10 @@ DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
namespace paddle {
namespace operators {
volatile sig_atomic_t gSignalStatus;
void RunServer(std::shared_ptr<distributed::RPCServer> service) {
service->StartServer();
VLOG(4) << "RunServer thread end";
}
static void split(const std::string &str, char sep,
std::vector<std::string> *pieces) {
pieces->clear();
......@@ -127,7 +126,6 @@ void ListenAndServOp::RunSyncLoop(
for (size_t i = 1; i < program->Size(); ++i) {
optimize_blocks_list.push_back(i);
}
auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list);
// Insert placeholder for block0 which holds current op itself,
// NOTE the first block in `optimize_prepared` should never be ran.
......@@ -137,15 +135,21 @@ void ListenAndServOp::RunSyncLoop(
// Trainers will get all parameters from pserver in the
// startup program, so we will wait RequestGet first
auto *barrier = distributed::BarrierMonitor::GetInstance();
rpc_service_->SetCond(distributed::kRequestGet);
rpc_service_->WaitBarrier(distributed::kRequestGet);
rpc_service_->ResetBarrierCounter();
while (true) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
barrier->WaitServerWeakup();
VLOG(3) << "wait all clients to send gradient";
rpc_service_->SetCond(distributed::kRequestSend);
VLOG(3) << "wait all clients to send send_barrier";
rpc_service_->WaitBarrier(distributed::kRequestSend);
if (gSignalStatus != 0) {
if (rpc_service_->IsExit()) {
LOG(WARNING) << "get exit!rpc_processor break!";
rpc_service_->SetCond(distributed::kRequestGet);
break;
}
......@@ -176,8 +180,12 @@ void ListenAndServOp::RunSyncLoop(
VLOG(3) << "ResetReceivedVars";
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
barrier->ServerWeakup();
VLOG(3) << "kRecvBarrier to push params to trainers";
VLOG(3) << "wait all clients to get parameters back";
rpc_service_->SetCond(distributed::kRequestGet);
VLOG(3) << "wait all clients to send fetch_barrier";
rpc_service_->WaitBarrier(distributed::kRequestGet);
VLOG(3) << "ResetBarrierCounter";
rpc_service_->ResetBarrierCounter();
} // while(true)
}
......@@ -273,7 +281,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
while (true) {
if (gSignalStatus != 0) {
if (rpc_service_->IsExit()) {
VLOG(4) << "get exit!rpc_processor break!";
break;
}
......@@ -383,7 +391,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), fan_in * 2);
request_notify_handler_.get(), rpc_send_thread_num);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......@@ -432,7 +440,6 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
prefetch_var_name_to_prepared_ctx;
for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
auto block_id = prefetch_block_id_list[i];
auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
......@@ -441,10 +448,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// parse attr of kSparseGradToParam sparse_grad_name -> param_name
std::unordered_map<std::string, std::string> sparse_grad_name_to_param_name;
auto sparse_grad_name_to_param_name_str =
Attr<std::vector<std::string>>(kSparseGradToParam);
for (const auto &sparse_grad_name_and_param_name :
sparse_grad_name_to_param_name_str) {
std::vector<std::string> pieces;
......@@ -472,18 +477,17 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
signal(SIGINT, SignalHandler::StopAndExit);
signal(SIGTERM, SignalHandler::StopAndExit);
distributed::BarrierMonitor::Init(fan_in);
if (distributed_mode == distributed::DistributedMode::kSync) {
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
VLOG(3) << "wait server thread to become ready...";
rpc_service_->WaitServerReady();
// Write to a file of server selected port for python use.
SavePort();
CacheVarsType(inputs, recv_scope);
// Write to a file of server selected port for python use.
SavePort();
RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
prefetch_block_id_list, checkpoint_block_id);
} else {
......@@ -570,8 +574,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
void SignalHandler::StopAndExit(int signal_num) {
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
distributed::BarrierMonitor::GetInstance()->Stop();
gSignalStatus = signal_num;
auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
remove(file_path.c_str());
exit(0);
}
} // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -17,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/distributed/barrier_monitor.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
......@@ -40,7 +42,6 @@ namespace string = paddle::string;
std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler;
std::unique_ptr<distributed::RequestNotifyHandler> g_notify_handler;
void StartServer() {
f::Scope scope;
......@@ -51,35 +52,21 @@ void StartServer() {
f::ProgramDesc empty_program;
f::Executor executor(dev_ctx.GetPlace());
g_req_handler->SetScope(&scope);
g_req_handler->SetDevCtx(&dev_ctx);
g_req_handler->SetProgram(&empty_program);
g_req_handler->SetExecutor(&executor);
g_req_handler->SetRPCServer(g_rpc_service.get());
g_notify_handler.SetRPCServer(rpc_service.get());
g_notify_handler.SetScope(scope);
g_notify_handler.SetDevCtx(&dev_ctx);
g_notify_handler.SetProgram(&empty_program);
g_notify_handler.SetExecutor(&executor);
g_rpc_service->RegisterRPC(distributed::kRequestSend, g_req_handler.get());
g_rpc_service->RegisterRPC(distributed::RequestNotifyHandler,
g_notify_handler.get());
distributed::BarrierMonitor::Init(1);
auto* barrier = distributed::BarrierMonitor::GetInstance();
barrier->Reset(1, distributed::BarrierType::kSendBarrier);
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
barrier->WaitServerWeakup();
barrier->ServerWeakup();
g_rpc_service->SetCond(distributed::kRequestSend);
g_rpc_service->WaitBarrier(distributed::kRequestSend);
LOG(INFO) << "got nccl id and stop server...";
barrier->Stop();
g_rpc_service->ShutDown();
server_thread.join();
}
......@@ -87,10 +74,6 @@ void StartServer() {
TEST(SendNcclId, RPCServer) {
g_req_handler.reset(
new distributed::RequestSendHandler(distributed::DistributedMode::kSync));
g_notify_handler.reset(new distributed::RequestNotifyHandler(
distributed::DistributedMode::kSync, -1));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
std::thread server_thread(StartServer);
......@@ -121,5 +104,4 @@ TEST(SendNcclId, RPCServer) {
server_thread.join();
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
g_notify_handler.reset(nullptr);
}
......@@ -94,7 +94,7 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase):
current_id=0,
role=role_maker.Role.WORKER
if training_role == "TRAINER" else role_maker.Role.SERVER,
worker_num=1,
worker_num=2,
server_endpoints=["127.0.0.1:6002"])
if training_role == "TRAINER":
......
......@@ -937,11 +937,6 @@ class TestDistBase(unittest.TestCase):
need_envs={},
log_name=""):
print(
"disable distributed unittests temporary, will enable it soon. (tangwei)"
)
return
required_envs = self._get_required_envs(check_error_log, need_envs)
local_losses \
......@@ -982,11 +977,6 @@ class TestDistBase(unittest.TestCase):
need_envs={},
log_name=""):
print(
"disable distributed unittests temporary, will enable it soon. (tangwei)"
)
return
# need open p2p or shm otherwise multi cards mode will hang
need_envs.update({"NCCL_P2P_DISABLE": "0", "NCCL_SHM_DISABLE": "0"})
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test f1 listen and serv_op."""
from __future__ import print_function
import paddle
import paddle.fluid as fluid
from paddle.fluid import Program
import os
import signal
import subprocess
import time
import unittest
from multiprocessing import Process
from op_test import OpTest
import numpy
import urllib
import sys
from dist_test_utils import *
cache_path = os.path.expanduser('~/.cache/paddle/dataset')
def run_trainer(use_cuda, sync_mode, ip, port, trainers, trainer_id):
'''
This function is run trainer.
Args:
use_cuda (bool): whether use cuda.
sync_mode (nouse): specify sync mode.
ip (string): the ip address.
port (string): the port for listening.
trainers (int): the count of trainer.
trainer_id (int): the id of trainer.
Returns:
None
'''
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# loss function
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
# optimizer
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
with open("{}/trainer_recv_program.dms".format(cache_path), "rb") as f:
trainer_recv_program_desc_str = f.read()
with open("{}/trainer_main_program.dms".format(cache_path), "rb") as f:
trainer_main_program_desc_str = f.read()
with open("{}/trainer_send_program.dms".format(cache_path), "rb") as f:
trainer_send_program_desc_str = f.read()
recv_program = Program.parse_from_string(trainer_recv_program_desc_str)
main_program = Program.parse_from_string(trainer_main_program_desc_str)
send_program = Program.parse_from_string(trainer_send_program_desc_str)
trainer_startup_program = fluid.default_startup_program()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(trainer_startup_program)
for i in range(5):
exe.run(recv_program)
exe.run(fluid.default_main_program(),
feed={
"x": numpy.array([1, 2]).astype('float32').reshape(2, 1),
"y": numpy.array([2, 3]).astype('float32').reshape(2, 1)
})
exe.run(send_program)
def run_pserver(use_cuda, sync_mode, ip, port, trainers, trainer_id):
'''
This function is run trainer.
Args:
use_cuda (bool): whether use cuda.
sync_mode (nouse): specify sync mode.
ip (string): the ip address.
port (string): the port for listening.
trainers (int): the count of trainer.
trainer_id (int): the id of trainer.
Returns:
None
'''
remove_ps_flag(os.getpid())
x = fluid.layers.data(name='x', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
# loss function
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
# optimizer
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(avg_cost)
with open("{}/pserver_startup_program.dms".format(cache_path), "rb") as f:
pserver_startup_program_desc_str = f.read()
with open("{}/pserver_main_program.dms".format(cache_path), "rb") as f:
pserver_main_program_desc_str = f.read()
startup_program = Program.parse_from_string(
pserver_startup_program_desc_str)
main_program = Program.parse_from_string(pserver_main_program_desc_str)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_program)
exe.run(main_program)
class TestFlListenAndServOp(unittest.TestCase):
"""This class is Test Fl Listen And ServOp."""
def setUp(self):
"""This function si set Up."""
self.ps_timeout = 5
self.ip = "127.0.0.1"
self.port = "6000"
self.trainers = 2
self.trainer_id = 0
def _start_pserver(self, use_cuda, sync_mode, pserver_func):
"""This function is start pserver."""
p = Process(
target=pserver_func,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id))
p.daemon = True
p.start()
return p
def _start_trainer0(self, use_cuda, sync_mode, pserver_func):
"""This function is start trainer0."""
p = Process(
target=pserver_func,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, 0))
p.daemon = True
p.start()
return p
def _start_trainer1(self, use_cuda, sync_mode, pserver_func):
"""This function is start trainer1."""
p = Process(
target=pserver_func,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers, 1))
p.daemon = True
p.start()
return p
def _wait_ps_ready(self, pid):
"""This function is wait ps ready."""
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def test_rpc_interfaces(self):
"""TODO(Yancey1989): need to make sure the rpc interface correctly."""
# TODO(Yancey1989): need to make sure the rpc interface correctly.
pass
def test_handle_signal_in_serv_op(self):
"""run pserver on CPU in sync mode."""
# run pserver on CPU in sync mode
if sys.platform == 'win32' or sys.platform == 'sys.platform':
pass
else:
print(sys.platform)
file_list = [
'pserver_startup_program.dms', 'pserver_main_program.dms',
'trainer_recv_program.dms', 'trainer_main_program.dms',
'trainer_send_program.dms'
]
if not os.path.exists(cache_path):
os.makedirs(cache_path)
prefix = 'wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/'
for f in file_list:
if not os.path.exists('{}/{}'.format(cache_path, f)):
cmd = "wget --no-check-certificate https://paddlefl.bj.bcebos.com/test_fl_listen_and_serv/{} -P {}/".format(
f, cache_path)
os.system(cmd)
p1 = self._start_pserver(False, True, run_pserver)
self._wait_ps_ready(p1.pid)
time.sleep(5)
t1 = self._start_trainer0(False, True, run_trainer)
time.sleep(2)
t2 = self._start_trainer1(False, True, run_trainer)
# raise SIGTERM to pserver
time.sleep(2)
cmd_del = "rm trainer*dms* pserver*dms*"
os.system(cmd_del)
os.kill(p1.pid, signal.SIGINT)
p1.join()
os.kill(t1.pid, signal.SIGINT)
t1.join()
os.kill(t2.pid, signal.SIGINT)
t2.join()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册