未验证 提交 781d2844 编写于 作者: 1 123malin 提交者: GitHub

Optimize decay (#20816) (#20952)

* update pserver decay blocks

* update distributed notify handler
上级 55c2329a
...@@ -62,8 +62,16 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -62,8 +62,16 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("sections")); node->Op()->GetNullableAttr("sections"));
auto trainer_id = auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id")); boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
boost::get<bool>(node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler =
boost::get<bool>(node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id); send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: " VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name]; << send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") { } else if (node->Name() == "recv") {
......
...@@ -130,8 +130,15 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program, ...@@ -130,8 +130,15 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
auto height_section = auto height_section =
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections")); boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id")); auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
auto merge_add = boost::get<bool>(op->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler =
boost::get<bool>(op->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext( send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id); send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: " VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name]; << send_varname_to_ctx[send_var_name];
} else if (op->Type() == "recv") { } else if (op->Type() == "recv") {
...@@ -208,12 +215,17 @@ void AsyncCommunicator::SendThread() { ...@@ -208,12 +215,17 @@ void AsyncCommunicator::SendThread() {
} }
} }
auto before_merge = GetCurrentUS(); auto before_merge = GetCurrentUS();
MergeVars(var_name, vars, send_scope_.get()); auto &ctx = send_varname_to_ctx_.at(var_name);
if (ctx.use_send_handler) {
MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
} else {
MergeVars<int64_t>(var_name, vars, send_scope_.get(),
ctx.merge_add);
}
auto after_merge = GetCurrentUS(); auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << merged_var_num << " " << var_name VLOG(3) << "merge " << merged_var_num << " " << var_name
<< " use time " << after_merge - before_merge; << " use time " << after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>(); auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) { if (!FLAGS_communicator_fake_rpc) {
send_functor(ctx, *send_scope_, true, 1); send_functor(ctx, *send_scope_, true, 1);
} }
......
...@@ -107,21 +107,21 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -107,21 +107,21 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
inline void MergeVars(const std::string& var_name, inline void MergeVars(const std::string& var_name,
const std::vector<std::shared_ptr<Variable>>& vars, const std::vector<std::shared_ptr<Variable>>& vars,
Scope* scope) { Scope* scope, bool merge_add = true) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!"); PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace(); auto cpu_place = platform::CPUPlace();
auto& var0 = vars[0]; auto& var0 = vars[0];
auto* out_var = scope->Var(var_name); auto* out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) { if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims(); auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims; VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
<< "; merge add: " << merge_add;
// init output tensor // init output tensor
auto* out_t = out_var->GetMutable<framework::LoDTensor>(); auto* out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<float>(dims, cpu_place); out_t->mutable_data<T>(dims, cpu_place);
// check the input dims // check the input dims
for (auto& var : vars) { for (auto& var : vars) {
auto& var_t = var->Get<framework::LoDTensor>(); auto& var_t = var->Get<framework::LoDTensor>();
...@@ -130,44 +130,41 @@ inline void MergeVars(const std::string& var_name, ...@@ -130,44 +130,41 @@ inline void MergeVars(const std::string& var_name,
// set output tensor to 0. // set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext(); auto cpu_ctx = paddle::platform::CPUDeviceContext();
math::SetConstant<paddle::platform::CPUDeviceContext, float> math::SetConstant<paddle::platform::CPUDeviceContext, T> constant_functor;
constant_functor; constant_functor(cpu_ctx, out_t, static_cast<T>(0));
constant_functor(cpu_ctx, out_t, static_cast<float>(0));
// sum all vars to out // sum all vars to out
auto result = EigenVector<float>::Flatten(*out_t); auto result = EigenVector<T>::Flatten(*out_t);
for (auto& var : vars) { for (auto& var : vars) {
auto& in_t = var->Get<framework::LoDTensor>(); auto& in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<float>::Flatten(in_t); auto in = EigenVector<T>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in; result.device(*cpu_ctx.eigen_device()) = result + in;
} }
if (!FLAGS_communicator_is_sgd_optimizer) { if (!merge_add) {
result.device(*cpu_ctx.eigen_device()) = result.device(*cpu_ctx.eigen_device()) =
result / static_cast<float>(vars.size()); result / static_cast<T>(vars.size());
} }
} else if (var0->IsType<framework::SelectedRows>()) { } else if (var0->IsType<framework::SelectedRows>()) {
auto& slr0 = var0->Get<framework::SelectedRows>(); auto& slr0 = var0->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>(); auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear(); out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place); out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows*> inputs; std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.reserve(vars.size()); inputs.reserve(vars.size());
for (auto& var : vars) { for (auto& var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>()); inputs.push_back(&var->Get<framework::SelectedRows>());
} }
auto dev_ctx = paddle::platform::CPUDeviceContext(); auto dev_ctx = paddle::platform::CPUDeviceContext();
if (FLAGS_communicator_is_sgd_optimizer) { if (merge_add) {
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float> math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, T> merge_add;
merge_add;
merge_add(dev_ctx, inputs, out_slr); merge_add(dev_ctx, inputs, out_slr);
} else { } else {
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float> math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, T>
merge_average; merge_average;
merge_average(dev_ctx, inputs, out_slr); merge_average(dev_ctx, inputs, out_slr);
} }
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height() VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims(); << " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
} else { } else {
PADDLE_THROW("unsupported var type!"); PADDLE_THROW("unsupported var type!");
} }
......
...@@ -47,7 +47,7 @@ TEST(communicator, merge_lod_tensors) { ...@@ -47,7 +47,7 @@ TEST(communicator, merge_lod_tensors) {
scope.reset(new framework::Scope()); scope.reset(new framework::Scope());
scope->Var(out_name); scope->Var(out_name);
for (auto i = 0; i < 10; ++i) { for (auto i = 0; i < 10; ++i) {
MergeVars(out_name, in_vars, scope.get()); MergeVars<float>(out_name, in_vars, scope.get());
} }
auto &out_tensor = scope->FindVar(out_name)->Get<LoDTensor>(); auto &out_tensor = scope->FindVar(out_name)->Get<LoDTensor>();
auto *out_data = out_tensor.data<float>(); auto *out_data = out_tensor.data<float>();
...@@ -86,7 +86,7 @@ TEST(communicator, merge_selected_rows) { ...@@ -86,7 +86,7 @@ TEST(communicator, merge_selected_rows) {
scope.reset(new framework::Scope()); scope.reset(new framework::Scope());
scope->Var(out_name); scope->Var(out_name);
for (auto i = 0; i < 10; ++i) { for (auto i = 0; i < 10; ++i) {
MergeVars(out_name, in_vars, scope.get()); MergeVars<float>(out_name, in_vars, scope.get());
} }
auto &out_slr = scope->FindVar(out_name)->Get<SelectedRows>(); auto &out_slr = scope->FindVar(out_name)->Get<SelectedRows>();
auto &out_t = out_slr.value(); auto &out_t = out_slr.value();
......
...@@ -438,26 +438,40 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -438,26 +438,40 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
return h; return h;
} }
VarHandlePtr GRPCClient::AsyncDistributeNotify(const std::string& ep, VarHandlePtr GRPCClient::AsyncDistributeNotify(
const std::string& type, const std::string& ep, const platform::DeviceContext& ctx,
int64_t time_out) { const framework::Scope& scope, const std::string& var_name,
const auto ch = GetChannel(ep); int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
DistributeNotifyProcessor* s = new DistributeNotifyProcessor(ch); const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const std::string method = kRequestNotify; const std::string method = kRequestNotify;
VarHandlePtr h( SendProcessor* s = new SendProcessor(ch);
new VarHandle(ep, method, LEARNING_RATE_DECAY_MESSAGE, nullptr, nullptr)); VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
sendrecv::VariableMessage req; framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
req.set_varname(type); auto* var = p_scope->FindVar(var_name_val);
platform::RecordRPCEvent record_event(method); ::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
auto rpc = s->stub_->AsyncDistributeNotify(s->context_.get(), req, &cq_); VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
// stub context
s->response_call_back_ = nullptr;
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
});
req_count_++; req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) { if (UNLIKELY(platform::IsProfileEnabled())) {
......
...@@ -173,20 +173,6 @@ class CheckpointNotifyProcessor : public BaseProcessor { ...@@ -173,20 +173,6 @@ class CheckpointNotifyProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
class DistributeNotifyProcessor : public BaseProcessor {
public:
explicit DistributeNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~DistributeNotifyProcessor() {}
void ProcessImpl() override {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class GRPCClient : public RPCClient { class GRPCClient : public RPCClient {
public: public:
GRPCClient() : ok_(true), completed_(false), stopped_(false) {} GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
...@@ -240,7 +226,8 @@ class GRPCClient : public RPCClient { ...@@ -240,7 +226,8 @@ class GRPCClient : public RPCClient {
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify( VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const std::string& type, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete( VarHandlePtr AsyncSendComplete(
......
...@@ -400,33 +400,31 @@ class RequestNotify final : public RequestBase { ...@@ -400,33 +400,31 @@ class RequestNotify final : public RequestBase {
RequestHandler* request_handler, int req_id) RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) { : RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(), request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx())); request_handler->dev_ctx(),
!request_handler->sync_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify); int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary( service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_, method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id))); reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
} }
virtual ~RequestNotify() {} virtual ~RequestNotify() {}
std::string GetReqName() override { return request_->Varname(); } std::string GetReqName() override { return request_->Varname(); }
void Process() override { void Process() override {
auto scope = request_->GetMutableLocalScope(); std::string varname = GetReqName();
VLOG(4) << "RequestNotify var_name:" << varname;
std::string varname = request_->Varname(); auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId(); int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr;
VLOG(4) << "RequestNotify notify: " << varname request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
<< ", trainer id: " << trainer_id;
request_handler_->Handle(varname, scope, nullptr, nullptr, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
protected: protected:
std::shared_ptr<GRPCVariableResponse> request_;
sendrecv::VoidMessage reply_; sendrecv::VoidMessage reply_;
std::shared_ptr<GRPCVariableResponse> request_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
}; };
......
...@@ -116,24 +116,44 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx, ...@@ -116,24 +116,44 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
row_offset += outs_dims[i][0]; row_offset += outs_dims[i][0];
} }
} }
if (rpc_ctx.use_send_handler) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) { for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i]; auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name; VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[i]; auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << "send var endpoint: " << endpoint; VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name); VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) { if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint; VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar( rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name)); endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done"; VLOG(4) << "send var " << send_var_name << " async handle done";
} else { } else {
VLOG(3) << "don't send non-initialized variable: " VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i]; << rpc_ctx.splited_var_names[i];
}
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: "
<< NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
} }
} }
} else if (send_var->IsType<framework::SelectedRows>()) { } else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>(); auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections); auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
......
...@@ -63,7 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; ...@@ -63,7 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV" #define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV" #define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY" #define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY" #define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
...@@ -262,11 +262,25 @@ bool RequestNotifyHandler::Handle(const std::string& varname, ...@@ -262,11 +262,25 @@ bool RequestNotifyHandler::Handle(const std::string& varname,
const int trainer_id, const int trainer_id,
const std::string& out_var_name, const std::string& out_var_name,
const std::string& table_name) { const std::string& table_name) {
VLOG(4) << "RequestNotifyHandler" << varname; VLOG(4) << "RequestNotifyHandler: " << varname;
if (varname == LEARNING_RATE_DECAY_MESSAGE) { VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
lr_decay_block_id, -1, lr_decay_block_id, -1,
"when lr_decay_block_id = -1, there should be no RPC invoke."); "when lr_decay_block_id = -1, there should be no RPC invoke.");
auto* origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto* send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t* origin_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t* send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0];
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_); executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
} }
return true; return true;
......
...@@ -81,7 +81,8 @@ class RPCClient { ...@@ -81,7 +81,8 @@ class RPCClient {
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify( virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const std::string& type, const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete( virtual VarHandlePtr AsyncSendComplete(
......
...@@ -27,12 +27,15 @@ struct RpcContext { ...@@ -27,12 +27,15 @@ struct RpcContext {
RpcContext(const std::string &name, const std::vector<std::string> &names, RpcContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap, const std::vector<std::string> &emap,
const std::vector<int64_t> &sections, int id) const std::vector<int64_t> &sections, int id,
bool merge_add_ = true, bool use_send_handler_ = true)
: var_name(name), : var_name(name),
splited_var_names(names), splited_var_names(names),
epmap(emap), epmap(emap),
height_sections(sections), height_sections(sections),
trainer_id(id) {} trainer_id(id),
merge_add(merge_add_),
use_send_handler(use_send_handler_) {}
RpcContext(const RpcContext &ctx) { RpcContext(const RpcContext &ctx) {
var_name = ctx.var_name; var_name = ctx.var_name;
...@@ -40,6 +43,8 @@ struct RpcContext { ...@@ -40,6 +43,8 @@ struct RpcContext {
epmap = ctx.epmap; epmap = ctx.epmap;
height_sections = ctx.height_sections; height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id; trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
use_send_handler = ctx.use_send_handler;
} }
std::string var_name; std::string var_name;
...@@ -47,6 +52,8 @@ struct RpcContext { ...@@ -47,6 +52,8 @@ struct RpcContext {
std::vector<std::string> epmap; std::vector<std::string> epmap;
std::vector<int64_t> height_sections; std::vector<int64_t> height_sections;
int trainer_id; int trainer_id;
bool merge_add;
bool use_send_handler;
}; };
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
...@@ -70,6 +77,9 @@ inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) { ...@@ -70,6 +77,9 @@ inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
os << section << ", "; os << section << ", ";
} }
os << "]\n"; os << "]\n";
os << "merge add: " << rpc_ctx.merge_add;
os << "; send handler: " << rpc_ctx.use_send_handler << "\n";
os << "}"; os << "}";
return os; return os;
} }
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
class DistributedNotifyOp : public framework::OperatorBase {
public:
DistributedNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string type = Attr<std::string>("type");
int trainer_id = Attr<int>("trainer_id");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
for (size_t i = 0; i < epmap.size(); i++) {
rpc_client->AsyncDistributeNotify(epmap[i], type);
VLOG(4) << "distribute notify sending : " << type << " to " << epmap[i];
}
PADDLE_ENFORCE_EQ(rpc_client->Wait(), true, "internal error in RPCClient");
}
};
class DistributedNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>("type",
"(string, default '') indicate the action type");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddComment(R"DOC(
DistributeNotify operator
This operator will send a signal to listen_and_serve op at
the parameter server.
)DOC");
}
};
class DistributedNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(distributed_notify, ops::DistributedNotifyOp,
paddle::framework::EmptyGradOpMaker,
ops::DistributedNotifyOpMaker,
ops::DistributedNotifyOpShapeInference);
...@@ -383,7 +383,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -383,7 +383,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier, rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get()); request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify, rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), 1); request_notify_handler_.get(),
FLAGS_rpc_send_thread_num);
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......
...@@ -45,6 +45,7 @@ class SendOp : public framework::OperatorBase { ...@@ -45,6 +45,7 @@ class SendOp : public framework::OperatorBase {
auto send_varnames = Attr<std::vector<std::string>>("send_varnames"); auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto height_sections = Attr<std::vector<int64_t>>("sections"); auto height_sections = Attr<std::vector<int64_t>>("sections");
auto use_send_handler = Attr<bool>("use_send_handler");
if (send_varnames.size() > 0) { if (send_varnames.size() > 0) {
if (ins.size() > 1) { if (ins.size() > 1) {
...@@ -62,13 +63,27 @@ class SendOp : public framework::OperatorBase { ...@@ -62,13 +63,27 @@ class SendOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { if (use_send_handler) {
if (NeedSend(scope, ins[i])) { for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; if (NeedSend(scope, ins[i])) {
rets.push_back( VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i])); rets.push_back(
} else { rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
VLOG(3) << "don't send no-initialied variable: " << ins[i]; } else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
} else {
for (size_t i = 0; i < ins.size(); i++) {
for (size_t j = 0; j < epmap.size(); j++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[j];
rets.push_back(rpc_client->AsyncDistributeNotify(epmap[j], ctx,
scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
} }
} }
for (size_t i = 0; i < rets.size(); i++) { for (size_t i = 0; i < rets.size(); i++) {
...@@ -113,6 +128,15 @@ This operator will send variables to listen_and_serve op at the parameter server ...@@ -113,6 +128,15 @@ This operator will send variables to listen_and_serve op at the parameter server
"Number of sub-tensors. This must evenly divide " "Number of sub-tensors. This must evenly divide "
"Input.dims()[axis]") "Input.dims()[axis]")
.SetDefault(0); .SetDefault(0);
AddAttr<bool>("merge_add",
"(bool, default 0)"
"merge method, true represent add, false represent average")
.SetDefault(false);
AddAttr<bool>(
"use_send_handler",
"(bool, default 1)"
"if it's true, use send handler, other wise, use notify handler")
.SetDefault(true);
} }
}; };
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import gc
import paddle.fluid as fluid
class TranspilerAsyncLRDecayTest(unittest.TestCase):
def setUp(self):
self.trainer_id = 0
self.trainers = 2
self.pservers = 2
# NOTE: we do not actually bind this port
self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175"
self.pserver1_ep = "127.0.0.1:6174"
self.pserver2_ep = "127.0.0.1:6175"
self.sync_mode = False
self.transpiler = None
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=0.1,
decay_steps=100,
decay_rate=0.99,
staircase=True))
sgd_optimizer.minimize(avg_cost)
def get_main_program(self):
main = fluid.Program()
main.random_seed = 1
with fluid.program_guard(main):
self.net_conf()
self.origin_prog = main.clone()
return main
def get_trainer(self, config=None):
src = fluid.default_startup_program().clone()
t = self._transpiler_instance(config)
trainer_main = t.get_trainer_program(wait_port=False)
trainer_startup = fluid.default_startup_program()
assert (src.num_blocks == 1)
assert (trainer_startup.num_blocks == src.num_blocks)
return trainer_main, trainer_startup
def get_pserver(self, ep, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode)
pserver = t.get_pserver_program(ep)
startup = t.get_startup_program(ep, pserver)
return pserver, startup
def _transpiler_instance(self, config=None, sync_mode=True):
if not self.transpiler:
main = self.get_main_program()
self.transpiler = fluid.DistributeTranspiler(config=config)
self.transpiler.transpile(
self.trainer_id,
program=main,
pservers=self.pserver_eps,
trainers=self.trainers,
sync_mode=sync_mode)
return self.transpiler
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep, sync_mode=False)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, sync_mode=False)
trainer, trainer_startup = self.get_trainer()
src = [op.type for op in trainer_startup.global_block().ops]
dst = ['fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', \
'uniform_random', 'recv', 'recv', 'fetch_barrier', 'concat']
self.assertEqual(src, dst)
self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'split_byref', 'send',
'send', 'recv', 'recv', 'concat'
])
self.assertEqual(len(pserver.blocks), 4)
# block0: listen_and_serv
self.assertEqual([op.type for op in pserver.blocks[0].ops],
["listen_and_serv"])
# block1: sum,cast,scale,floor,fill_constant,elementwise_pow,scale
self.assertEqual([op.type for op in pserver.blocks[1].ops], [
"sum", "cast", "scale", "floor", "fill_constant", "elementwise_pow",
"scale"
])
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[2].ops], ["sgd"])
# confirm startup program
self.assertEqual([op.type for op in startup.global_block().ops], [
"fill_constant", "fill_constant", "fill_constant", "fill_constant",
"uniform_random"
])
def test_transpiler(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.transpiler_test_impl()
# NOTE: run gc.collect to eliminate pybind side objects to
# prevent random double-deallocate when inherited in python.
del self.transpiler
del main
del startup
gc.collect()
if __name__ == "__main__":
unittest.main()
...@@ -41,7 +41,7 @@ import logging ...@@ -41,7 +41,7 @@ import logging
import numpy as np import numpy as np
from .ps_dispatcher import RoundRobin, PSDispatcher from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework, unique_name from .. import core, framework, unique_name, initializer
from ..framework import Program, default_main_program, \ from ..framework import Program, default_main_program, \
default_startup_program, Block, Parameter, grad_var_name default_startup_program, Block, Parameter, grad_var_name
from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed
...@@ -304,6 +304,7 @@ class DistributeTranspiler(object): ...@@ -304,6 +304,7 @@ class DistributeTranspiler(object):
PRINT_LOG = True PRINT_LOG = True
assert (self.config.min_block_size >= 8192) assert (self.config.min_block_size >= 8192)
assert (self.config.split_method.__bases__[0] == PSDispatcher) assert (self.config.split_method.__bases__[0] == PSDispatcher)
self.counter_var = None
def _transpile_nccl2(self, def _transpile_nccl2(self,
trainer_id, trainer_id,
...@@ -631,6 +632,7 @@ class DistributeTranspiler(object): ...@@ -631,6 +632,7 @@ class DistributeTranspiler(object):
np.random.shuffle(grad_var_mapping_items) np.random.shuffle(grad_var_mapping_items)
self.grad_name_to_send_dummy_out = dict() self.grad_name_to_send_dummy_out = dict()
for grad_varname, splited_vars in grad_var_mapping_items: for grad_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars) eplist = ps_dispatcher.dispatch(splited_vars)
...@@ -720,6 +722,31 @@ class DistributeTranspiler(object): ...@@ -720,6 +722,31 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
fetch_barrier_input.append(send_barrier_out) fetch_barrier_input.append(send_barrier_out)
else:
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0 and self.counter_var:
decay_dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
if self.config.runtime_split_send_recv:
## async mode, using communicator to merge and send
send_varnames = [self.counter_var.name]
else:
send_varnames = []
sections = []
program.global_block().append_op(
type="send",
inputs={"X": self.counter_var},
outputs={"Out": decay_dummy_output},
attrs={
"epmap": pserver_endpoints,
"sections": sections,
"send_varnames": send_varnames,
"merge_add": True,
"use_send_handler": False,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[self.counter_var.name, self.counter_var.name]
})
# step 3: insert recv op to receive parameters from parameter server # step 3: insert recv op to receive parameters from parameter server
recv_vars = [] recv_vars = []
...@@ -821,19 +848,6 @@ class DistributeTranspiler(object): ...@@ -821,19 +848,6 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
}) })
if not self.sync_mode:
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:
program.global_block().append_op(
type="distributed_notify",
inputs={},
outputs={},
attrs={
"epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
"type": "LRDECAY@RECV"
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist) self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
...@@ -2380,11 +2394,57 @@ class DistributeTranspiler(object): ...@@ -2380,11 +2394,57 @@ class DistributeTranspiler(object):
def _get_lr_ops(self): def _get_lr_ops(self):
lr_ops = [] lr_ops = []
block = self.origin_program.global_block() block = self.origin_program.global_block()
for op in block.ops: for index, op in enumerate(block.ops):
role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME)) role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \ if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \ role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
int(OPT_OP_ROLE_ATTR_VALUE): int(OPT_OP_ROLE_ATTR_VALUE):
if self.sync_mode == False and op.type == 'increment':
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op)
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op)
for key in outputs:
counter_var = outputs[key]
all_trainer_counter_inputs = [
self.origin_program.global_block().create_var(
name="%s.trainer_%d" % (counter_var.name, id_),
type=counter_var.type,
shape=counter_var.shape,
dtype=counter_var.dtype,
persistable=counter_var.persistable)
for id_ in range(self.trainer_num)
]
for i, op in enumerate(self.startup_program.global_block()
.ops):
if op.type == 'fill_constant':
for key in op.output_names:
if len(op.output(key)) == 1 and op.output(key)[
0] == counter_var.name:
self.startup_program.global_block().ops[
i]._set_attr(
'value',
float(0.0 - self.trainer_num))
for var in all_trainer_counter_inputs:
if var.name == "%s.trainer_%d" % (counter_var.name,
self.trainer_id):
self.counter_var = var
self.startup_program.global_block().create_var(
name=var.name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable,
initializer=initializer.Constant(1))
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
block._remove_op(index)
op = block._insert_op(
index,
type='sum',
inputs={'X': all_trainer_counter_inputs},
outputs=outputs,
attrs={op_role_attr_name: LR_SCHED_OP_ROLE_ATTR_VALUE})
lr_ops.append(op) lr_ops.append(op)
log("append lr op: ", op.type) log("append lr op: ", op.type)
return lr_ops return lr_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册