未验证 提交 20cdff0e 编写于 作者: 1 123malin 提交者: GitHub

Optimize decay (#20816)

* update pserver decay blocks

* update distributed notify handler
上级 16596f64
......@@ -62,8 +62,16 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
node->Op()->GetNullableAttr("sections"));
auto trainer_id =
boost::get<int>(node->Op()->GetNullableAttr("trainer_id"));
auto merge_add =
boost::get<bool>(node->Op()->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler =
boost::get<bool>(node->Op()->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id);
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (node->Name() == "recv") {
......
......@@ -130,8 +130,15 @@ void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
auto height_section =
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
auto merge_add = boost::get<bool>(op->GetNullableAttr("merge_add"));
if (!merge_add) {
merge_add = FLAGS_communicator_is_sgd_optimizer;
}
auto use_send_handler =
boost::get<bool>(op->GetNullableAttr("use_send_handler"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id);
send_var_name, send_varnames, epmap, height_section, trainer_id,
merge_add, use_send_handler);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (op->Type() == "recv") {
......@@ -208,12 +215,17 @@ void AsyncCommunicator::SendThread() {
}
}
auto before_merge = GetCurrentUS();
MergeVars(var_name, vars, send_scope_.get());
auto &ctx = send_varname_to_ctx_.at(var_name);
if (ctx.use_send_handler) {
MergeVars<float>(var_name, vars, send_scope_.get(), ctx.merge_add);
} else {
MergeVars<int64_t>(var_name, vars, send_scope_.get(),
ctx.merge_add);
}
auto after_merge = GetCurrentUS();
VLOG(3) << "merge " << merged_var_num << " " << var_name
<< " use time " << after_merge - before_merge;
auto send_functor = distributed::ParameterSend<float>();
auto &ctx = send_varname_to_ctx_.at(var_name);
if (!FLAGS_communicator_fake_rpc) {
send_functor(ctx, *send_scope_, true, 1);
}
......
......@@ -107,21 +107,21 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
inline void MergeVars(const std::string& var_name,
const std::vector<std::shared_ptr<Variable>>& vars,
Scope* scope) {
Scope* scope, bool merge_add = true) {
PADDLE_ENFORCE(!vars.empty(), "should have value to merge!");
auto cpu_place = platform::CPUPlace();
auto& var0 = vars[0];
auto* out_var = scope->Var(var_name);
if (var0->IsType<framework::LoDTensor>()) {
auto dims = var0->Get<framework::LoDTensor>().dims();
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims;
VLOG(3) << "merge " << var_name << " LoDTensor dims " << dims
<< "; merge add: " << merge_add;
// init output tensor
auto* out_t = out_var->GetMutable<framework::LoDTensor>();
out_t->mutable_data<float>(dims, cpu_place);
out_t->mutable_data<T>(dims, cpu_place);
// check the input dims
for (auto& var : vars) {
auto& var_t = var->Get<framework::LoDTensor>();
......@@ -130,44 +130,41 @@ inline void MergeVars(const std::string& var_name,
// set output tensor to 0.
auto cpu_ctx = paddle::platform::CPUDeviceContext();
math::SetConstant<paddle::platform::CPUDeviceContext, float>
constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<float>(0));
math::SetConstant<paddle::platform::CPUDeviceContext, T> constant_functor;
constant_functor(cpu_ctx, out_t, static_cast<T>(0));
// sum all vars to out
auto result = EigenVector<float>::Flatten(*out_t);
auto result = EigenVector<T>::Flatten(*out_t);
for (auto& var : vars) {
auto& in_t = var->Get<framework::LoDTensor>();
auto in = EigenVector<float>::Flatten(in_t);
auto in = EigenVector<T>::Flatten(in_t);
result.device(*cpu_ctx.eigen_device()) = result + in;
}
if (!FLAGS_communicator_is_sgd_optimizer) {
if (!merge_add) {
result.device(*cpu_ctx.eigen_device()) =
result / static_cast<float>(vars.size());
result / static_cast<T>(vars.size());
}
} else if (var0->IsType<framework::SelectedRows>()) {
auto& slr0 = var0->Get<framework::SelectedRows>();
auto* out_slr = out_var->GetMutable<framework::SelectedRows>();
out_slr->mutable_rows()->clear();
out_slr->mutable_value()->mutable_data<float>({{}}, cpu_place);
out_slr->mutable_value()->mutable_data<T>({{}}, cpu_place);
std::vector<const paddle::framework::SelectedRows*> inputs;
inputs.reserve(vars.size());
for (auto& var : vars) {
inputs.push_back(&var->Get<framework::SelectedRows>());
}
auto dev_ctx = paddle::platform::CPUDeviceContext();
if (FLAGS_communicator_is_sgd_optimizer) {
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, float>
merge_add;
if (merge_add) {
math::scatter::MergeAdd<paddle::platform::CPUDeviceContext, T> merge_add;
merge_add(dev_ctx, inputs, out_slr);
} else {
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, float>
math::scatter::MergeAverage<paddle::platform::CPUDeviceContext, T>
merge_average;
merge_average(dev_ctx, inputs, out_slr);
}
VLOG(3) << "merge " << var_name << " SelectedRows height: " << slr0.height()
<< " dims: " << slr0.value().dims();
<< " dims: " << slr0.value().dims() << "; merge add: " << merge_add;
} else {
PADDLE_THROW("unsupported var type!");
}
......
......@@ -47,7 +47,7 @@ TEST(communicator, merge_lod_tensors) {
scope.reset(new framework::Scope());
scope->Var(out_name);
for (auto i = 0; i < 10; ++i) {
MergeVars(out_name, in_vars, scope.get());
MergeVars<float>(out_name, in_vars, scope.get());
}
auto &out_tensor = scope->FindVar(out_name)->Get<LoDTensor>();
auto *out_data = out_tensor.data<float>();
......@@ -86,7 +86,7 @@ TEST(communicator, merge_selected_rows) {
scope.reset(new framework::Scope());
scope->Var(out_name);
for (auto i = 0; i < 10; ++i) {
MergeVars(out_name, in_vars, scope.get());
MergeVars<float>(out_name, in_vars, scope.get());
}
auto &out_slr = scope->FindVar(out_name)->Get<SelectedRows>();
auto &out_t = out_slr.value();
......
......@@ -438,26 +438,40 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
return h;
}
VarHandlePtr GRPCClient::AsyncDistributeNotify(const std::string& ep,
const std::string& type,
int64_t time_out) {
const auto ch = GetChannel(ep);
DistributeNotifyProcessor* s = new DistributeNotifyProcessor(ch);
VarHandlePtr GRPCClient::AsyncDistributeNotify(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const std::string method = kRequestNotify;
VarHandlePtr h(
new VarHandle(ep, method, LEARNING_RATE_DECAY_MESSAGE, nullptr, nullptr));
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(type);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
platform::RecordRPCEvent record_event(method);
::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
auto rpc = s->stub_->AsyncDistributeNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = nullptr;
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/DistributeNotify", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
});
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
......
......@@ -173,20 +173,6 @@ class CheckpointNotifyProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class DistributeNotifyProcessor : public BaseProcessor {
public:
explicit DistributeNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor() {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~DistributeNotifyProcessor() {}
void ProcessImpl() override {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class GRPCClient : public RPCClient {
public:
GRPCClient() : ok_(true), completed_(false), stopped_(false) {}
......@@ -240,7 +226,8 @@ class GRPCClient : public RPCClient {
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const std::string& type,
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete(
......
......@@ -402,33 +402,31 @@ class RequestNotify final : public RequestBase {
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(request_handler->scope(),
request_handler->dev_ctx()));
request_handler->dev_ctx(),
!request_handler->sync_mode()));
int method_id = static_cast<int>(distributed::GrpcMethod::kRequestNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestNotify() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
auto scope = request_->GetMutableLocalScope();
std::string varname = GetReqName();
VLOG(4) << "RequestNotify var_name:" << varname;
std::string varname = request_->Varname();
auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestNotify notify: " << varname
<< ", trainer id: " << trainer_id;
request_handler_->Handle(varname, scope, nullptr, nullptr, trainer_id);
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<GRPCVariableResponse> request_;
sendrecv::VoidMessage reply_;
std::shared_ptr<GRPCVariableResponse> request_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
......
......@@ -116,24 +116,44 @@ void ParameterSend<T>::operator()(const RpcContext &rpc_ctx,
row_offset += outs_dims[i][0];
}
}
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
if (rpc_ctx.use_send_handler) {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[i];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: " << NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncSendVar(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
} else {
for (size_t i = 0; i < rpc_ctx.splited_var_names.size(); i++) {
for (size_t j = 0; j < rpc_ctx.epmap.size(); j++) {
auto &send_var_name = rpc_ctx.splited_var_names[i];
VLOG(4) << "send var name: " << send_var_name;
auto &endpoint = rpc_ctx.epmap[j];
VLOG(4) << "send var endpoint: " << endpoint;
VLOG(4) << "need send: "
<< NeedSend(*local_scope.get(), send_var_name);
if (NeedSend(*local_scope.get(), send_var_name)) {
VLOG(3) << "sending " << send_var_name << " to " << endpoint;
rets.push_back(rpc_client->AsyncDistributeNotify(
endpoint, cpu_ctx, *local_scope.get(), send_var_name));
VLOG(4) << "send var " << send_var_name << " async handle done";
} else {
VLOG(3) << "don't send non-initialized variable: "
<< rpc_ctx.splited_var_names[i];
}
}
}
}
} else if (send_var->IsType<framework::SelectedRows>()) {
auto &send_slr = send_var->Get<framework::SelectedRows>();
auto abs_sections = ToAbsoluteSection(rpc_ctx.height_sections);
......
......@@ -63,7 +63,7 @@ constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV"
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
......@@ -262,11 +262,25 @@ bool RequestNotifyHandler::Handle(const std::string& varname,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestNotifyHandler" << varname;
if (varname == LEARNING_RATE_DECAY_MESSAGE) {
VLOG(4) << "RequestNotifyHandler: " << varname;
VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;
string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, decay_piece)) {
VLOG(3) << "LearningRate Decay Counter Update";
PADDLE_ENFORCE_NE(
lr_decay_block_id, -1,
"when lr_decay_block_id = -1, there should be no RPC invoke.");
auto* origin_var = scope_->FindVar(varname);
auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
auto* send_var = scope->FindVar(varname);
auto send_var_tensor = send_var->Get<framework::LoDTensor>();
int64_t* origin_value =
origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
int64_t* send_value =
send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
origin_value[0] += send_value[0];
executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
}
return true;
......
......@@ -81,7 +81,8 @@ class RPCClient {
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncDistributeNotify(
const std::string& ep, const std::string& type,
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete(
......
......@@ -27,12 +27,15 @@ struct RpcContext {
RpcContext(const std::string &name, const std::vector<std::string> &names,
const std::vector<std::string> &emap,
const std::vector<int64_t> &sections, int id)
const std::vector<int64_t> &sections, int id,
bool merge_add_ = true, bool use_send_handler_ = true)
: var_name(name),
splited_var_names(names),
epmap(emap),
height_sections(sections),
trainer_id(id) {}
trainer_id(id),
merge_add(merge_add_),
use_send_handler(use_send_handler_) {}
RpcContext(const RpcContext &ctx) {
var_name = ctx.var_name;
......@@ -40,6 +43,8 @@ struct RpcContext {
epmap = ctx.epmap;
height_sections = ctx.height_sections;
trainer_id = ctx.trainer_id;
merge_add = ctx.merge_add;
use_send_handler = ctx.use_send_handler;
}
std::string var_name;
......@@ -47,6 +52,8 @@ struct RpcContext {
std::vector<std::string> epmap;
std::vector<int64_t> height_sections;
int trainer_id;
bool merge_add;
bool use_send_handler;
};
inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
......@@ -70,6 +77,9 @@ inline std::ostream &operator<<(std::ostream &os, const RpcContext &rpc_ctx) {
os << section << ", ";
}
os << "]\n";
os << "merge add: " << rpc_ctx.merge_add;
os << "; send handler: " << rpc_ctx.use_send_handler << "\n";
os << "}";
return os;
}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
class DistributedNotifyOp : public framework::OperatorBase {
public:
DistributedNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string type = Attr<std::string>("type");
int trainer_id = Attr<int>("trainer_id");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
for (size_t i = 0; i < epmap.size(); i++) {
rpc_client->AsyncDistributeNotify(epmap[i], type);
VLOG(4) << "distribute notify sending : " << type << " to " << epmap[i];
}
PADDLE_ENFORCE_EQ(rpc_client->Wait(), true, "internal error in RPCClient");
}
};
class DistributedNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>("type",
"(string, default '') indicate the action type");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddComment(R"DOC(
DistributeNotify operator
This operator will send a signal to listen_and_serve op at
the parameter server.
)DOC");
}
};
class DistributedNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
distributed_notify, ops::DistributedNotifyOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::DistributedNotifyOpMaker, ops::DistributedNotifyOpShapeInference);
......@@ -383,7 +383,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), 1);
request_notify_handler_.get(),
FLAGS_rpc_send_thread_num);
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......
......@@ -45,6 +45,7 @@ class SendOp : public framework::OperatorBase {
auto send_varnames = Attr<std::vector<std::string>>("send_varnames");
auto height_sections = Attr<std::vector<int64_t>>("sections");
auto use_send_handler = Attr<bool>("use_send_handler");
if (send_varnames.size() > 0) {
if (ins.size() > 1) {
......@@ -62,13 +63,27 @@ class SendOp : public framework::OperatorBase {
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rets.push_back(
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
if (use_send_handler) {
for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
rets.push_back(
rpc_client->AsyncSendVar(epmap[i], ctx, scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
} else {
for (size_t i = 0; i < ins.size(); i++) {
for (size_t j = 0; j < epmap.size(); j++) {
if (NeedSend(scope, ins[i])) {
VLOG(3) << "sending " << ins[i] << " to " << epmap[j];
rets.push_back(rpc_client->AsyncDistributeNotify(epmap[j], ctx,
scope, ins[i]));
} else {
VLOG(3) << "don't send no-initialied variable: " << ins[i];
}
}
}
}
for (size_t i = 0; i < rets.size(); i++) {
......@@ -113,6 +128,15 @@ This operator will send variables to listen_and_serve op at the parameter server
"Number of sub-tensors. This must evenly divide "
"Input.dims()[axis]")
.SetDefault(0);
AddAttr<bool>("merge_add",
"(bool, default 0)"
"merge method, true represent add, false represent average")
.SetDefault(false);
AddAttr<bool>(
"use_send_handler",
"(bool, default 1)"
"if it's true, use send handler, other wise, use notify handler")
.SetDefault(true);
}
};
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import gc
import paddle.fluid as fluid
class TranspilerAsyncLRDecayTest(unittest.TestCase):
def setUp(self):
self.trainer_id = 0
self.trainers = 2
self.pservers = 2
# NOTE: we do not actually bind this port
self.pserver_eps = "127.0.0.1:6174,127.0.0.1:6175"
self.pserver1_ep = "127.0.0.1:6174"
self.pserver2_ep = "127.0.0.1:6175"
self.sync_mode = False
self.transpiler = None
def net_conf(self):
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
y_predict = fluid.layers.fc(input=x,
size=1000,
act=None,
param_attr=fluid.ParamAttr(name='fc_w'),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=fluid.layers.exponential_decay(
learning_rate=0.1,
decay_steps=100,
decay_rate=0.99,
staircase=True))
sgd_optimizer.minimize(avg_cost)
def get_main_program(self):
main = fluid.Program()
main.random_seed = 1
with fluid.program_guard(main):
self.net_conf()
self.origin_prog = main.clone()
return main
def get_trainer(self, config=None):
src = fluid.default_startup_program().clone()
t = self._transpiler_instance(config)
trainer_main = t.get_trainer_program(wait_port=False)
trainer_startup = fluid.default_startup_program()
assert (src.num_blocks == 1)
assert (trainer_startup.num_blocks == src.num_blocks)
return trainer_main, trainer_startup
def get_pserver(self, ep, config=None, sync_mode=True):
t = self._transpiler_instance(config, sync_mode)
pserver = t.get_pserver_program(ep)
startup = t.get_startup_program(ep, pserver)
return pserver, startup
def _transpiler_instance(self, config=None, sync_mode=True):
if not self.transpiler:
main = self.get_main_program()
self.transpiler = fluid.DistributeTranspiler(config=config)
self.transpiler.transpile(
self.trainer_id,
program=main,
pservers=self.pserver_eps,
trainers=self.trainers,
sync_mode=sync_mode)
return self.transpiler
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep, sync_mode=False)
pserver2, startup2 = self.get_pserver(self.pserver2_ep, sync_mode=False)
trainer, trainer_startup = self.get_trainer()
src = [op.type for op in trainer_startup.global_block().ops]
dst = ['fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', \
'uniform_random', 'recv', 'recv', 'fetch_barrier', 'concat']
self.assertEqual(src, dst)
self.assertEqual([op.type for op in trainer.global_block().ops], [
'mul', 'elementwise_add', 'elementwise_sub', 'square', 'mean',
'fill_constant', 'mean_grad', 'square_grad', 'elementwise_sub_grad',
'elementwise_add_grad', 'send', 'mul_grad', 'split_byref', 'send',
'send', 'recv', 'recv', 'concat'
])
self.assertEqual(len(pserver.blocks), 4)
# block0: listen_and_serv
self.assertEqual([op.type for op in pserver.blocks[0].ops],
["listen_and_serv"])
# block1: sum,cast,scale,floor,fill_constant,elementwise_pow,scale
self.assertEqual([op.type for op in pserver.blocks[1].ops], [
"sum", "cast", "scale", "floor", "fill_constant", "elementwise_pow",
"scale"
])
# block1~2: optimize pass
self.assertEqual([op.type for op in pserver.blocks[2].ops], ["sgd"])
# confirm startup program
self.assertEqual([op.type for op in startup.global_block().ops], [
"fill_constant", "fill_constant", "fill_constant", "fill_constant",
"uniform_random"
])
def test_transpiler(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
self.transpiler_test_impl()
# NOTE: run gc.collect to eliminate pybind side objects to
# prevent random double-deallocate when inherited in python.
del self.transpiler
del main
del startup
gc.collect()
if __name__ == "__main__":
unittest.main()
......@@ -41,7 +41,7 @@ import logging
import numpy as np
from .ps_dispatcher import RoundRobin, PSDispatcher
from .. import core, framework, unique_name
from .. import core, framework, unique_name, initializer
from ..framework import Program, default_main_program, \
default_startup_program, Block, Parameter, grad_var_name
from .details import wait_server_ready, UnionFind, VarStruct, VarsDistributed
......@@ -304,6 +304,7 @@ class DistributeTranspiler(object):
PRINT_LOG = True
assert (self.config.min_block_size >= 8192)
assert (self.config.split_method.__bases__[0] == PSDispatcher)
self.counter_var = None
def _transpile_nccl2(self,
trainer_id,
......@@ -631,6 +632,7 @@ class DistributeTranspiler(object):
np.random.shuffle(grad_var_mapping_items)
self.grad_name_to_send_dummy_out = dict()
for grad_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars)
......@@ -720,6 +722,31 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
fetch_barrier_input.append(send_barrier_out)
else:
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0 and self.counter_var:
decay_dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
if self.config.runtime_split_send_recv:
## async mode, using communicator to merge and send
send_varnames = [self.counter_var.name]
else:
send_varnames = []
sections = []
program.global_block().append_op(
type="send",
inputs={"X": self.counter_var},
outputs={"Out": decay_dummy_output},
attrs={
"epmap": pserver_endpoints,
"sections": sections,
"send_varnames": send_varnames,
"merge_add": True,
"use_send_handler": False,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME:
[self.counter_var.name, self.counter_var.name]
})
# step 3: insert recv op to receive parameters from parameter server
recv_vars = []
......@@ -821,19 +848,6 @@ class DistributeTranspiler(object):
RPC_OP_ROLE_ATTR_NAME: DIST_OP_ROLE_ATTR_VALUE
})
if not self.sync_mode:
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:
program.global_block().append_op(
type="distributed_notify",
inputs={},
outputs={},
attrs={
"epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
"type": "LRDECAY@RECV"
})
self._get_trainer_startup_program(recv_vars=recv_vars, eplist=eplist)
if self.has_distributed_lookup_table:
......@@ -2380,11 +2394,57 @@ class DistributeTranspiler(object):
def _get_lr_ops(self):
lr_ops = []
block = self.origin_program.global_block()
for op in block.ops:
for index, op in enumerate(block.ops):
role_id = int(op.attr(RPC_OP_ROLE_ATTR_NAME))
if role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) or \
role_id == int(LR_SCHED_OP_ROLE_ATTR_VALUE) | \
int(OPT_OP_ROLE_ATTR_VALUE):
if self.sync_mode == False and op.type == 'increment':
inputs = self._get_input_map_from_op(
self.origin_program.global_block().vars, op)
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op)
for key in outputs:
counter_var = outputs[key]
all_trainer_counter_inputs = [
self.origin_program.global_block().create_var(
name="%s.trainer_%d" % (counter_var.name, id_),
type=counter_var.type,
shape=counter_var.shape,
dtype=counter_var.dtype,
persistable=counter_var.persistable)
for id_ in range(self.trainer_num)
]
for i, op in enumerate(self.startup_program.global_block()
.ops):
if op.type == 'fill_constant':
for key in op.output_names:
if len(op.output(key)) == 1 and op.output(key)[
0] == counter_var.name:
self.startup_program.global_block().ops[
i]._set_attr(
'value',
float(0.0 - self.trainer_num))
for var in all_trainer_counter_inputs:
if var.name == "%s.trainer_%d" % (counter_var.name,
self.trainer_id):
self.counter_var = var
self.startup_program.global_block().create_var(
name=var.name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable,
initializer=initializer.Constant(1))
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName(
)
block._remove_op(index)
op = block._insert_op(
index,
type='sum',
inputs={'X': all_trainer_counter_inputs},
outputs=outputs,
attrs={op_role_attr_name: LR_SCHED_OP_ROLE_ATTR_VALUE})
lr_ops.append(op)
log("append lr op: ", op.type)
return lr_ops
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册