未验证 提交 306236c2 编写于 作者: W Wu Yi 提交者: GitHub

feature/DC asgd (#12722)

* wip

* add ref_by_trainer_id op

* ready to test

* fix ref inputs

* refine rpc_op_handle

* fix merge bug
上级 c3cbf0b8
...@@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, ...@@ -29,22 +29,19 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
place_(place) {} place_(place) {}
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString() if (ir::IsControlDepVar(*in->Node())) {
if (ir::IsControlDepVar(*in->Node())) { // HACK
continue; continue;
} }
if (in->GeneratedOp()) { if (in->GeneratedOp()) {
in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p)); in->GeneratedOp()->RecordWaitEventOnCtx(dev_ctxes_.at(p));
} }
} }
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); this->RunAndRecordEvent([this] {
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead op_->Run(*local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(),
// lock. place_);
op_->Run(*tmp_scope, place_); });
} }
std::string RPCOpHandle::Name() const { return name_; } std::string RPCOpHandle::Name() const { return name_; }
......
...@@ -85,8 +85,10 @@ Executor::Executor(const platform::Place& place) : place_(place) {} ...@@ -85,8 +85,10 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
void Executor::Close() { void Executor::Close() {
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
// TODO(typhoonzero): complete message will need to use real trainer_id,
// except 0.
::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>() ::paddle::operators::distributed::GRPCClient>(0)
->SendComplete(); ->SendComplete();
#endif #endif
} }
......
...@@ -38,9 +38,10 @@ class CheckpointNotifyOp : public framework::OperatorBase { ...@@ -38,9 +38,10 @@ class CheckpointNotifyOp : public framework::OperatorBase {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir"); std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table"); std::string lookup_table_name = Attr<std::string>("lookup_table");
int trainer_id = Attr<int>("trainer_id");
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
for (size_t i = 0; i < epmap.size(); i++) { for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir = auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i); string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
...@@ -63,6 +64,7 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -63,6 +64,7 @@ class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
"dir", "(string, default '') indicate the folder checkpoint will use"); "dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table", AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name"); "(string, default '') the lookup table name");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddComment(R"DOC( AddComment(R"DOC(
CheckpointNotify operator CheckpointNotify operator
......
...@@ -79,7 +79,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -79,7 +79,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
SerializeToByteBuffer(var_name_val, var, *p_ctx, &req); SerializeToByteBuffer(var_name_val, var, *p_ctx, &req, "", trainer_id_);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
...@@ -105,7 +105,10 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -105,7 +105,10 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h, void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) { const ::grpc::ByteBuffer& ret_msg) {
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar); // get response's trainer_id is not used
int trainer_id;
DeserializeFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
&trainer_id);
} }
template <typename T> template <typename T>
...@@ -135,6 +138,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -135,6 +138,7 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf; ::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf); RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
......
...@@ -34,8 +34,8 @@ namespace distributed { ...@@ -34,8 +34,8 @@ namespace distributed {
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, ::grpc::ByteBuffer* msg, const std::string& out_name,
const std::string& out_name) { const int trainer_id) {
platform::RecordRPCEvent record_event("serial", &ctx); platform::RecordRPCEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU // Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed. // the CPU buffer need to be freed.
...@@ -45,6 +45,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -45,6 +45,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
size_t payload_size; size_t payload_size;
request.set_varname(name); request.set_varname(name);
request.set_trainer_id(trainer_id);
// Note: normally the profiler is enabled in 1 trainer, hence only // Note: normally the profiler is enabled in 1 trainer, hence only
// 1 trainer returns true for ShouldSendProfileState(). It tells PS // 1 trainer returns true for ShouldSendProfileState(). It tells PS
// servers the trainer's profiling state so that PS can follow the // servers the trainer's profiling state so that PS can follow the
...@@ -147,11 +148,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -147,11 +148,12 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var) { framework::Variable** var, int* trainer_id) {
platform::RecordRPCEvent record_event("deserial", &ctx); platform::RecordRPCEvent record_event("deserial", &ctx);
operators::distributed::GRPCVariableResponse resp(scope, &ctx); operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar(); *var = resp.GetVar();
*trainer_id = resp.GetTrainerId();
} }
} // namespace distributed } // namespace distributed
......
...@@ -38,12 +38,13 @@ typedef void (*DestroyCallback)(void*); ...@@ -38,12 +38,13 @@ typedef void (*DestroyCallback)(void*);
void SerializeToByteBuffer(const std::string& name, framework::Variable* var, void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, ::grpc::ByteBuffer* msg,
const std::string& out_varname = std::string()); const std::string& out_varname = std::string(),
const int trainer_id = 0);
void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var); framework::Variable** var, int* trainer_id);
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
......
...@@ -102,9 +102,10 @@ class RequestSend final : public RequestBase { ...@@ -102,9 +102,10 @@ class RequestSend final : public RequestBase {
auto scope = request_->GetMutableLocalScope(); auto scope = request_->GetMutableLocalScope();
auto invar = request_->GetVar(); auto invar = request_->GetVar();
int trainer_id = request_->GetTrainerId();
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
...@@ -133,13 +134,14 @@ class RequestGet final : public RequestBase { ...@@ -133,13 +134,14 @@ class RequestGet final : public RequestBase {
void Process() override { void Process() override {
// proc request. // proc request.
std::string varname = request_.varname(); std::string varname = request_.varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << varname; VLOG(4) << "RequestGet " << varname;
auto scope = request_handler_->scope(); auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname); auto invar = scope->FindVar(varname);
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
if (outvar) { if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
...@@ -179,6 +181,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -179,6 +181,7 @@ class RequestPrefetch final : public RequestBase {
// prefetch process... // prefetch process...
std::string in_var_name = request_->Varname(); std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname(); std::string out_var_name = request_->OutVarname();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name VLOG(4) << "RequestPrefetch, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name; << " out_var_name: " << out_var_name;
...@@ -187,7 +190,8 @@ class RequestPrefetch final : public RequestBase { ...@@ -187,7 +190,8 @@ class RequestPrefetch final : public RequestBase {
// out var must be created in local scope! // out var must be created in local scope!
framework::Variable* outvar = scope->Var(out_var_name); framework::Variable* outvar = scope->Var(out_var_name);
request_handler_->Handle(in_var_name, scope, invar, &outvar, out_var_name); request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
...@@ -225,12 +229,13 @@ class RequestCheckpointNotify final : public RequestBase { ...@@ -225,12 +229,13 @@ class RequestCheckpointNotify final : public RequestBase {
std::string checkpoint_notify = request_->Varname(); std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname(); std::string checkpoint_dir = request_->OutVarname();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir; << ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr, request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir); trainer_id, checkpoint_dir);
Finish(reply_, &responder_); Finish(reply_, &responder_);
} }
......
...@@ -293,6 +293,14 @@ int GRPCVariableResponse::Parse(Source* source) { ...@@ -293,6 +293,14 @@ int GRPCVariableResponse::Parse(Source* source) {
} }
break; break;
} }
case sendrecv::VariableMessage::kTrainerIdFieldNumber: {
uint64_t trainer_id = 0;
if (!input.ReadVarint64(&trainer_id)) {
return tag;
}
meta_.set_trainer_id(trainer_id);
break;
}
default: { default: {
// Unknown tag, return unknown error. // Unknown tag, return unknown error.
return -1; return -1;
......
...@@ -190,6 +190,7 @@ class RequestHandler { ...@@ -190,6 +190,7 @@ class RequestHandler {
// } // }
virtual bool Handle(const std::string& varname, framework::Scope* scope, virtual bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") = 0; const std::string& out_var_name = "") = 0;
protected: protected:
......
...@@ -36,6 +36,7 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -36,6 +36,7 @@ bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
...@@ -76,6 +77,7 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -76,6 +77,7 @@ bool RequestGetHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestGetHandler:" << varname; VLOG(4) << "RequestGetHandler:" << varname;
if (sync_mode_) { if (sync_mode_) {
...@@ -88,6 +90,19 @@ bool RequestGetHandler::Handle(const std::string& varname, ...@@ -88,6 +90,19 @@ bool RequestGetHandler::Handle(const std::string& varname,
} }
} else { } else {
if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) { if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
if (enable_dc_asgd_) {
// NOTE: the format is determined by distributed_transpiler.py
std::string param_bak_name =
string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
auto var = scope_->FindVar(varname);
auto t_orig = var->Get<framework::LoDTensor>();
auto param_bak = scope_->Var(param_bak_name);
auto t = param_bak->GetMutable<framework::LoDTensor>();
t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
VLOG(3) << "copying " << varname << " to " << param_bak_name;
framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
}
*outvar = scope_->FindVar(varname); *outvar = scope_->FindVar(varname);
} }
} }
...@@ -98,6 +113,7 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, ...@@ -98,6 +113,7 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestPrefetchHandler " << varname; VLOG(4) << "RequestPrefetchHandler " << varname;
...@@ -113,6 +129,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname, ...@@ -113,6 +129,7 @@ bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
framework::Variable** outvar, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name) { const std::string& out_var_name) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
checkpoint_notify_id != -1, checkpoint_notify_id != -1,
......
...@@ -36,20 +36,34 @@ namespace distributed { ...@@ -36,20 +36,34 @@ namespace distributed {
class RequestSendHandler final : public RequestHandler { class RequestSendHandler final : public RequestHandler {
public: public:
explicit RequestSendHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestSendHandler(bool sync_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestSendHandler() {} virtual ~RequestSendHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
private:
bool enable_dc_asgd_;
}; };
class RequestGetHandler final : public RequestHandler { class RequestGetHandler final : public RequestHandler {
public: public:
explicit RequestGetHandler(bool sync_mode) : RequestHandler(sync_mode) {} explicit RequestGetHandler(bool sync_mode, bool enable_dc_asgd = false)
: RequestHandler(sync_mode) {
enable_dc_asgd_ = enable_dc_asgd;
}
virtual ~RequestGetHandler() {} virtual ~RequestGetHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
private:
bool enable_dc_asgd_;
}; };
class RequestPrefetchHandler final : public RequestHandler { class RequestPrefetchHandler final : public RequestHandler {
...@@ -58,6 +72,7 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -58,6 +72,7 @@ class RequestPrefetchHandler final : public RequestHandler {
virtual ~RequestPrefetchHandler() {} virtual ~RequestPrefetchHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
}; };
...@@ -70,6 +85,7 @@ class RequestCheckpointHandler final : public RequestHandler { ...@@ -70,6 +85,7 @@ class RequestCheckpointHandler final : public RequestHandler {
virtual ~RequestCheckpointHandler() {} virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope, bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar, framework::Variable* var, framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
private: private:
......
...@@ -24,6 +24,7 @@ namespace distributed { ...@@ -24,6 +24,7 @@ namespace distributed {
std::once_flag RPCClient::init_flag_; std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr); std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
int RPCClient::trainer_id_ = 0;
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
......
...@@ -72,14 +72,15 @@ class RPCClient { ...@@ -72,14 +72,15 @@ class RPCClient {
virtual bool Wait() = 0; virtual bool Wait() = 0;
template <typename T> template <typename T>
static RPCClient* GetInstance() { static RPCClient* GetInstance(int trainer_id) {
std::call_once(init_flag_, &RPCClient::Init<T>); std::call_once(init_flag_, &RPCClient::Init<T>, trainer_id);
return rpc_client_.get(); return rpc_client_.get();
} }
// Init is called by GetInstance. // Init is called by GetInstance.
template <typename T> template <typename T>
static void Init() { static void Init(int trainer_id) {
trainer_id_ = trainer_id;
if (rpc_client_.get() == nullptr) { if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new T()); rpc_client_.reset(new T());
rpc_client_->InitImpl(); rpc_client_->InitImpl();
...@@ -88,6 +89,8 @@ class RPCClient { ...@@ -88,6 +89,8 @@ class RPCClient {
protected: protected:
virtual void InitImpl() {} virtual void InitImpl() {}
// each trainer have exact one trainer id, it should be static
static int trainer_id_;
private: private:
static std::once_flag init_flag_; static std::once_flag init_flag_;
......
...@@ -125,7 +125,7 @@ TEST(PREFETCH, CPU) { ...@@ -125,7 +125,7 @@ TEST(PREFETCH, CPU) {
g_req_handler.reset(new distributed::RequestPrefetchHandler(true)); g_req_handler.reset(new distributed::RequestPrefetchHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
std::thread server_thread(StartServer, distributed::kRequestPrefetch); std::thread server_thread(StartServer, distributed::kRequestPrefetch);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
...@@ -165,7 +165,7 @@ TEST(COMPLETE, CPU) { ...@@ -165,7 +165,7 @@ TEST(COMPLETE, CPU) {
g_req_handler.reset(new distributed::RequestSendHandler(true)); g_req_handler.reset(new distributed::RequestSendHandler(true));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2)); g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 2));
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE(client != nullptr); PADDLE_ENFORCE(client != nullptr);
std::thread server_thread(StartServer, distributed::kRequestSend); std::thread server_thread(StartServer, distributed::kRequestSend);
g_rpc_service->WaitServerReady(); g_rpc_service->WaitServerReady();
......
...@@ -79,6 +79,7 @@ message VariableMessage { ...@@ -79,6 +79,7 @@ message VariableMessage {
// server stops profiling and generates a profile to /tmp/profile_ps_* // server stops profiling and generates a profile to /tmp/profile_ps_*
// when profile switches from 1 to 2. // when profile switches from 1 to 2.
int64 profile = 11; int64 profile = 11;
int64 trainer_id = 12;
} }
message VoidMessage {} message VoidMessage {}
...@@ -92,6 +92,8 @@ class VariableResponse { ...@@ -92,6 +92,8 @@ class VariableResponse {
return scope_->FindVar(meta_.varname()); return scope_->FindVar(meta_.varname());
} }
int GetTrainerId() { return static_cast<int>(meta_.trainer_id()); }
protected: protected:
bool ReadRaw(::google::protobuf::io::CodedInputStream* input, bool ReadRaw(::google::protobuf::io::CodedInputStream* input,
const platform::DeviceContext& dev_ctx, platform::Place place, const platform::DeviceContext& dev_ctx, platform::Place place,
......
...@@ -37,7 +37,8 @@ class FetchBarrierOp : public framework::OperatorBase { ...@@ -37,7 +37,8 @@ class FetchBarrierOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient"); PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
...@@ -61,6 +62,7 @@ This operator will send a send barrier signal to list_and_serv op, so that ...@@ -61,6 +62,7 @@ This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent. the Parameter Server would knew all variables have been sent.
)DOC"); )DOC");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints", AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.") "Server endpoints to send variables to.")
......
...@@ -61,7 +61,7 @@ class GenNCCLIdOp : public framework::OperatorBase { ...@@ -61,7 +61,7 @@ class GenNCCLIdOp : public framework::OperatorBase {
std::vector<std::string> endpoint_list = std::vector<std::string> endpoint_list =
Attr<std::vector<std::string>>("endpoint_list"); Attr<std::vector<std::string>>("endpoint_list");
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
for (auto& ep : endpoint_list) { for (auto& ep : endpoint_list) {
VLOG(3) << "sending nccl id to " << ep; VLOG(3) << "sending nccl id to " << ep;
......
...@@ -218,23 +218,26 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -218,23 +218,26 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program, framework::ProgramDesc *program,
framework::Scope *recv_scope) const { framework::Scope *recv_scope) const {
VLOG(2) << "RunAsyncLoop"; VLOG(2) << "RunAsyncLoop";
// grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad;
auto grad_to_block_id_str = auto grad_to_block_id_str =
Attr<std::vector<std::string>>("grad_to_block_id"); Attr<std::vector<std::string>>("grad_to_block_id");
for (const auto &grad_and_id : grad_to_block_id_str) { DoubleFindMap<std::string, int32_t> grad_to_block_id;
auto append_block_maps = [](DoubleFindMap<std::string, int32_t> *out_map,
const std::string &grad_and_id) {
std::vector<std::string> pieces; std::vector<std::string> pieces;
split(grad_and_id, ':', &pieces); split(grad_and_id, ':', &pieces);
VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; VLOG(3) << "after split, key = " << pieces[0] << ", id=" << pieces[1];
PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(pieces.size(), 2);
PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); PADDLE_ENFORCE_EQ(out_map->count(pieces[0]), 0);
int block_id = std::stoi(pieces[1]); int block_id = std::stoi(pieces[1]);
grad_to_block_id[pieces[0]] = block_id; (*out_map)[pieces[0]] = block_id;
id_to_grad[block_id] = pieces[0]; };
for (const auto &grad_and_id : grad_to_block_id_str) {
append_block_maps(&grad_to_block_id, grad_and_id);
} }
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
...@@ -244,15 +247,22 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -244,15 +247,22 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
block_list.push_back(blkid); block_list.push_back(blkid);
} }
auto optimize_prepared = executor->Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, block_list);
// execute global block if needed // execute global block if needed, block id 1 in the program is global
if (block_list[0] == 1 && id_to_grad.count(1) == 0) { // block if it's not bind to a grad var for it's update.
if (block_list[0] == 1 &&
grad_to_block_id.find_value(static_cast<int32_t>(1)) ==
grad_to_block_id.end()) {
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope); executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
} }
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx; grad_to_prepared_ctx, param_to_prepared_ctx;
for (size_t i = 0; i < block_list.size(); ++i) { for (size_t i = 0; i < block_list.size(); ++i) {
grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i]; auto blkid = block_list[i];
auto it = grad_to_block_id.find_value(blkid);
if (it != grad_to_block_id.end()) {
grad_to_prepared_ctx[it->first] = optimize_prepared[i];
}
} }
request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
...@@ -315,6 +325,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -315,6 +325,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
bool sync_mode = Attr<bool>("sync_mode"); bool sync_mode = Attr<bool>("sync_mode");
bool dc_sgd = Attr<bool>("dc_asgd");
auto fan_in = Attr<int>("Fanin"); auto fan_in = Attr<int>("Fanin");
auto inputs = Inputs("X"); auto inputs = Inputs("X");
...@@ -328,8 +339,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -328,8 +339,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
request_send_handler_.reset(new distributed::RequestSendHandler(sync_mode)); request_send_handler_.reset(
request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode)); new distributed::RequestSendHandler(sync_mode, dc_sgd));
request_get_handler_.reset(
new distributed::RequestGetHandler(sync_mode, dc_sgd));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler( request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
...@@ -443,6 +456,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -443,6 +456,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
"a map from grad name to it's optimize block id") "a map from grad name to it's optimize block id")
.SetDefault({}); .SetDefault({});
AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true); AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
AddAttr<bool>("dc_asgd", "set to true will enable DC-ASGD training.")
.SetDefault(false);
AddAttr<std::vector<framework::BlockDesc *>>( AddAttr<std::vector<framework::BlockDesc *>>(
kOptimizeBlocks, "Optimize blocks to run on server side.") kOptimizeBlocks, "Optimize blocks to run on server side.")
.SetDefault({}); .SetDefault({});
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <atomic> #include <atomic>
#include <set> #include <set>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
...@@ -37,6 +38,17 @@ constexpr char kCheckpointBlockId[] = "checkpint_block_id"; ...@@ -37,6 +38,17 @@ constexpr char kCheckpointBlockId[] = "checkpint_block_id";
void RunServer(std::shared_ptr<distributed::RPCServer> service); void RunServer(std::shared_ptr<distributed::RPCServer> service);
template <class TKey, class TValue>
class DoubleFindMap : public std::unordered_map<TKey, TValue> {
public:
typename std::unordered_map<TKey, TValue>::iterator find_value(TValue v) {
return std::find_if(this->begin(), this->end(),
[&v](const std::pair<const std::string, int> p) {
return p.second == v;
});
}
};
class ListenAndServOp : public framework::OperatorBase { class ListenAndServOp : public framework::OperatorBase {
public: public:
ListenAndServOp(const std::string& type, ListenAndServOp(const std::string& type,
......
...@@ -42,7 +42,8 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -42,7 +42,8 @@ class PrefetchOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
...@@ -69,6 +70,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,6 +70,7 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) result " "(LoDTensor) result "
"to be fetched from parameter server") "to be fetched from parameter server")
.AsDuplicable(); .AsDuplicable();
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>( AddAttr<std::vector<std::string>>(
"epmap", "epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
......
...@@ -42,7 +42,8 @@ class RecvOp : public framework::OperatorBase { ...@@ -42,7 +42,8 @@ class RecvOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
...@@ -73,6 +74,7 @@ This operator can get variables from server side. ...@@ -73,6 +74,7 @@ This operator can get variables from server side.
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<int>("sync_mode", AddAttr<int>("sync_mode",
"(int, default 0)" "(int, default 0)"
"sync recv or async recv.") "sync recv or async recv.")
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h"
#include <string>
namespace paddle {
namespace operators {
class RefByTrainerIdOp : public framework::OperatorWithKernel {
public:
RefByTrainerIdOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"),
"Input(X) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("TrainerId"),
"Input(TrainerId) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RefByTrainerIdOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("TrainerId").size(), 1,
"TrainerId should be a scalar.");
// Out's shape is determined at runtime.
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
ctx.MultiInput<framework::Tensor>("X")[0]->type()),
ctx.GetPlace());
}
};
class RefByTrainerIdOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor) Input tensor list.").AsDuplicable();
AddInput("TrainerId", "(Tensor) Scalar int, the trainer id runtime value.");
AddOutput("Out", "(Tensor) Return one tensor reference of X[trainer_id]");
AddComment(R"DOC(
**RefByTrainerId operator**
Return a reference of a tensor, using trainer_id as the index to find from the input.
$$Out = X[TrainerId]$$
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(ref_by_trainer_id, ops::RefByTrainerIdOp,
ops::RefByTrainerIdOpMaker);
REGISTER_OP_CPU_KERNEL(
ref_by_trainer_id,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, float>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, double>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, int>,
ops::RefByTrainerIdKernel<paddle::platform::CPUDeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/ref_by_trainer_id_op.h"
REGISTER_OP_CUDA_KERNEL(
ref_by_trainer_id,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::RefByTrainerIdKernel<paddle::platform::CUDADeviceContext,
int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <stdio.h>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class RefByTrainerIdKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out = context.Output<framework::Tensor>("Out");
auto in_list = context.MultiInput<framework::Tensor>("X");
auto* trainer_id_t = context.Input<framework::Tensor>("TrainerId");
int64_t trainer_id;
auto* trainer_id_data = trainer_id_t->data<int64_t>();
if (platform::is_gpu_place(context.GetPlace())) {
#ifdef PADDLE_WITH_CUDA
auto stream = context.cuda_device_context().stream();
memory::Copy<>(platform::CPUPlace(), &trainer_id,
boost::get<platform::CUDAPlace>(context.GetPlace()),
trainer_id_data, sizeof(int64_t), stream);
#endif
} else {
trainer_id = *trainer_id_data;
}
printf("after get trainer_id %lu\n", trainer_id);
PADDLE_ENFORCE_LT(trainer_id, in_list.size());
out->mutable_data<T>(context.GetPlace());
out->ShareDataWith(*(in_list[trainer_id]));
}
};
} // namespace operators
} // namespace paddle
...@@ -39,7 +39,8 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -39,7 +39,8 @@ class SendBarrierOp : public framework::OperatorBase {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
VLOG(3) << "SendBarrierOp sync"; VLOG(3) << "SendBarrierOp sync";
...@@ -67,6 +68,7 @@ This operator will send a send barrier signal to list_and_serv op, so that ...@@ -67,6 +68,7 @@ This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent. the Parameter Server would knew all variables have been sent.
)DOC"); )DOC");
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("endpoints", AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.") "Server endpoints to send variables to.")
......
...@@ -44,7 +44,8 @@ class SendOp : public framework::OperatorBase { ...@@ -44,7 +44,8 @@ class SendOp : public framework::OperatorBase {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
distributed::RPCClient* rpc_client = distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
...@@ -79,6 +80,7 @@ This operator will send variables to listen_and_serve op at the parameter server ...@@ -79,6 +80,7 @@ This operator will send variables to listen_and_serve op at the parameter server
"(int, default 0)" "(int, default 0)"
"sync send or async send.") "sync send or async send.")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::vector<std::string>>("epmap", AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input " "Server endpoints in the order of input "
......
...@@ -92,7 +92,7 @@ TEST(SendNcclId, RPCServer) { ...@@ -92,7 +92,7 @@ TEST(SendNcclId, RPCServer) {
std::string ep = string::Sprintf("127.0.0.1:%d", port); std::string ep = string::Sprintf("127.0.0.1:%d", port);
distributed::RPCClient* client = distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(); distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
LOG(INFO) << "connect to server" << ep; LOG(INFO) << "connect to server" << ep;
client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME); client->AsyncSendVar(ep, dev_ctx, scope, NCCL_ID_VARNAME);
......
...@@ -37,10 +37,15 @@ class TestDistRunnerBase(object): ...@@ -37,10 +37,15 @@ class TestDistRunnerBase(object):
"get_model should be implemented by child classes.") "get_model should be implemented by child classes.")
@staticmethod @staticmethod
def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers, def get_transpiler(trainer_id,
sync_mode): main_program,
pserver_endpoints,
trainers,
sync_mode,
dc_asgd=False):
# NOTE: import fluid until runtime, or else forking processes will cause error. # NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd
t = fluid.DistributeTranspiler(config=config) t = fluid.DistributeTranspiler(config=config)
t.transpile( t.transpile(
trainer_id=trainer_id, trainer_id=trainer_id,
...@@ -55,7 +60,7 @@ class TestDistRunnerBase(object): ...@@ -55,7 +60,7 @@ class TestDistRunnerBase(object):
# NOTE: pserver should not call memory optimize # NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id, t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints, fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode) args.trainers, args.sync_mode, args.dc_asgd)
pserver_prog = t.get_pserver_program(args.current_endpoint) pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint, startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog) pserver_prog)
...@@ -75,8 +80,7 @@ class TestDistRunnerBase(object): ...@@ -75,8 +80,7 @@ class TestDistRunnerBase(object):
t = self.get_transpiler(args.trainer_id, t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), fluid.default_main_program(),
args.endpoints, args.trainers, args.endpoints, args.trainers,
args.sync_mode) args.sync_mode, args.dc_asgd)
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
else: else:
trainer_prog = fluid.default_main_program() trainer_prog = fluid.default_main_program()
...@@ -155,6 +159,7 @@ def runtime_main(test_class): ...@@ -155,6 +159,7 @@ def runtime_main(test_class):
parser.add_argument('--mem_opt', action='store_true') parser.add_argument('--mem_opt', action='store_true')
parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument( parser.add_argument(
'--use_reader_alloc', action='store_true', required=False) '--use_reader_alloc', action='store_true', required=False)
parser.add_argument('--batch_size', required=False, type=int, default=2) parser.add_argument('--batch_size', required=False, type=int, default=2)
...@@ -200,6 +205,7 @@ class TestDistBase(unittest.TestCase): ...@@ -200,6 +205,7 @@ class TestDistBase(unittest.TestCase):
self._enforce_place = None self._enforce_place = None
self._mem_opt = False self._mem_opt = False
self._use_reduce = False self._use_reduce = False
self._dc_asgd = False # must use with async mode
self._use_reader_alloc = True self._use_reader_alloc = True
self._setup_config() self._setup_config()
self._after_setup_config() self._after_setup_config()
......
...@@ -53,6 +53,15 @@ class TestDistMnistAsync(TestDistBase): ...@@ -53,6 +53,15 @@ class TestDistMnistAsync(TestDistBase):
self.check_with_place("dist_mnist.py", delta=200) self.check_with_place("dist_mnist.py", delta=200)
class TestDistMnistDcAsgd(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._dc_asgd = True
def test_se_resnext(self):
self.check_with_place("dist_mnist.py", delta=200)
# FIXME(typhoonzero): enable these tests once we have 4 # FIXME(typhoonzero): enable these tests once we have 4
# 4 GPUs on CI machine, and the base class should be updated. # 4 GPUs on CI machine, and the base class should be updated.
# #
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestRefByTrainerIdOp(OpTest):
def setUp(self):
self.op_type = "ref_by_trainer_id"
param_baks = [("x%d" % x, np.random.random((10, 10)).astype("float32"))
for x in range(10)]
self.inputs = {
'X': param_baks,
'TrainerId': np.array([8]).astype("int64")
}
self.outputs = {'Out': param_baks[8][1]}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
...@@ -38,7 +38,7 @@ import six ...@@ -38,7 +38,7 @@ import six
import logging import logging
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework from .. import core, framework, unique_name
from ..framework import Program, default_main_program, \ from ..framework import Program, default_main_program, \
default_startup_program, Block, \ default_startup_program, Block, \
Parameter, grad_var_name Parameter, grad_var_name
...@@ -138,6 +138,7 @@ class DistributeTranspilerConfig(object): ...@@ -138,6 +138,7 @@ class DistributeTranspilerConfig(object):
slice_var_up = True slice_var_up = True
split_method = None split_method = None
min_block_size = 8192 min_block_size = 8192
enable_dc_asgd = False
# supported modes: pserver, nccl2 # supported modes: pserver, nccl2
mode = "pserver" mode = "pserver"
print_log = False print_log = False
...@@ -252,6 +253,8 @@ class DistributeTranspiler(object): ...@@ -252,6 +253,8 @@ class DistributeTranspiler(object):
n workers, the id may range from 0 ~ n-1 n workers, the id may range from 0 ~ n-1
program (Program|None): program to transpile, program (Program|None): program to transpile,
default is fluid.default_main_program(). default is fluid.default_main_program().
startup_program (Program|None): startup_program to transpile,
default is fluid.default_startup_program().
pservers (str): comma separated ip:port string for the pserver pservers (str): comma separated ip:port string for the pserver
list. list.
trainers (int|str): in pserver mode this is the number of trainers (int|str): in pserver mode this is the number of
...@@ -383,6 +386,8 @@ class DistributeTranspiler(object): ...@@ -383,6 +386,8 @@ class DistributeTranspiler(object):
outputs={"Out": send_barrier_out}, outputs={"Out": send_barrier_out},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
"sync_mode": self.sync_mode,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
...@@ -426,6 +431,7 @@ class DistributeTranspiler(object): ...@@ -426,6 +431,7 @@ class DistributeTranspiler(object):
outputs={"Out": splited_var}, outputs={"Out": splited_var},
attrs={ attrs={
"epmap": eps, "epmap": eps,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: OP_ROLE_VAR_ATTR_NAME:
[param_varname, recv_op_role_var_name], [param_varname, recv_op_role_var_name],
...@@ -440,6 +446,7 @@ class DistributeTranspiler(object): ...@@ -440,6 +446,7 @@ class DistributeTranspiler(object):
outputs={"Out": all_recv_outputs}, outputs={"Out": all_recv_outputs},
attrs={ attrs={
"endpoints": pserver_endpoints, "endpoints": pserver_endpoints,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
}) })
...@@ -651,6 +658,24 @@ in a single call.") ...@@ -651,6 +658,24 @@ in a single call.")
endpoint, op): endpoint, op):
opt_op_on_pserver.append(op) opt_op_on_pserver.append(op)
# step 3.3 # step 3.3
# prepare if dc asgd is enabled
if self.config.enable_dc_asgd == True:
assert (self.sync_mode == False)
self.param_bak_list = []
# add param_bak for each trainer
for p in self.param_grad_ep_mapping[endpoint]["params"]:
# each parameter should have w_bak for each trainer id
for i in range(self.trainer_num):
param_bak_name = "%s.trainer_%d_bak" % (p.name, i)
tmpvar = pserver_program.global_block().create_var(
# NOTE: this var name format is used in `request_get_handler`
name=param_bak_name,
type=p.type,
shape=p.shape,
dtype=p.dtype)
self.param_bak_list.append((p, tmpvar))
# step 3.4
# Iterate through the ops, and if an op and the optimize ops # Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then # which located on current pserver are in one set, then
# append it into the sub program. # append it into the sub program.
...@@ -741,7 +766,7 @@ in a single call.") ...@@ -741,7 +766,7 @@ in a single call.")
grad_to_block_id, merged_var, grad_to_block_id, merged_var,
lr_ops) lr_ops)
# dedup grad to ids list # dedup grad to ids list
grad_to_block_id = list(set(grad_to_block_id)) grad_to_block_id = list(set(grad_to_block_id))
# append global ops # append global ops
if global_ops: if global_ops:
...@@ -787,6 +812,8 @@ in a single call.") ...@@ -787,6 +812,8 @@ in a single call.")
if self.has_distributed_lookup_table: if self.has_distributed_lookup_table:
attrs['checkpint_block_id'] = checkpoint_block_id attrs['checkpint_block_id'] = checkpoint_block_id
if self.config.enable_dc_asgd:
attrs['dc_asgd'] = True
if len(prefetch_var_name_to_block_id) > 0: if len(prefetch_var_name_to_block_id) > 0:
attrs[ attrs[
...@@ -903,6 +930,15 @@ to transpile() call.") ...@@ -903,6 +930,15 @@ to transpile() call.")
inputs=new_inputs, inputs=new_inputs,
outputs=new_outputs, outputs=new_outputs,
attrs=op.all_attrs()) attrs=op.all_attrs())
if self.config.enable_dc_asgd:
for p, p_bak in self.param_bak_list:
startup_param_var = s_prog.global_block().vars[p.name]
startup_tmpvar = s_prog.global_block().vars[p_bak.name]
# copy init random value to param_bak
s_prog.global_block().append_op(
type="assign",
inputs={"X": startup_param_var},
outputs={"Out": startup_tmpvar})
# add slice vars # add slice vars
s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint) s_prog._slice_vars_and_attrs = self._get_slice_vars_and_attrs(endpoint)
...@@ -1175,6 +1211,7 @@ to transpile() call.") ...@@ -1175,6 +1211,7 @@ to transpile() call.")
attrs={ attrs={
"sync_mode": not self.sync_mode, "sync_mode": not self.sync_mode,
"epmap": pserver_endpoints, "epmap": pserver_endpoints,
"trainer_id": self.trainer_id,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE, RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
OP_ROLE_VAR_ATTR_NAME: [ OP_ROLE_VAR_ATTR_NAME: [
self.grad_name_to_param_name[table_grad_name], self.grad_name_to_param_name[table_grad_name],
...@@ -1531,6 +1568,69 @@ to transpile() call.") ...@@ -1531,6 +1568,69 @@ to transpile() call.")
attrs={"scale": 1.0 / float(self.trainer_num)}) attrs={"scale": 1.0 / float(self.trainer_num)})
return merged_var return merged_var
def _append_dc_asgd_ops(self, block, param_var, grad_var):
# NOTE: can not use grammar candy here, should put ops in specific block
local_param_bak = block.create_var(
name="%s.local_bak" % param_var.name,
shape=param_var.shape,
type=param_var.type,
dtype=param_var.dtype,
persistable=False)
# trainer_id_var is block local
trainer_id_var = block.create_var(
name="@TRAINER_ID@",
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=core.VarDesc.VarType.INT64,
shape=[1],
persistable=False)
# ref_inputs = [x[1] for x in self.param_bak_list]
ref_inputs = []
for p, p_bak in self.param_bak_list:
if p.name == param_var.name:
print("#### ref inputs: ", param_var.name, p_bak.name)
ref_inputs.append(p_bak)
block.append_op(
type="ref_by_trainer_id",
inputs={"X": ref_inputs,
"TrainerId": trainer_id_var},
outputs={"Out": local_param_bak})
def __create_temp_var__():
return block.create_var(
name=unique_name.generate("tmp_dc_output"),
shape=param_var.shape,
type=param_var.type,
dtype=param_var.dtype,
persistable=False)
o1 = __create_temp_var__()
block.append_op(
type="elementwise_sub",
inputs={"X": param_var,
"Y": local_param_bak},
outputs={"Out": o1})
o2 = __create_temp_var__()
block.append_op(
type="elementwise_mul",
inputs={"X": o1,
"Y": grad_var},
outputs={"Out": o2})
o3 = __create_temp_var__()
block.append_op(
type="elementwise_mul",
inputs={"X": o2,
"Y": grad_var},
outputs={"Out": o3})
# TODO(typhoonzero): append scale
o4 = __create_temp_var__()
block.append_op(
type="elementwise_add",
inputs={"X": grad_var,
"Y": o3},
outputs={"Out": o4})
return o4
def _append_pserver_ops(self, optimize_block, opt_op, endpoint, def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
grad_to_block_id, origin_program, merged_var): grad_to_block_id, origin_program, merged_var):
program = optimize_block.program program = optimize_block.program
...@@ -1546,8 +1646,15 @@ to transpile() call.") ...@@ -1546,8 +1646,15 @@ to transpile() call.")
break break
return param_block return param_block
if self.config.enable_dc_asgd:
param_var = _get_param_block(opt_op)
dc = self._append_dc_asgd_ops(optimize_block, param_var, merged_var)
for key in opt_op.input_names: for key in opt_op.input_names:
if key == "Grad": if key == "Grad":
if self.config.enable_dc_asgd:
new_inputs[key] = dc
else:
new_inputs[key] = merged_var new_inputs[key] = merged_var
elif key == "Param": elif key == "Param":
param_block = _get_param_block(opt_op) param_block = _get_param_block(opt_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册