未验证 提交 7f2aa2db 编写于 作者: C Chengmo 提交者: GitHub

【paddle.fleet】Support Heter Parameter Server (#25998)

* Support Heter Parameter Server
上级 ac63c7cd
...@@ -56,7 +56,7 @@ endif() ...@@ -56,7 +56,7 @@ endif()
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op) DEPS ${RPC_DEPS} executor scope proto_desc lookup_sparse_table_read_op scale_op)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler scope)
cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory) cc_library(parameter_prefetch SRCS parameter_prefetch.cc DEPS sendrecvop_rpc memory)
cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory)
......
...@@ -132,6 +132,15 @@ void ProcGetResponse(const VarHandle& var_h, ...@@ -132,6 +132,15 @@ void ProcGetResponse(const VarHandle& var_h,
&trainer_id); &trainer_id);
} }
void ProcGetRecvResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
VLOG(4) << "ProcGetRecvResponse";
framework::Variable* outvar = nullptr;
int trainer_id;
DeserializeRecvFromByteBuffer(ret_msg, *var_h.ctx(), var_h.scope(), &outvar,
&trainer_id);
}
template <typename T> template <typename T>
void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) { void RequestToByteBuffer(const T& proto, ::grpc::ByteBuffer* result) {
::grpc::Slice slice(proto.ByteSizeLong()); ::grpc::Slice slice(proto.ByteSizeLong());
...@@ -482,6 +491,79 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( ...@@ -482,6 +491,79 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
return h; return h;
} }
VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& send_var_name,
const std::string& recv_var_name,
const std::string& table_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string send_var_name_val = send_var_name;
const std::string recv_var_name_val = recv_var_name;
const std::string table_name_val = table_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
const std::string method = kSendAndRecvRPC;
VLOG(4) << "GRPCClient::SendAndRecv Begin ,Send_var_name: "
<< send_var_name_val << " Recv_var_name: " << recv_var_name_val;
int retry_times_ = 0;
while (true) {
SendAndRecvProcessor* s = new SendAndRecvProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, method, send_var_name_val, p_ctx, p_scope));
VarHandlePtr h_recv(
new VarHandle(ep, method, recv_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
s->RecvPrepare(h_recv);
framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val,
p_scope, p_ctx, s, method, h, this] {
auto* send_var = p_scope->FindVar(send_var_name_val);
send_var->GetMutable<framework::LoDTensor>()->set_lod({});
::grpc::ByteBuffer buf;
VLOG(4) << "SerializeToByteBuffer: send_var_name_val: "
<< send_var_name_val
<< " recv_var_name_val: " << recv_var_name_val;
SerializeToByteBuffer(send_var_name_val, send_var, *p_ctx, &buf,
recv_var_name_val, trainer_id_, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetRecvResponse;
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendAndRecvVariable",
buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
h->Wait();
if (h->should_retry) {
VLOG(3) << "rpc call failed, retry times " << retry_times_;
retry_times_++;
std::random_device rd;
std::this_thread::sleep_for(std::chrono::milliseconds(rd() % 5));
continue;
}
}
return h;
}
}
bool GRPCClient::Wait() { bool GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); }); sync_cond_.wait(lk, [this] { return (req_count_ == 0 || ok_ == false); });
......
...@@ -53,6 +53,8 @@ namespace distributed { ...@@ -53,6 +53,8 @@ namespace distributed {
void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg); void ProcGetResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
void ProcGetRecvResponse(const VarHandle& var_h, const grpc::ByteBuffer& msg);
class BaseProcessor { class BaseProcessor {
public: public:
BaseProcessor() { context_ = nullptr; } BaseProcessor() { context_ = nullptr; }
...@@ -131,6 +133,28 @@ class GetProcessor : public BaseProcessor { ...@@ -131,6 +133,28 @@ class GetProcessor : public BaseProcessor {
RequestGetCallBack response_call_back_ = ProcGetResponse; RequestGetCallBack response_call_back_ = ProcGetResponse;
}; };
class SendAndRecvProcessor : public BaseProcessor {
public:
explicit SendAndRecvProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(), stub_g_(ch) {}
virtual ~SendAndRecvProcessor() {}
void ProcessImpl() override {
if (response_call_back_) {
response_call_back_(*var_h_recv_.get(), reply_);
var_h_recv_->Finish(true);
}
}
void RecvPrepare(VarHandlePtr h_recv) { var_h_recv_ = h_recv; }
::grpc::ByteBuffer reply_;
::grpc::GenericStub stub_g_;
RequestGetCallBack response_call_back_ = ProcGetResponse;
VarHandlePtr var_h_recv_;
};
class BatchBarrierProcessor : public BaseProcessor { class BatchBarrierProcessor : public BaseProcessor {
public: public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch) explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
...@@ -231,6 +255,14 @@ class GRPCClient : public RPCClient { ...@@ -231,6 +255,14 @@ class GRPCClient : public RPCClient {
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendAndRecv(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& send_var_name,
const std::string& recv_var_name,
const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncSendComplete( VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
......
...@@ -76,7 +76,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -76,7 +76,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
PADDLE_THROW("Serialize does not support type: %s", PADDLE_THROW("Serialize does not support type: %s",
typeid(var->Type()).name()); typeid(var->Type()).name());
} }
std::string header; std::string header;
request.AppendToString(&header); request.AppendToString(&header);
auto buffer = std::unique_ptr<char[]>(new char[1024]); auto buffer = std::unique_ptr<char[]>(new char[1024]);
...@@ -101,7 +100,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -101,7 +100,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
} }
#endif #endif
PADDLE_ENFORCE_NOT_NULL(payload); PADDLE_ENFORCE_NOT_NULL(payload);
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber,
payload->memory_size()); payload->memory_size());
if (payload->memory_size() >= std::numeric_limits<int>::max()) { if (payload->memory_size() >= std::numeric_limits<int>::max()) {
...@@ -140,7 +138,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -140,7 +138,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
::grpc::Slice::STEAL_REF); ::grpc::Slice::STEAL_REF);
num_slices = 4; num_slices = 4;
} }
::grpc::ByteBuffer tmp(&slices[0], num_slices); ::grpc::ByteBuffer tmp(&slices[0], num_slices);
msg->Swap(&tmp); msg->Swap(&tmp);
} }
...@@ -156,6 +153,19 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, ...@@ -156,6 +153,19 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
*trainer_id = resp.GetTrainerId(); *trainer_id = resp.GetTrainerId();
} }
void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id) {
platform::RecordRPCEvent record_event("deserial");
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE_EQ(
resp.Parse(msg), 0,
platform::errors::InvalidArgument("parse bytebuffer to tensor error!"));
*var = resp.GetRecvVar();
*trainer_id = resp.GetTrainerId();
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -47,6 +47,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, ...@@ -47,6 +47,11 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var, int* trainer_id); framework::Variable** var, int* trainer_id);
void DeserializeRecvFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var, int* trainer_id);
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -28,6 +28,7 @@ DECLARE_int32(rpc_retry_bind_port); ...@@ -28,6 +28,7 @@ DECLARE_int32(rpc_retry_bind_port);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
enum CallStatus { PROCESS = 0, FINISH }; enum CallStatus { PROCESS = 0, FINISH };
// reference: // reference:
...@@ -433,6 +434,51 @@ class RequestNotify final : public RequestBase { ...@@ -433,6 +434,51 @@ class RequestNotify final : public RequestBase {
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_; ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
}; };
class RequestSendAndRecv final : public RequestBase {
public:
explicit RequestSendAndRecv(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new GRPCVariableResponse(
request_handler->scope(), request_handler->dev_ctx(),
request_handler->distributed_mode()));
int method_id =
static_cast<int>(distributed::GrpcMethod::kRequestSendAndRecv);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestSendAndRecv() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
std::string in_var_name = request_->Varname();
std::string out_var_name = request_->OutVarname();
std::string table_name = request_->TableName();
int trainer_id = request_->GetTrainerId();
VLOG(4) << "RequestSendAndRecv, in_var_name: " << in_var_name
<< " out_var_name: " << out_var_name << " trainer: " << trainer_id;
auto scope = request_->GetMutableLocalScope();
auto invar = scope->FindVar(in_var_name);
framework::Variable* outvar = nullptr;
request_handler_->Handle(in_var_name, scope, invar, &outvar, trainer_id,
out_var_name, table_name);
SerializeToByteBuffer(out_var_name, outvar, *request_handler_->dev_ctx(),
&reply_);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<GRPCVariableResponse> request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
void AsyncGRPCServer::WaitServerReady() { void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is waiting server ready"; VLOG(4) << "AsyncGRPCServer is waiting server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
...@@ -586,6 +632,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -586,6 +632,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id); b = new RequestCheckpointNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestNotify) { } else if (rpc_name == kRequestNotify) {
b = new RequestNotify(service_.get(), cq.get(), handler, req_id); b = new RequestNotify(service_.get(), cq.get(), handler, req_id);
} else if (rpc_name == kRequestSendAndRecv) {
b = new RequestSendAndRecv(service_.get(), cq.get(), handler, req_id);
} else { } else {
PADDLE_ENFORCE(false, "not supported rpc"); PADDLE_ENFORCE(false, "not supported rpc");
} }
......
...@@ -85,10 +85,12 @@ enum class GrpcMethod { ...@@ -85,10 +85,12 @@ enum class GrpcMethod {
kGetMonomerVariable, kGetMonomerVariable,
kGetMonomerBarrier, kGetMonomerBarrier,
kRequestNotify, kRequestNotify,
kRequestSendAndRecv,
// when you add new handler, change kGrpcNumMethods at the same time!
}; };
static const int kGrpcNumMethods = static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kRequestNotify) + 1; static_cast<int>(GrpcMethod::kRequestSendAndRecv) + 1;
inline const char* GrpcMethodName(GrpcMethod id) { inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) { switch (id) {
...@@ -108,6 +110,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -108,6 +110,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/CheckpointNotify"; return "/sendrecv.SendRecvService/CheckpointNotify";
case GrpcMethod::kRequestNotify: case GrpcMethod::kRequestNotify:
return "/sendrecv.SendRecvService/DistributeNotify"; return "/sendrecv.SendRecvService/DistributeNotify";
case GrpcMethod::kRequestSendAndRecv:
return "/sendrecv.SendRecvService/SendAndRecvVariable";
} }
// Shouldn't be reached. // Shouldn't be reached.
......
...@@ -46,6 +46,7 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint"; ...@@ -46,6 +46,7 @@ constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier"; constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier"; constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kRequestNotify[] = "RequestNotify"; constexpr char kRequestNotify[] = "RequestNotify";
constexpr char kRequestSendAndRecv[] = "RequestSendAndRecv";
constexpr char kSendRPC[] = "SendRPC"; constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC"; constexpr char kGetRPC[] = "GetRPC";
...@@ -57,6 +58,7 @@ constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC"; ...@@ -57,6 +58,7 @@ constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC"; constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC"; constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC"; constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
constexpr char kSendAndRecvRPC[] = "SendAndRecvRPC";
constexpr int64_t kPrefetchTimeout = 60000; constexpr int64_t kPrefetchTimeout = 60000;
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
......
...@@ -325,6 +325,22 @@ bool RequestNotifyHandler::Handle(const std::string &varname, ...@@ -325,6 +325,22 @@ bool RequestNotifyHandler::Handle(const std::string &varname,
return true; return true;
} }
bool RequestSendAndRecvHandler::Handle(const std::string &varname,
framework::Scope *Scope,
framework::Variable *var,
framework::Variable **outvar,
const int trainer_id,
const std::string &out_var_name,
const std::string &table_name) {
VLOG(3) << "SendAndRecvHandle: " << varname
<< " out_var_name: " << out_var_name
<< " , trainer_id: " << trainer_id;
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), Scope);
*outvar = Scope->FindVar(out_var_name);
return true;
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -176,6 +176,17 @@ class RequestNotifyHandler final : public RequestHandler { ...@@ -176,6 +176,17 @@ class RequestNotifyHandler final : public RequestHandler {
std::unordered_map<int, int64_t> decay_counters; std::unordered_map<int, int64_t> decay_counters;
}; };
class RequestSendAndRecvHandler final : public RequestHandler {
public:
explicit RequestSendAndRecvHandler(int distributed_mode)
: RequestHandler(distributed_mode) {}
virtual ~RequestSendAndRecvHandler() {}
bool Handle(const std::string& varname, framework::Scope* Scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
};
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -85,6 +85,12 @@ class RPCClient { ...@@ -85,6 +85,12 @@ class RPCClient {
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendAndRecv(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& send_var_name,
const std::string& recv_var_name, const std::string& table_name = "",
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncSendComplete( virtual VarHandlePtr AsyncSendComplete(
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0; const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) = 0;
......
...@@ -35,27 +35,24 @@ namespace platform = paddle::platform; ...@@ -35,27 +35,24 @@ namespace platform = paddle::platform;
namespace distributed = paddle::operators::distributed; namespace distributed = paddle::operators::distributed;
USE_NO_KERNEL_OP(lookup_sparse_table_read); USE_NO_KERNEL_OP(lookup_sparse_table_read);
USE_OP(scale);
std::unique_ptr<distributed::RPCServer> g_rpc_service; std::unique_ptr<distributed::RPCServer> g_rpc_service;
std::unique_ptr<distributed::RequestHandler> g_req_handler; std::unique_ptr<distributed::RequestHandler> g_req_handler;
framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) { framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) {
auto root_block = program->MutableBlock(0); auto root_block = program->MutableBlock(0);
auto* block = program->AppendBlock(*root_block); auto* block = program->AppendBlock(*root_block);
framework::VariableNameMap input({{"W", {"w"}}, {"Ids", {"ids"}}}); framework::OpDesc* op = block->AppendOp();
framework::VariableNameMap output({{"Output", {"out"}}}); op->SetType("scale");
auto op = block->AppendOp(); op->SetInput("X", {"x"});
op->SetType("lookup_sparse_table_read"); op->SetOutput("Out", {"res"});
op->SetInput("W", {"w"}); op->SetAttr("scale", 0.5f);
op->SetInput("Ids", {"ids"});
op->SetOutput("Out", {"out"}); auto& out = *root_block->Var("res");
op->SetAttr("tablename", {"w"});
op->SetAttr("value_names", {"Param"});
auto& out = *root_block->Var("out");
out.SetType(framework::proto::VarType::LOD_TENSOR); out.SetType(framework::proto::VarType::LOD_TENSOR);
out.SetShape({10, 10}); out.SetShape({1, 10});
return block; return block;
} }
...@@ -69,6 +66,12 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { ...@@ -69,6 +66,12 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) {
auto ids_var = scope->Var("ids"); auto ids_var = scope->Var("ids");
ids_var->GetMutable<framework::LoDTensor>(); ids_var->GetMutable<framework::LoDTensor>();
auto x_var = scope->Var("x");
x_var->GetMutable<framework::LoDTensor>();
auto res_var = scope->Var("res");
res_var->GetMutable<framework::LoDTensor>();
} }
void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
...@@ -78,6 +81,11 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, ...@@ -78,6 +81,11 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place,
int64_t* ids_ptr = int64_t* ids_ptr =
ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place); ids_var->mutable_data<int64_t>(framework::DDim({rows_numel, 1}), *place);
for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2;
auto x_var = scope->Var("x")->GetMutable<framework::LoDTensor>();
float* x_ptr =
x_var->mutable_data<float>(framework::DDim({1, rows_numel}), *place);
for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0;
} }
void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place,
...@@ -124,6 +132,38 @@ void StartServer(const std::string& rpc_name) { ...@@ -124,6 +132,38 @@ void StartServer(const std::string& rpc_name) {
server_thread.join(); server_thread.join();
} }
void StartSendAndRecvServer(const std::string& rpc_name) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
framework::Executor exe(place);
platform::CPUDeviceContext ctx(place);
auto block = AppendSendAndRecvBlock(&program);
std::string in_var_name("x");
std::vector<int> prefetch_block_ids{block->ID()};
auto prepared = exe.Prepare(program, prefetch_block_ids);
InitTensorsOnServer(&scope, &place, 10);
std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx;
grad_to_prepared_ctx[in_var_name] = prepared[0];
g_req_handler->SetProgram(&program);
g_req_handler->SetGradToPreparedCtx(&grad_to_prepared_ctx);
g_req_handler->SetDevCtx(&ctx);
g_req_handler->SetScope(&scope);
g_req_handler->SetExecutor(&exe);
g_rpc_service->RegisterRPC(rpc_name, g_req_handler.get());
g_req_handler->SetRPCServer(g_rpc_service.get());
std::thread server_thread(
std::bind(&distributed::RPCServer::StartServer, g_rpc_service.get()));
server_thread.join();
}
TEST(COMPLETE, CPU) { TEST(COMPLETE, CPU) {
setenv("http_proxy", "", 1); setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1); setenv("https_proxy", "", 1);
...@@ -147,3 +187,46 @@ TEST(COMPLETE, CPU) { ...@@ -147,3 +187,46 @@ TEST(COMPLETE, CPU) {
g_rpc_service.reset(nullptr); g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr); g_req_handler.reset(nullptr);
} }
TEST(SENDANDRECV, CPU) {
setenv("http_proxy", "", 1);
setenv("https_proxy", "", 1);
g_req_handler.reset(new distributed::RequestSendAndRecvHandler(
distributed::DistributedMode::kAsync));
g_rpc_service.reset(new RPCSERVER_T("127.0.0.1:0", 1));
distributed::RPCClient* client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);
PADDLE_ENFORCE_NE(client, nullptr,
platform::errors::InvalidArgument(
"Client Start Fail, Check Your Code & Env"));
std::thread server_thread(StartSendAndRecvServer,
distributed::kRequestSendAndRecv);
g_rpc_service->WaitServerReady();
int port = g_rpc_service->GetSelectedPort();
std::string ep = paddle::string::Sprintf("127.0.0.1:%d", port);
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
int64_t rows_numel = 10;
InitTensorsOnClient(&scope, &place, rows_numel);
std::string in_var_name("x");
std::string out_var_name("res");
client->AsyncSendAndRecv(ep, ctx, scope, in_var_name, out_var_name);
client->Wait();
auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::LoDTensor>();
auto ptr = value->mutable_data<float>(place);
for (int64_t i = 0; i < rows_numel; ++i) {
EXPECT_EQ(ptr[i], 0.5);
}
g_rpc_service->ShutDown();
server_thread.join();
LOG(INFO) << "begin reset";
g_rpc_service.reset(nullptr);
g_req_handler.reset(nullptr);
}
...@@ -29,7 +29,7 @@ service SendRecvService { ...@@ -29,7 +29,7 @@ service SendRecvService {
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {} rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
rpc DistributeNotify(VariableMessage) returns (VoidMessage) {} rpc DistributeNotify(VariableMessage) returns (VoidMessage) {}
rpc SendAndRecvVariable(VariableMessage) returns (VariableMessage) {}
rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {} rpc GetMonomerVariable(VariableMessage) returns (VariableMessage) {}
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {} rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
} }
......
...@@ -96,6 +96,13 @@ class VariableResponse { ...@@ -96,6 +96,13 @@ class VariableResponse {
return scope_->FindVar(meta_.varname()); return scope_->FindVar(meta_.varname());
} }
framework::Variable* GetRecvVar() {
if (create_scope_) {
return local_scope_->Var(meta_.out_varname());
}
return scope_->FindVar(meta_.out_varname());
}
int GetTrainerId() { return static_cast<int>(meta_.trainer_id()); } int GetTrainerId() { return static_cast<int>(meta_.trainer_id()); }
protected: protected:
......
...@@ -268,7 +268,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -268,7 +268,6 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
PADDLE_ENFORCE_GE(num_blocks, 2, PADDLE_ENFORCE_GE(num_blocks, 2,
"server program should have at least 2 blocks"); "server program should have at least 2 blocks");
std::vector<int> block_list; std::vector<int> block_list;
for (size_t blkid = 1; blkid < num_blocks; ++blkid) { for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
block_list.push_back(blkid); block_list.push_back(blkid);
...@@ -295,6 +294,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -295,6 +294,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx); request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
request_send_and_recv_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
while (true) { while (true) {
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit()) {
...@@ -394,6 +394,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -394,6 +394,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestGetNoBarrierHandler()); new distributed::RequestGetNoBarrierHandler());
request_notify_handler_.reset( request_notify_handler_.reset(
new distributed::RequestNotifyHandler(distributed_mode, fan_in)); new distributed::RequestNotifyHandler(distributed_mode, fan_in));
request_send_and_recv_handler_.reset(
new distributed::RequestSendAndRecvHandler(distributed_mode));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(), rpc_send_thread_num); request_send_handler_.get(), rpc_send_thread_num);
...@@ -408,6 +410,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -408,6 +410,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_no_barrier_handler_.get()); request_get_no_barrier_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestNotify, rpc_service_->RegisterRPC(distributed::kRequestNotify,
request_notify_handler_.get(), rpc_send_thread_num); request_notify_handler_.get(), rpc_send_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestSendAndRecv,
request_send_and_recv_handler_.get(),
rpc_get_thread_num);
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
...@@ -416,6 +421,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -416,6 +421,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
"optimize blocks is less than 1. Optimize blocks " "optimize blocks is less than 1. Optimize blocks "
"should be 1 at least on the pserver side.")); "should be 1 at least on the pserver side."));
auto *program = optimize_blocks[0]->Program(); auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr; std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
...@@ -488,6 +494,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -488,6 +494,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f(request_checkpoint_handler_.get()); f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get()); f(request_get_no_barrier_handler_.get());
f(request_notify_handler_.get()); f(request_notify_handler_.get());
f(request_send_and_recv_handler_.get());
// register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers // register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
signal(SIGINT, SignalHandler::StopAndExit); signal(SIGINT, SignalHandler::StopAndExit);
......
...@@ -99,6 +99,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -99,6 +99,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<distributed::RequestHandler> mutable std::shared_ptr<distributed::RequestHandler>
request_checkpoint_handler_; request_checkpoint_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_notify_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_notify_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_send_and_recv_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
mutable std::vector<std::string> sparse_vars_; mutable std::vector<std::string> sparse_vars_;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/distributed/communicator.h"
#include "paddle/fluid/operators/distributed/communicator_common.h"
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"
#include "paddle/fluid/operators/distributed_ops/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SendAndRecvKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& scope = ctx.scope();
const auto& place = ctx.GetPlace();
auto send_var_name = ctx.Attr<std::string>("send_var_name");
auto recv_var_name = ctx.Attr<std::string>("recv_var_name");
auto epmap = ctx.Attr<std::string>("endpoint");
auto trainer_id = ctx.Attr<int>("trainer_id");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& context = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
VLOG(3) << "SendAndRecvOp Send_var_name: " << send_var_name
<< " Recv_var_name: " << recv_var_name;
distributed::VarHandlePtr rets = rpc_client->AsyncSendAndRecv(
epmap, context, scope, send_var_name, recv_var_name);
rets->Wait();
}
};
class SendAndRecvOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, platform::CPUPlace());
}
};
class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "Tensor Input variable to be sent").AsDuplicable();
AddOutput("Out", "Tensor Output varibale to be recv").AsDuplicable();
AddAttr<std::string>("send_var_name", "Send Tensor's name")
.SetDefault(std::string(""));
AddAttr<std::string>("recv_var_name", "Recv Tensor's name")
.SetDefault(std::string(""));
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
AddAttr<std::string>("endpoint", "Server endpoint")
.SetDefault({"127.0.0.1:6164"});
AddComment(R"DOC(
SendAndRecv operator
This operator will send variables to listen_and_serve op at the parameter server.
And recv variable from parameter server of send variable's scope.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker);
REGISTER_OP_CPU_KERNEL(
send_and_recv,
ops::SendAndRecvKernel<paddle::platform::CPUDeviceContext, float>)
...@@ -200,7 +200,8 @@ class Fleet(object): ...@@ -200,7 +200,8 @@ class Fleet(object):
bool: True if this is a node of server, bool: True if this is a node of server,
False if not. False if not.
""" """
return self._role_maker.is_server() return self._role_maker.is_server(
) or self._role_maker._is_heter_worker()
@property @property
def util(self): def util(self):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Defination of Role Makers.""" """Defination of Role Makers."""
import os import os
import numpy as np import numpy as np
import warnings
from multiprocessing import Process, Manager from multiprocessing import Process, Manager
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -23,6 +24,7 @@ import paddle.fluid as fluid ...@@ -23,6 +24,7 @@ import paddle.fluid as fluid
class Role: class Role:
WORKER = 1 WORKER = 1
SERVER = 2 SERVER = 2
HETER_WORKER = 3
class RoleMakerBase(object): class RoleMakerBase(object):
...@@ -40,6 +42,11 @@ class RoleMakerBase(object): ...@@ -40,6 +42,11 @@ class RoleMakerBase(object):
self._role = None self._role = None
self._current_id = -1 self._current_id = -1
# for heter parameter server mode
self._heter_trainer_endpoints = []
self._heter_trainer_device = "CPU"
self._is_heter_parameter_server_mode = False
self._node_type = None self._node_type = None
self._node_type_comm = None self._node_type_comm = None
self._all_comm = None self._all_comm = None
...@@ -163,12 +170,58 @@ class RoleMakerBase(object): ...@@ -163,12 +170,58 @@ class RoleMakerBase(object):
""" """
print("warning: RoleMakerBase does not have barrier worker.") print("warning: RoleMakerBase does not have barrier worker.")
def _is_heter_worker(self):
"""
Return is_heter_worker() of current process
"""
warnings.warn("RoleMakerBase does not have function: _is_heter_worker.")
return False
def _heter_worker_num(self):
"""
Get current total heter-worker number.
Returns:
int: heter_worker number
"""
warnings.warn(
"RoleMakerBase does not have function: _heter_worker_num.")
return 0
def _get_heter_worker_endpoints(self):
"""
Returns:
string: all heter_trainers'endpoints
"""
assert self._heter_trainer_endpoints != []
return self._heter_trainer_endpoints
def _get_heter_worker_endpoint(self):
"""
Returns:
int: corresponding heter_trainer's endpoint
e.g: if we have 4 cpu-trainer(default), 2 gpu-trainer(heter)
then No.0 and No.2 cpu-trainer will work with No.0 gpu-trainer
and No.1 and No.3 cpu-trainer will work with No.1 gpu-trainerr
"""
assert self._heter_trainer_endpoints != []
return self._heter_trainer_endpoints[(self._current_id + 1) %
self._heter_worker_num()]
def _get_heter_worker_device(self):
"""
Returns:
string: heter_trainer's device of current node, e.g: CPU/GPU/XPU
"""
return self._heter_trainer_device.upper()
class PaddleCloudRoleMaker(RoleMakerBase): class PaddleCloudRoleMaker(RoleMakerBase):
def __init__(self, is_collective=False, **kwargs): def __init__(self, is_collective=False, **kwargs):
super(PaddleCloudRoleMaker, self).__init__() super(PaddleCloudRoleMaker, self).__init__()
self._is_collective = is_collective self._is_collective = is_collective
self._init_gloo = False #default no init gloo self._init_gloo = False # default no init gloo
self._kwargs = kwargs self._kwargs = kwargs
self._role_is_generated = False self._role_is_generated = False
...@@ -278,10 +331,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -278,10 +331,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
""" """
get index of current node get index of current node
""" """
if self.is_server(): return self._current_id
return self.server_index()
elif self.is_worker():
return self.worker_index()
def worker_num(self): def worker_num(self):
""" """
...@@ -323,6 +373,22 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -323,6 +373,22 @@ class PaddleCloudRoleMaker(RoleMakerBase):
self.generate_role() self.generate_role()
return self._server_endpoints return self._server_endpoints
def _heter_worker_num(self):
"""
get heter worker nums
"""
if not self._role_is_generated:
self.generate_role()
return self._heter_trainers_num
def _is_heter_worker(self):
"""
whether current process is heter worker
"""
if not self._role_is_generated:
self.generate_role()
return self._role == Role.HETER_WORKER
def _get_rank(self): def _get_rank(self):
""" """
get current rank in all workers and pservers get current rank in all workers and pservers
...@@ -342,17 +408,47 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -342,17 +408,47 @@ class PaddleCloudRoleMaker(RoleMakerBase):
def _ps_env(self): def _ps_env(self):
try: try:
# Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set
# format: string(ip:port), eg. 127.0.0.1:6001 # format: string(ip:port,ip:port), eg. 127.0.0.1:6001,127.0.0.1:6002
self._server_endpoints = os.environ[ self._server_endpoints = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST",
"PADDLE_PSERVERS_IP_PORT_LIST"].split(",") "").split(",")
assert self._server_endpoints != ""
self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS",
"").split(",") "").split(",")
assert self._server_endpoints != ""
trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"]) trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"])
training_role = os.environ["TRAINING_ROLE"] training_role = os.environ["TRAINING_ROLE"]
if training_role not in ["TRAINER", "PSERVER"]: if training_role not in ["TRAINER", "PSERVER", "HETER_TRAINER"]:
raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER") raise ValueError(
"TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.".
format(training_role))
# For heter parameter server env setting
heter_trainer_eplist = os.getenv(
"PADDLE_HETER_TRAINER_IP_PORT_LIST", None)
heter_trainer_device = os.getenv("PADDLE_HETER_TRAINER_DEVICE",
None)
if heter_trainer_eplist and heter_trainer_device:
try:
heter_trainer_eplist = os.environ[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",")
except:
raise ValueError(
"Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ."
)
self._is_heter_parameter_server_mode = True
heter_trainers_num = len(heter_trainer_eplist)
current_node_device = heter_trainer_device.upper()
if current_node_device not in ["CPU", "GPU", "XPU"]:
raise ValueError(
"Heter Trainer doesn't support {} device now, please use CPU / GPU / XPU(KunLun)".
format(heter_trainer_device))
self._heter_trainer_device = current_node_device
else:
self._is_heter_parameter_server_mode = False
heter_trainers_num = 0
if training_role == "TRAINER": if training_role == "TRAINER":
role = Role.WORKER role = Role.WORKER
...@@ -365,17 +461,26 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -365,17 +461,26 @@ class PaddleCloudRoleMaker(RoleMakerBase):
ip = os.environ["POD_IP"] ip = os.environ["POD_IP"]
self._cur_endpoint = ip + ":" + port self._cur_endpoint = ip + ":" + port
current_id = self._server_endpoints.index(self._cur_endpoint) current_id = self._server_endpoints.index(self._cur_endpoint)
elif training_role == "HETER_TRAINER":
role = Role.HETER_WORKER
cur_ip = os.environ["POD_IP"]
cur_port = os.environ["PADDLE_PORT"]
curr_endpoint = ":".join([cur_ip, cur_port])
current_id = heter_trainer_eplist.index(curr_endpoint)
else: else:
raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER")
except ValueError as ve:
raise ValueError( raise ValueError(
"something wrong with PaddleCloud, please check environment") "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER")
except ValueError as e:
raise ValueError(
"Something wrong with PaddleCloud, please check environment")
self._trainers_num = trainers_num self._trainers_num = trainers_num
self._role = role self._role = role
self._current_id = current_id self._current_id = current_id
self._node_num = len( self._node_num = len(
set([x.split(':')[0] for x in self._worker_endpoints])) set([x.split(':')[0] for x in self._worker_endpoints]))
self._heter_trainers_num = heter_trainers_num
self._heter_trainer_endpoints = heter_trainer_eplist
def _collective_env(self): def _collective_env(self):
self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
......
...@@ -15,10 +15,10 @@ from .amp_optimizer import AMPOptimizer ...@@ -15,10 +15,10 @@ from .amp_optimizer import AMPOptimizer
from .recompute_optimizer import RecomputeOptimizer from .recompute_optimizer import RecomputeOptimizer
from .gradient_merge_optimizer import GradientMergeOptimizer from .gradient_merge_optimizer import GradientMergeOptimizer
from .graph_execution_optimizer import GraphExecutionOptimizer from .graph_execution_optimizer import GraphExecutionOptimizer
from .async_optimizer import AsyncMetaOptimizer from .parameter_server_optimizer import ParameterServerOptimizer
from .pipeline_optimizer import PipelineOptimizer from .pipeline_optimizer import PipelineOptimizer
from .localsgd_optimizer import LocalSGDOptimizer from .localsgd_optimizer import LocalSGDOptimizer
from .lars_optimizer import LarsOptimizer from .lars_optimizer import LarsOptimizer
from .async_graph_execution_optimizer import AsyncGraphExecutionOptimizer from .parameter_server_graph_optimizer import ParameterServerGraphOptimizer
from .dgc_optimizer import DGCOptimizer from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer from .lamb_optimizer import LambOptimizer
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
from paddle import fluid from paddle import fluid
from paddle.fluid import compiler from paddle.fluid import compiler
from .async_optimizer import AsyncMetaOptimizer from .parameter_server_optimizer import ParameterServerOptimizer
class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer): class ParameterServerGraphOptimizer(ParameterServerOptimizer):
def __init__(self, optimizer): def __init__(self, optimizer):
super(AsyncGraphExecutionOptimizer, self).__init__(optimizer) super(ParameterServerGraphOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently # we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = [] self.meta_optimizers_white_list = []
...@@ -31,6 +31,9 @@ class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer): ...@@ -31,6 +31,9 @@ class AsyncGraphExecutionOptimizer(AsyncMetaOptimizer):
if self.role_maker.is_server(): if self.role_maker.is_server():
return False return False
if self.role_maker._is_heter_parameter_server_mode:
return False
return True return True
def _disable_strategy(self, dist_strategy): def _disable_strategy(self, dist_strategy):
......
...@@ -15,9 +15,9 @@ from paddle import fluid ...@@ -15,9 +15,9 @@ from paddle import fluid
from .meta_optimizer_base import MetaOptimizerBase from .meta_optimizer_base import MetaOptimizerBase
class AsyncMetaOptimizer(MetaOptimizerBase): class ParameterServerOptimizer(MetaOptimizerBase):
def __init__(self, optimizer): def __init__(self, optimizer):
super(AsyncMetaOptimizer, self).__init__(optimizer) super(ParameterServerOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently # we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = [] self.meta_optimizers_white_list = []
...@@ -68,6 +68,21 @@ class AsyncMetaOptimizer(MetaOptimizerBase): ...@@ -68,6 +68,21 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
_startup = worker.init_from_server_pass(_startup, compiled_config) _startup = worker.init_from_server_pass(_startup, compiled_config)
_startup = worker.delet_extra_optimizes_pass(_startup, _startup = worker.delet_extra_optimizes_pass(_startup,
compiled_config) compiled_config)
# for heter program
if self.role_maker._is_heter_parameter_server_mode:
from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker
if self.role_maker._is_heter_worker():
# for heter worker
_main = heter_worker.split_heter_worker_ops_pass(
_main, compiled_config)
else:
# for default worker
_main = heter_worker.split_trainer_ops_pass(_main,
compiled_config)
# for startup change
_startup = heter_worker.delete_startup_useless_ops_var_pass(
_startup, _main, compiled_config)
else: else:
_main = worker.append_send_ops_pass(_main, compiled_config) _main = worker.append_send_ops_pass(_main, compiled_config)
_startup = _startup _startup = _startup
...@@ -129,9 +144,12 @@ class AsyncMetaOptimizer(MetaOptimizerBase): ...@@ -129,9 +144,12 @@ class AsyncMetaOptimizer(MetaOptimizerBase):
_origin_startup_program, _origin_startup_program,
strategy, self.role_maker) strategy, self.role_maker)
main_program, startup_program = \ if self.role_maker.is_worker() or self.role_maker._is_heter_worker():
self._build_trainer_programs(compiled_config) if self.role_maker.is_worker() \ main_program, startup_program = self._build_trainer_programs(
else self._build_pserver_programs(compiled_config) compiled_config)
elif self.role_maker.is_server():
main_program, startup_program = self._build_pserver_programs(
compiled_config)
loss.block.program = main_program loss.block.program = main_program
fluid.framework.switch_startup_program(startup_program) fluid.framework.switch_startup_program(startup_program)
......
...@@ -196,6 +196,18 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -196,6 +196,18 @@ class ParameterServerRuntime(RuntimeBase):
else: else:
warnings.warn("communicator has been initialized, skip") warnings.warn("communicator has been initialized, skip")
def _get_executor(self):
if self.role_maker._is_heter_worker():
if self.role_maker._get_heter_worker_device() == "GPU":
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
executor = Executor(fluid.CUDAPlace(gpu_id))
else:
raise ValueError("Not Support Device {}".format(
self.role_maker._get_heter_worker_device()))
else:
executor = fluid.Executor(fluid.CPUPlace())
return executor
def _init_server(self, *args, **kwargs): def _init_server(self, *args, **kwargs):
if len(args) > 1: if len(args) > 1:
raise ValueError("init server can only accept 1 args: `dirname`") raise ValueError("init server can only accept 1 args: `dirname`")
...@@ -204,9 +216,15 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -204,9 +216,15 @@ class ParameterServerRuntime(RuntimeBase):
else: else:
model_dirname = None model_dirname = None
executor = fluid.Executor(fluid.CPUPlace()) if self.role_maker._is_heter_worker():
self._init_worker()
executor = self._get_executor()
executor.run(fluid.default_startup_program()) executor.run(fluid.default_startup_program())
if self.role_maker._is_heter_worker():
return
if not model_dirname: if not model_dirname:
return return
...@@ -237,12 +255,12 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -237,12 +255,12 @@ class ParameterServerRuntime(RuntimeBase):
# self._load_sparse_params(dirname=model_dir, varnames=distribtued_varnames) # self._load_sparse_params(dirname=model_dir, varnames=distribtued_varnames)
def _run_server(self): def _run_server(self):
executor = fluid.Executor(fluid.CPUPlace()) executor = self._get_executor()
executor.run(fluid.default_main_program()) executor.run(fluid.default_main_program())
def _stop_worker(self): def _stop_worker(self):
self._communicator.stop() self._communicator.stop()
executor = fluid.Executor(fluid.CPUPlace()) executor = self._get_executor()
executor.close() executor.close()
def _get_optimizer_status(self, op, param_name): def _get_optimizer_status(self, op, param_name):
......
...@@ -145,7 +145,7 @@ class Fleet(object): ...@@ -145,7 +145,7 @@ class Fleet(object):
Returns: Returns:
bool: True if this is a node of server, bool: True if this is a node of server,
False if not. False if not
""" """
return self._role_maker.is_server() return self._role_maker.is_server()
......
...@@ -343,7 +343,6 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -343,7 +343,6 @@ class MPISymetricRoleMaker(MPIRoleMaker):
def get_pserver_endpoints(self): def get_pserver_endpoints(self):
""" """
get pserver endpoints get pserver endpoints
Returns: Returns:
endpoints(list): pserver endpoints endpoints(list): pserver endpoints
""" """
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import warnings
import paddle.fluid.core as core
import paddle.fluid.framework as framework
from paddle.fluid.transpiler.details.program_utils import delete_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_heter_ops
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_heter_program
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_trainer_program
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_block_joints
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_op_input_output
from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import get_vars_name_in_block
def split_heter_worker_ops_pass(program, config):
"""
split heter worker program from origin-program
1. find heter op (located on different device)
2. find input&output of every heter-block
3. create heter worker program, add listen&serv op
"""
default_deveice = "cpu"
program, heter_ops, _, program_block_ops = find_heter_ops(program,
default_deveice)
if len(heter_ops) == 0:
warnings.warn(
"Currently running in Heter Parameter Server mode, but no OP running on heterogeneous devices, Please check your code."
)
return program
current_device = "gpu"
if current_device not in heter_ops:
raise ValueError("Op which run on device {} not exist.".format(
current_device))
block_vars_detail = find_block_joints(program, program_block_ops, heter_ops)
heter_program = framework.Program()
create_heter_program(program, config, heter_program, heter_ops,
block_vars_detail, current_device)
return heter_program
def split_trainer_ops_pass(program, config):
"""
split cpu-trainer program from origin-program
1. find heter op (located on different device)
2. find input&output of every heter-block
3. create cpu-trainer program, add send&recv op
"""
# Todo: support user define default_device (MrChengmo)
default_deveice = "cpu"
program, heter_ops, _, program_block_ops = find_heter_ops(program,
default_deveice)
block_vars_detail = find_block_joints(program, program_block_ops, heter_ops)
create_trainer_program(program, config, heter_ops, block_vars_detail)
return program
def delete_startup_useless_ops_var_pass(startup_program, main_program, config):
"""
delete variable which not used in current main_program
"""
# find all op and its var
vars_in_main_program = get_vars_name_in_block(main_program.global_block())
block_nums = startup_program.num_blocks
for block_index in range(1, block_nums):
current_block = startup_program.block(block_index)
# delete useless op
need_delete_op = []
for op in current_block.ops:
inputs, outputs = find_op_input_output(startup_program,
current_block, op)
inputs += outputs
# Todo: delete some concat op
if list(set(inputs) & set(vars_in_main_program)) == None:
need_delete_op.append(op)
delete_ops(current_block, need_delete_op)
# delete useless var
for var in current_block.vars:
if var.name not in vars_in_main_program:
startup_program._remove_var(var.name)
return startup_program
...@@ -12,33 +12,23 @@ ...@@ -12,33 +12,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Copyright(c) 2020 PaddlePaddle Authors.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0(the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
from functools import reduce from functools import reduce
import collections import collections
import math import math
import os import os
import warnings
import six import six
import paddle.fluid as fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.core import CommContext from paddle.fluid.core import CommContext
import paddle.fluid.framework as framework
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools
from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundRobin, PSDispatcher from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundRobin, PSDispatcher
from paddle.fluid.transpiler.details.program_utils import delete_ops
OP_NAME_SCOPE = "op_namescope" OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "@CLIP" CLIP_OP_NAME_SCOPE = "@CLIP"
...@@ -122,9 +112,20 @@ class MergedVariable: ...@@ -122,9 +112,20 @@ class MergedVariable:
self.offsets = offsets self.offsets = offsets
def Singleton(cls):
_instance = {}
def _singleton(*args, **kargs):
if cls not in _instance:
_instance[cls] = cls(*args, **kargs)
return _instance[cls]
return _singleton
@Singleton
class CompileTimeStrategy(object): class CompileTimeStrategy(object):
def __init__(self, main_program, startup_program, strategy, role_maker): def __init__(self, main_program, startup_program, strategy, role_maker):
self.min_block_size = 8192 self.min_block_size = 8192
self.origin_main_program = main_program self.origin_main_program = main_program
...@@ -177,6 +178,12 @@ class CompileTimeStrategy(object): ...@@ -177,6 +178,12 @@ class CompileTimeStrategy(object):
def get_ps_endpoints(self): def get_ps_endpoints(self):
return self.role_maker.get_pserver_endpoints() return self.role_maker.get_pserver_endpoints()
def get_heter_worker_endpoints(self):
return self.role_maker._get_heter_worker_endpoints()
def get_heter_worker_endpoint(self):
return self.role_maker._get_heter_worker_endpoint()
def get_origin_programs(self): def get_origin_programs(self):
return self.origin_main_program, self.origin_startup_program return self.origin_main_program, self.origin_startup_program
...@@ -810,6 +817,30 @@ class CompileTimeStrategy(object): ...@@ -810,6 +817,30 @@ class CompileTimeStrategy(object):
return sparse_param_grads, dense_param_grads return sparse_param_grads, dense_param_grads
def remove_var_pair_by_grad(self, var_name):
for index, pair in enumerate(self.merged_variables_pairs):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del self.merged_variables_pairs[index]
for index, pair in enumerate(self.merged_dense_pairs):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del self.merged_dense_pairs[index]
return
for index, pair in enumerate(self.merged_sparse_pairs):
var = pair[0]
var_grad = pair[1]
if var_grad.merged_var.name == var_name:
del self.merged_sparse_pairs[index]
return
print("Not find {} in self.merge_pairs".format(var_name))
def _is_opt_role_op(op): def _is_opt_role_op(op):
# NOTE : depend on oprole to find out whether this op is for # NOTE : depend on oprole to find out whether this op is for
......
...@@ -17,8 +17,9 @@ from __future__ import print_function ...@@ -17,8 +17,9 @@ from __future__ import print_function
import os import os
import logging import logging
import tarfile import tarfile
import tempfile
import random import random
import warnings
import paddle import paddle
import paddle.fluid.incubate.data_generator as data_generator import paddle.fluid.incubate.data_generator as data_generator
...@@ -57,7 +58,7 @@ def load_dnn_input_record(sent): ...@@ -57,7 +58,7 @@ def load_dnn_input_record(sent):
def load_lr_input_record(sent): def load_lr_input_record(sent):
res = [] res = []
for _ in [x.split(':') for x in sent.split()]: for _ in [x.split(':') for x in sent.split()]:
res.append(int(_[0])) res.append(int(_[0]) % 10000)
return res return res
...@@ -120,9 +121,62 @@ def prepare_data(): ...@@ -120,9 +121,62 @@ def prepare_data():
lr_input_dim = res[1] lr_input_dim = res[1]
logger.info('dnn input dim: %d' % dnn_input_dim) logger.info('dnn input dim: %d' % dnn_input_dim)
logger.info('lr input dim: %d' % lr_input_dim) logger.info('lr input dim: %d' % lr_input_dim)
return dnn_input_dim, lr_input_dim, train_file_path return dnn_input_dim, lr_input_dim, train_file_path
def gen_fake_line(dnn_data_num=7,
dnn_data_range=1e5,
lr_data_num=5,
lr_data_range=1e5):
line = ""
# for deep data
for index in range(dnn_data_num):
data = str(random.randint(0, dnn_data_range - 1))
if index < dnn_data_num - 1:
data += " "
line += data
line += "\t"
# for wide data
for index in range(lr_data_num):
data = str(random.randint(0, lr_data_range - 1)) + ":" + str(1)
if index < lr_data_num - 1:
data += " "
line += data
line += "\t"
# for label
line += str(random.randint(0, 1))
line += "\n"
return line
def prepare_fake_data(file_nums=8, file_lines=1000):
"""
Create fake data with same type as avazu_ctr_data
"""
file_dir = tempfile.mkdtemp()
warnings.warn("Fake data write in {}".format(file_dir))
for file_index in range(file_nums):
with open(
os.path.join(file_dir,
"ctr_train_data_part_{}".format(file_index)),
'w+') as fin:
file_str = ""
for line_index in range(file_lines):
file_str += gen_fake_line()
fin.write(file_str)
warnings.warn("Write done ctr_train_data_part_{}".format(
file_index))
file_list = [os.path.join(file_dir, x) for x in os.listdir(file_dir)]
assert len(file_list) == file_nums
return file_list
if __name__ == "__main__": if __name__ == "__main__":
pairwise_reader = DatasetCtrReader() pairwise_reader = DatasetCtrReader()
pairwise_reader.run_from_stdin() pairwise_reader.run_from_stdin()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Distribute CTR model for test fleet api
"""
from __future__ import print_function
import shutil
import tempfile
import time
import paddle
import paddle.fluid as fluid
import os
import numpy as np
import ctr_dataset_reader
from test_dist_fleet_heter_base import runtime_main, FleetDistHeterRunnerBase
from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader
from paddle.distributed.fleet.base.util_factory import fleet_util
# Fix seed for test
fluid.default_startup_program().random_seed = 1
fluid.default_main_program().random_seed = 1
class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
"""
For test CTR model, using Fleet api
"""
def net(self, args, batch_size=4, lr=0.01):
"""
network definition
Args:
batch_size(int): the size of mini-batch for training
lr(float): learning rate of training
Returns:
avg_cost: LoDTensor of cost.
"""
dnn_input_dim, lr_input_dim = int(1e5), int(1e5)
dnn_data = fluid.layers.data(
name="dnn_data",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
lr_data = fluid.layers.data(
name="lr_data",
shape=[-1, 1],
dtype="int64",
lod_level=1,
append_batch_size=False)
label = fluid.layers.data(
name="click",
shape=[-1, 1],
dtype="float32",
lod_level=0,
append_batch_size=False)
datas = [dnn_data, lr_data, label]
if args.reader == "pyreader":
self.reader = fluid.io.PyReader(
feed_list=datas,
capacity=64,
iterable=False,
use_double_buffer=False)
# build dnn model
dnn_layer_dims = [128, 64, 32, 1]
dnn_embedding = fluid.layers.embedding(
is_distributed=False,
input=dnn_data,
size=[dnn_input_dim, dnn_layer_dims[0]],
param_attr=fluid.ParamAttr(
name="deep_embedding",
initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=True)
dnn_pool = fluid.layers.sequence_pool(
input=dnn_embedding, pool_type="sum")
dnn_out = dnn_pool
# build lr model
lr_embbding = fluid.layers.embedding(
is_distributed=False,
input=lr_data,
size=[lr_input_dim, 1],
param_attr=fluid.ParamAttr(
name="wide_embedding",
initializer=fluid.initializer.Constant(value=0.01)),
is_sparse=True)
lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum")
with fluid.device_guard("gpu"):
for i, dim in enumerate(dnn_layer_dims[1:]):
fc = fluid.layers.fc(
input=dnn_out,
size=dim,
act="relu",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)),
name='dnn-fc-%d' % i)
dnn_out = fc
merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1)
label = fluid.layers.cast(label, dtype="int64")
predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
fluid.layers.Print(avg_cost, message="avg_cost")
self.feeds = datas
self.train_file_path = ["fake1", "fake2"]
self.avg_cost = avg_cost
self.predict = predict
return avg_cost
def check_model_right(self, dirname):
model_filename = os.path.join(dirname, "__model__")
with open(model_filename, "rb") as f:
program_desc_str = f.read()
program = fluid.Program.parse_from_string(program_desc_str)
with open(os.path.join(dirname, "__model__.proto"), "w") as wn:
wn.write(str(program))
def do_pyreader_training(self, fleet):
"""
do training using dataset, using fetch handler to catch variable
Args:
fleet(Fleet api): the fleet object of Parameter Server, define distribute training role
"""
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader)
for epoch_id in range(1):
self.reader.start()
try:
pass_start = time.time()
while True:
exe.run(program=fluid.default_main_program())
pass_time = time.time() - pass_start
except fluid.core.EOFException:
self.reader.reset()
fleet.stop_worker()
def do_dataset_training(self, fleet):
train_file_list = ctr_dataset_reader.prepare_fake_data()
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
thread_num = 1
batch_size = 128
filelist = fleet_util.get_file_shard(train_file_list)
print("filelist: {}".format(filelist))
# config dataset
dataset = paddle.distributed.fleet.DatasetFactory().create_dataset()
dataset.set_batch_size(batch_size)
dataset.set_use_var(self.feeds)
pipe_command = 'python ctr_dataset_reader.py'
dataset.set_pipe_command(pipe_command)
dataset.set_filelist(filelist)
dataset.set_thread(thread_num)
for epoch_id in range(1):
pass_start = time.time()
dataset.set_filelist(filelist)
exe.train_from_dataset(
program=fluid.default_main_program(),
dataset=dataset,
fetch_list=[self.avg_cost],
fetch_info=["cost"],
print_period=2,
debug=int(os.getenv("Debug", "0")))
pass_time = time.time() - pass_start
print("do_dataset_training done. using time {}".format(pass_time))
if os.getenv("SAVE_MODEL") == "1":
model_dir = tempfile.mkdtemp()
fleet.save_inference_model(exe, model_dir,
[feed.name for feed in self.feeds],
self.avg_cost)
self.check_model_right(model_dir)
shutil.rmtree(model_dir)
fleet.stop_worker()
print("do_dataset_training stop worker.")
if __name__ == "__main__":
runtime_main(TestHeterPsCTR2x2)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
"""
high level unit test for distribute fleet.
"""
import os
import sys
import subprocess
import six
import shutil
import numpy as np
import argparse
from contextlib import closing
import socket
import time
import tempfile
import unittest
import paddle
import paddle.fluid as fluid
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet.base.util_factory import fleet_util
from paddle.distributed.fleet import fleet
__all__ = ['FleetDistHeterRunnerBase', 'TestFleetHeterBase', 'runtime_main']
RUN_STEP = 5
LEARNING_RATE = 0.01
DIST_UT_PORT = 0
class FleetDistHeterRunnerBase(object):
"""
run_pserver,run_trainer : after init role, using transpiler split program
net : implment by child class, the network of model
do training : exe run program
"""
def build_role(self, args):
environs = {}
environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints
environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints
environs[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"] = args.heter_trainer_endpoints
environs["PADDLE_HETER_TRAINER_DEVICE"] = args.heter_trainer_device
environs["TRAINING_ROLE"] = args.role.upper()
environs["PADDLE_TRAINERS_NUM"] = args.trainers
environs["PADDLE_TRAINER_ID"] = args.current_id
if args.role.upper() == "PSERVER":
environs["POD_IP"] = args.endpoints.split(",")[int(
args.current_id)].split(":")[0]
environs["PADDLE_PORT"] = args.endpoints.split(",")[int(
args.current_id)].split(":")[1]
elif args.role.upper() == "HETER_TRAINER":
environs["POD_IP"] = args.heter_trainer_endpoints.split(",")[int(
args.current_id)].split(":")[0]
environs["PADDLE_PORT"] = args.heter_trainer_endpoints.split(",")[
int(args.current_id)].split(":")[1]
environs["FLAGS_selected_gpus"] = args.current_id
for k, v in environs.items():
os.environ[k] = str(v)
self.role = role_maker.PaddleCloudRoleMaker()
return self.role
def build_strategy(self, args):
self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.a_sync = True
return self.strategy
def build_optimizer(self, avg_cost, strategy):
optimizer = fluid.optimizer.SGD(LEARNING_RATE)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
def run_pserver(self, args):
fleet.init_server()
fleet.run_server()
def run_dataset_trainer(self, args):
out = self.do_dataset_training(fleet)
def run_pyreader_trainer(self, args):
out = self.do_pyreader_training(fleet)
def net(self, args, batch_size=4, lr=0.01):
raise NotImplementedError(
"get_model should be implemented by child classes.")
def do_dataset_training(self, fleet):
raise NotImplementedError(
"do_dataset_training should be implemented by child classes.")
def do_pyreader_training(self, fleet):
raise NotImplementedError(
"do_pyreader_training should be implemented by child classes.")
class TestFleetHeterBase(unittest.TestCase):
"""
start_pserver,start_trainer : add start cmd to test
run_cluster : using multi process to test distribute program
"""
def _setup_config(self):
raise NotImplementedError("tests should have _setup_config implemented")
def tearDown(self):
t = time.time() - self.startTime
print('%s: %.3f' % (self.__class__.__name__, t))
def setUp(self):
self.startTime = time.time()
self._mode = "async"
self._reader = "pyreader"
self._trainers = 2
self._pservers = 2
self._port_set = set()
self._heter_device = "gpu"
global DIST_UT_PORT
if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"):
DIST_UT_PORT = int(os.getenv("PADDLE_DIST_UT_PORT"))
if DIST_UT_PORT:
print("set begin_port:", DIST_UT_PORT)
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT, DIST_UT_PORT + 1)
self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT + 2, DIST_UT_PORT + 3)
self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
DIST_UT_PORT + 4, DIST_UT_PORT + 5)
DIST_UT_PORT += 6
else:
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
self._geo_sgd_need_push_nums = 5
self._grad_clip_mode = 0
self._setup_config()
def _find_free_port(self):
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
s.bind(('', 0))
return s.getsockname()[1]
while True:
port = __free_port()
if port not in self._port_set:
self._port_set.add(port)
return port
def _start_pserver(self, cmd, required_envs):
ps0_cmd, ps1_cmd = cmd.format(0), cmd.format(1)
ps0_pipe = open(tempfile.gettempdir() + "/ps0_err.log", "wb+")
ps1_pipe = open(tempfile.gettempdir() + "/ps1_err.log", "wb+")
ps0_proc = subprocess.Popen(
ps0_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps0_pipe,
env=required_envs)
ps1_proc = subprocess.Popen(
ps1_cmd.strip().split(" "),
stdout=subprocess.PIPE,
stderr=ps1_pipe,
env=required_envs)
return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
def _start_trainer(self, cmd, required_envs):
tr0_cmd, tr1_cmd = cmd.format(0), cmd.format(1)
tr0_pipe = open(tempfile.gettempdir() + "/tr0_err.log", "wb+")
tr1_pipe = open(tempfile.gettempdir() + "/tr1_err.log", "wb+")
tr0_out = open(tempfile.gettempdir() + "/tr0_out.log", "wb+")
tr1_out = open(tempfile.gettempdir() + "/tr1_out.log", "wb+")
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(" "),
stdout=tr0_out,
stderr=tr0_pipe,
env=required_envs)
tr1_proc = subprocess.Popen(
tr1_cmd.strip().split(" "),
stdout=tr1_out,
stderr=tr1_pipe,
env=required_envs)
return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe
def _start_heter_trainer(self, cmd, required_envs):
heter0_cmd, heter1_cmd = cmd.format(0), cmd.format(1)
heter0_pipe = open(tempfile.gettempdir() + "/heter0_err.log", "wb+")
heter1_pipe = open(tempfile.gettempdir() + "/heter1_err.log", "wb+")
heter0_out = open(tempfile.gettempdir() + "/heter0_out.log", "wb+")
heter1_out = open(tempfile.gettempdir() + "/heter1_out.log", "wb+")
heter0_proc = subprocess.Popen(
heter0_cmd.strip().split(" "),
stdout=heter0_out,
stderr=heter0_pipe,
env=required_envs)
heter1_proc = subprocess.Popen(
heter1_cmd.strip().split(" "),
stdout=heter1_out,
stderr=heter1_pipe,
env=required_envs)
return heter0_proc, heter1_proc, heter0_pipe, heter1_pipe
def _run_cluster(self, model, envs):
env = {'GRAD_CLIP': str(self._grad_clip_mode)}
python_path = self._python_interp
gloo_path = tempfile.mkdtemp()
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '')
python_path += " -m coverage run --branch -p"
env.update(envs)
tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format(
python_path, model, self._ps_endpoints, self._tr_endpoints,
self._trainers, self._mode, self._geo_sgd_need_push_nums,
self._reader, gloo_path, self._heter_endpoints, self._heter_device)
ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format(
python_path, model, self._ps_endpoints, self._tr_endpoints,
self._trainers, self._mode, self._geo_sgd_need_push_nums,
self._reader, gloo_path, self._heter_endpoints, self._heter_device)
heter_cmd = "{0} {1} --role heter_trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format(
python_path, model, self._ps_endpoints, self._tr_endpoints,
self._trainers, self._mode, self._geo_sgd_need_push_nums,
self._reader, gloo_path, self._heter_endpoints, self._heter_device)
# Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env)
tr0, tr1, tr0_pipe, tr1_pipe = self._start_trainer(tr_cmd, env)
heter0, heter1, heter0_pipe, heter1_pipe = self._start_heter_trainer(
heter_cmd, env)
# Wait until trainer process terminate
while True:
stat0 = tr0.poll()
time.sleep(0.1)
if stat0 is not None:
break
while True:
stat1 = tr1.poll()
time.sleep(0.1)
if stat1 is not None:
break
tr0_out, tr0_err = tr0.communicate()
tr1_out, tr1_err = tr1.communicate()
print("tr end communicate")
tr0_ret = tr0.returncode
tr1_ret = tr0.returncode
print("tr get returncode: {}".format(tr0_ret))
if tr0_ret != 0:
print(
"========================Error tr0_err begin==========================="
)
os.system("cat {}".format(tempfile.gettempdir() + "/tr0_err.log"))
print(
"========================Error tr0_err end==========================="
)
if tr1_ret != 0:
print(
"========================Error tr1_err begin==========================="
)
os.system("cat {}".format(tempfile.gettempdir() + "/tr1_err.log"))
print(
"========================Error tr1_err end==========================="
)
self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
# close trainer file
tr0_pipe.close()
tr1_pipe.close()
ps0_pipe.close()
ps1_pipe.close()
heter0_pipe.close()
heter1_pipe.close()
ps0.terminate()
ps1.terminate()
heter0.terminate()
heter1.terminate()
shutil.rmtree(gloo_path)
return 0, 0
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def runtime_main(test_class):
parser = argparse.ArgumentParser(description='Run Fleet test.')
parser.add_argument(
'--role',
type=str,
required=True,
choices=['pserver', 'trainer', 'heter_trainer'])
parser.add_argument('--endpoints', type=str, required=False, default="")
parser.add_argument(
'--trainer_endpoints', type=str, required=False, default="")
parser.add_argument(
'--heter_trainer_endpoints', type=str, required=False, default="")
parser.add_argument(
'--heter_trainer_device', type=str, required=False, default="gpu")
parser.add_argument('--gloo_path', type=str, required=False, default="")
parser.add_argument('--current_id', type=int, required=False, default=0)
parser.add_argument('--trainers', type=int, required=False, default=1)
parser.add_argument('--mode', type=str, required=False, default='async')
parser.add_argument(
'--geo_sgd_need_push_nums', type=int, required=False, default=2)
parser.add_argument('--reader', type=str, required=False, default='dataset')
args = parser.parse_args()
model = test_class()
role = model.build_role(args)
fleet.init(role)
strategy = model.build_strategy(args)
avg_cost = model.net(args)
model.build_optimizer(avg_cost, strategy)
fleet_util._set_strategy(strategy)
fleet_util._set_role_maker(role)
if args.role == "pserver" or args.role == "heter_trainer":
model.run_pserver(args)
else:
if args.reader == "dataset":
model.run_dataset_trainer(args)
else:
model.run_pyreader_trainer(args)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import tempfile
from test_dist_fleet_heter_base import TestFleetHeterBase
class TestDistHeterDatasetAsync2x2(TestFleetHeterBase):
def _setup_config(self):
self._mode = "async"
self._reader = "dataset"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": "",
"CPU_NUM": "1"
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "4"
required_envs["GLOG_logtostderr"] = "1"
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
def test_dist_train(self):
self.check_with_place(
"dist_fleet_heter_ctr.py", delta=1e-5, check_error_log=True)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import os
import math
import paddle.fluid as fluid
import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet.base.util_factory import fleet_util
from paddle.distributed.fleet import fleet
class TestDistFleetHeterProgram(unittest.TestCase):
def build_role(self):
environs = {}
environs[
"PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36012,127.0.0.1:36013"
environs["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36014,127.0.0.1:36015"
environs[
"PADDLE_HETER_TRAINER_IP_PORT_LIST"] = "127.0.0.1:36016,127.0.0.1:36017"
environs["PADDLE_HETER_TRAINER_DEVICE"] = "gpu"
environs["TRAINING_ROLE"] = "HETER_TRAINER"
environs["PADDLE_TRAINERS_NUM"] = 2
environs["PADDLE_TRAINER_ID"] = 0
environs["POD_IP"] = "127.0.0.1"
environs["PADDLE_PORT"] = "36016"
environs["FLAGS_selected_gpus"] = 0
for k, v in environs.items():
os.environ[k] = str(v)
self.role = role_maker.PaddleCloudRoleMaker()
return self.role
def build_strategy(self):
self.strategy = paddle.distributed.fleet.DistributedStrategy()
self.strategy.a_sync = True
return self.strategy
def build_input(self):
dense_input = fluid.layers.data(
name="dense_input", shape=[10], dtype="float32")
sparse_input_ids = [
fluid.layers.data(
name="C" + str(i), shape=[1], lod_level=1, dtype="int64")
for i in range(1, 27)
]
label = fluid.layers.data(name="label", shape=[1], dtype="float32")
inputs = [dense_input] + sparse_input_ids + [label]
return inputs
def build_net(self, inputs):
def embedding_layer(input):
return fluid.layers.embedding(
input=input,
is_sparse=True,
size=[100001, 10],
param_attr=fluid.ParamAttr(
name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()), )
sparse_embed_seq = list(map(embedding_layer, inputs[1:-1]))
concated = fluid.layers.concat(sparse_embed_seq + inputs[0:1], axis=1)
with fluid.device_guard("gpu"):
fc1 = fluid.layers.fc(
input=concated,
size=400,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(concated.shape[1]))),
name="fc1")
with fluid.device_guard("cpu"):
fc2 = fluid.layers.fc(input=fc1,
size=400,
act="relu",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc1.shape[1]))),
name="fc2")
with fluid.device_guard("gpu"):
fc3 = fluid.layers.fc(input=fc2,
size=400,
act="relu",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc2.shape[1]))),
name="fc3")
with fluid.device_guard("cpu"):
predict = fluid.layers.fc(
input=fc3,
size=2,
act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc3.shape[1]))), )
with fluid.device_guard("gpu"):
labels = fluid.layers.cast(inputs[-1], dtype="int64")
cost = fluid.layers.cross_entropy(input=predict, label=labels)
avg_cost = fluid.layers.reduce_sum(cost)
return avg_cost
def build_optimizer(self, avg_cost, strategy):
optimizer = fluid.optimizer.SGD(1e-2)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost)
def test(self):
role = self.build_role()
fleet.init(role)
strategy = self.build_strategy()
inputs = self.build_input()
avg_cost = self.build_net(inputs)
self.build_optimizer(avg_cost, strategy)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册