diff --git a/paddle/fluid/distributed/service/brpc_utils.cc b/paddle/fluid/distributed/service/brpc_utils.cc index a356b77e73733ed9b657a7603adf57c5228bf3c5..92dcde99cccb0b484119f6326d97a2057f109c9f 100644 --- a/paddle/fluid/distributed/service/brpc_utils.cc +++ b/paddle/fluid/distributed/service/brpc_utils.cc @@ -138,23 +138,11 @@ void SerializeSelectedRows(framework::Variable* var, var_data->clear(); var_data->resize(rows->size() * sizeof(int64_t)); char* data_ptr = const_cast(var_data->data()); - - if (platform::is_cpu_place(tensor->place())) { - memcpy(data_ptr, &(*rows)[0], rows->size() * sizeof(int64_t)); - } else { -#ifdef PADDLE_WITH_CUDA - auto stream = - reinterpret_cast(ctx).stream(); - memory::Copy(platform::CPUPlace(), data_ptr, - BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), - &(*rows)[0], rows->size() * sizeof(int64_t), stream); -#endif - } + memcpy(data_ptr, &((*rows)[0]), rows->size() * sizeof(int64_t)); var_msg->set_data_type(static_cast(tensor->type())); for (auto& dim : framework::vectorize(tensor->dims())) { var_msg->add_dims(dim); } - // IO Buffer if (platform::is_cpu_place(tensor->place())) { auto data_len = tensor->numel() * framework::SizeOfType(tensor->type()); @@ -273,8 +261,8 @@ void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, auto* slr = var->GetMutable(); framework::Tensor* tensor = slr->mutable_value(); slr->set_height(msg.slr_height()); - std::vector tmp_rows(msg.slr_height()); - memcpy(&tmp_rows[0], msg.data().data(), msg.slr_height() * sizeof(int64_t)); + std::vector tmp_rows(msg.dims()[0]); + memcpy(tmp_rows.data(), msg.data().data(), msg.dims()[0] * sizeof(int64_t)); slr->set_rows(tmp_rows); std::vector vec_dim; for (auto& x : msg.dims()) { diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc index 30529d73fa1995c4141eb24e331d9cc10609af0b..a016c478846cbb8ca7ce4e521e027c17e00671f6 100644 --- a/paddle/fluid/distributed/service/communicator.cc +++ b/paddle/fluid/distributed/service/communicator.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/distributed/service/communicator.h" - #include #include "gflags/gflags.h" @@ -361,6 +360,8 @@ void Communicator::InitParams(const RecvCtxMap &recv_varname_to_ctx) { << " from 0' trainer done"; } } + std::this_thread::sleep_for( + std::chrono::milliseconds(100 + trainer_id_ * 10)); BarrierWithTable(1); return; } @@ -518,7 +519,6 @@ void AsyncCommunicator::SendByCommunicator() { MergeVars(var_name, vars[i], send_scope_.get(), 1); } } - if (ctx.is_tensor_table) { SendGlobalStep(ctx, merged_var_num, send_scope_.get()); } else if (ctx.is_sparse) { diff --git a/paddle/fluid/distributed/service/heter_client.cc b/paddle/fluid/distributed/service/heter_client.cc index 10fc8368a26a92671c0df7cf3fd236db455db65c..d9ec6b21fd37717d31906c6d6e5a668a7d295d31 100644 --- a/paddle/fluid/distributed/service/heter_client.cc +++ b/paddle/fluid/distributed/service/heter_client.cc @@ -25,6 +25,36 @@ namespace distributed { std::shared_ptr HeterClient::s_instance_ = NULL; bool HeterClient::is_initialized_ = false; +int GetMicroId(const platform::DeviceContext& ctx, + const framework::Scope* scope) { + framework::Variable* var = scope->FindVar("microbatch_id"); + PADDLE_ENFORCE_EQ(var->IsType(), true, + platform::errors::InvalidArgument( + "the type of micro id shoulde be LoDTensor.")); + auto micro_id = -1; + auto* tensor = var->GetMutable(); + if (platform::is_cpu_place(tensor->place())) { + auto data = reinterpret_cast(tensor->data()); + micro_id = static_cast(data[0]); + } else { +#ifdef PADDLE_WITH_CUDA + std::vector temp; + temp.resize(tensor->numel() * framework::SizeOfType(tensor->type())); + char* temp_ptr = temp.data(); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(platform::CPUPlace(), temp_ptr, + BOOST_GET_CONST(platform::CUDAPlace, tensor->place()), + tensor->data(), + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); + float* temp_ptr_float = reinterpret_cast(temp_ptr); + micro_id = static_cast(temp_ptr_float[0]); +#endif + } + return micro_id; +} + void HeterClient::MainThread() { while (running_) { RpcProfilerControl(); @@ -99,43 +129,68 @@ void HeterClient::CreateClient2XpuConnection() { } } } + previous_xpu_channels_.resize(previous_xpu_list_.size()); + for (size_t i = 0; i < previous_xpu_list_.size(); ++i) { + previous_xpu_channels_[i].reset(new brpc::Channel()); + if (previous_xpu_channels_[i]->Init(previous_xpu_list_[i].c_str(), "", + &options) != 0) { + VLOG(0) << "HeterClient channel init fail. Try Again"; + auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':'); + std::string ip = ip_port[0]; + int port = std::stoi(ip_port[1]); + std::string int_ip_port = GetIntTypeEndpoint(ip, port); + if (previous_xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) != + 0) { + LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port; + } + } + } } void HeterClient::SendAndRecvAsync( - const std::vector& ep, const platform::DeviceContext& ctx, - const framework::Scope& scope, const std::string& message_name, + const platform::DeviceContext& ctx, const framework::Scope& scope, + const std::string& message_name, const std::vector& send_var_name, - const std::vector& recv_var_name) { + const std::vector& recv_var_name, const std::string& mode) { platform::RecordEvent record_event("HeterClient->SendAndRecvAsync"); const platform::DeviceContext* p_ctx = &ctx; const framework::Scope* p_scope = &scope; const std::string message_name_val = message_name; const std::vector send_var_name_val = send_var_name; const std::vector recv_var_name_val = recv_var_name; - - VLOG(3) << "GRPCClient::SendAndRecv Begin, message_name: " + VLOG(3) << "BRPCClient::SendAndRecv Begin, message_name: " << message_name_val; - // Todo: get correct channel - int num = trainer_id_ % xpu_channels_.size(); - - brpc::Controller cntl; - cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); - distributed::MultiVarMsg request, response; - auto& request_io_buffer = cntl.request_attachment(); - ::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get()); + brpc::Channel* channel = nullptr; + distributed::MultiVarMsg request; + OnHeterRpcDone* closure = new OnHeterRpcDone([p_ctx, p_scope](void* done) { + auto* closure = reinterpret_cast(done); + PADDLE_ENFORCE_NE( + closure->cntl.Failed(), true, + platform::errors::Unimplemented( + "HeterClient::SendAndRecv meets brpc error, error message is %s", + closure->cntl.ErrorText())); + + VLOG(4) << "call heter_worker success"; + }); + closure->cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); + auto& request_io_buffer = closure->cntl.request_attachment(); distributed::SerializeToMultiVarMsgAndIOBuf( message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, &request, &request_io_buffer); - stub.SendAndRecvVariable(&cntl, &request, &response, NULL); - PADDLE_ENFORCE_NE( - cntl.Failed(), true, - platform::errors::Unimplemented( - "HeterClient::SendAndRecv meets brpc error, error message is %s", - cntl.ErrorText())); - VLOG(4) << "call heter_worker success"; - auto& response_io_buffer = cntl.response_attachment(); - distributed::DeserializeFromMultiVarMsgAndIOBuf(response, &response_io_buffer, - ctx, p_scope); + + int micro_id = GetMicroId(ctx, p_scope); + auto minibatch_id = micro_id / 10; + // select channel according to micro id + if (mode == "forward") { + int num = minibatch_id % xpu_channels_.size(); + channel = xpu_channels_[num].get(); + } else if (mode == "backward") { + int num = minibatch_id % previous_xpu_channels_.size(); + channel = previous_xpu_channels_[num].get(); + } + ::paddle::distributed::PsService_Stub stub(channel); + stub.SendAndRecvVariable(&closure->cntl, &request, &closure->response, + closure); } std::future HeterClient::SendCmd( diff --git a/paddle/fluid/distributed/service/heter_client.h b/paddle/fluid/distributed/service/heter_client.h index 31227386c5c980abe1100bd2614ca76d5df6961a..5fa49bc2411c9d34a647dcc598bbdc7338a4930c 100644 --- a/paddle/fluid/distributed/service/heter_client.h +++ b/paddle/fluid/distributed/service/heter_client.h @@ -76,20 +76,23 @@ class HeterClient { void CreateClient2XpuConnection(); - void SendAndRecvAsync(const std::vector& ep, - const platform::DeviceContext& ctx, + void SendAndRecvAsync(const platform::DeviceContext& ctx, const framework::Scope& scope, const std::string& message_name, const std::vector& send_var_name, - const std::vector& recv_var_name); + const std::vector& recv_var_name, + const std::string& mode = "forward"); // HeterClient singleton static std::shared_ptr GetInstance( - const std::vector& endpoint, const int& trainer_id) { + const std::vector& endpoint, + const std::vector& previous_endpoint, + const int& trainer_id) { if (NULL == s_instance_) { is_initialized_ = true; s_instance_.reset(new paddle::distributed::HeterClient()); s_instance_->SetXpuList(endpoint); + s_instance_->SetPreviousXpuList(previous_endpoint); s_instance_->SetTrainerID(trainer_id); s_instance_->CreateClient2XpuConnection(); } @@ -118,6 +121,10 @@ class HeterClient { xpu_list_ = xpu_list; } + void SetPreviousXpuList(const std::vector& xpu_list) { + previous_xpu_list_ = xpu_list; + } + void SetTrainerID(const int& trainer_id) { trainer_id_ = trainer_id; } private: @@ -125,9 +132,11 @@ class HeterClient { static bool is_initialized_; std::unique_ptr main_thread_{nullptr}; std::vector> xpu_channels_; + std::vector> previous_xpu_channels_; DISABLE_COPY_AND_ASSIGN(HeterClient); std::vector xpu_list_; + std::vector previous_xpu_list_; bool running_ = false; int trainer_id_; diff --git a/paddle/fluid/distributed/service/heter_server.cc b/paddle/fluid/distributed/service/heter_server.cc index 57a1a16a723830450fcc7c9ae094e3a35e0ff3ee..035668f2bc7af067a210fa2d8d22af1e6dea4ffb 100644 --- a/paddle/fluid/distributed/service/heter_server.cc +++ b/paddle/fluid/distributed/service/heter_server.cc @@ -46,20 +46,20 @@ void HeterServer::StartHeterService() { ready_ = 1; } condition_ready_.notify_all(); - std::unique_lock running_lock(mutex_); + stoped_ = false; cv_.wait(running_lock, [&] { VLOG(1) << "Heter Server is Stop? " << stoped_; return stoped_; }); } -void HeterServer::SetEndPoint(std::string& endpoint) { +void HeterServer::SetEndPoint(const std::string& endpoint) { endpoint_ = endpoint; service_.SetEndpoint(endpoint); } -void HeterServer::SetFanin(int& fan_in) { service_.SetFanin(fan_in); } +void HeterServer::SetFanin(const int& fan_in) { service_.SetFanin(fan_in); } void HeterServer::WaitServerReady() { std::unique_lock lock(this->mutex_ready_); diff --git a/paddle/fluid/distributed/service/heter_server.h b/paddle/fluid/distributed/service/heter_server.h index 93fa37454a5743918160732768fbb88e245f9343..76717a4674add8acef7a97f1dbef57b8adfd37a6 100644 --- a/paddle/fluid/distributed/service/heter_server.h +++ b/paddle/fluid/distributed/service/heter_server.h @@ -27,6 +27,7 @@ limitations under the License. */ #include "brpc/server.h" #include "paddle/fluid/distributed/service/brpc_utils.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h" +#include "paddle/fluid/framework/blocking_queue.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -45,6 +46,7 @@ namespace paddle { namespace framework { class Executor; class ProgramDesc; +class Scope; } // namespace framework namespace platform { class DeviceContext; @@ -61,7 +63,7 @@ using VarMsg = ::paddle::distributed::VariableMessage; class HeterService; typedef int32_t (HeterService::*serviceHandlerFunc)( - const PsRequestMessage& request, PsResponseMessage& response, + const PsRequestMessage& request, PsResponseMessage& response, // NOLINT brpc::Controller* cntl); typedef std::function HeterRpcCallbackFunc; @@ -124,19 +126,27 @@ class HeterService : public ::paddle::distributed::PsService { handler_map_[message_name] = func; } + int32_t ForceExit() { + VLOG(3) << "heter service force exit"; + is_exit_ = true; + return 0; + } + void SetEndpoint(const std::string& end_point) { endpoint_ = end_point; } void SetFanin(const int& fan_in) { fan_in_ = fan_in; } bool IsExit() { return is_exit_; } private: int32_t stop_profiler(const PsRequestMessage& request, - PsResponseMessage& response, brpc::Controller* cntl); + PsResponseMessage& response, // NOLINT + brpc::Controller* cntl); int32_t start_profiler(const PsRequestMessage& request, - PsResponseMessage& response, brpc::Controller* cntl); + PsResponseMessage& response, // NOLINT + brpc::Controller* cntl); int32_t stop_heter_worker(const PsRequestMessage& request, - PsResponseMessage& response, + PsResponseMessage& response, // NOLINT brpc::Controller* cntl); private: @@ -148,19 +158,182 @@ class HeterService : public ::paddle::distributed::PsService { bool is_exit_ = false; }; +using SharedMiniScope = + std::shared_ptr>; +using SharedMicroScope = std::shared_ptr>>>; +using SharedTaskQueue = std::shared_ptr< + std::unordered_map>>>>; + +class HeterRequestHandler { + public: + HeterRequestHandler() : dev_ctx_(nullptr), scope_(nullptr) {} + + virtual ~HeterRequestHandler() {} + + void SetScope(const framework::Scope* scope) { scope_ = scope; } + void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } + + virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) = 0; + + protected: + const platform::DeviceContext* dev_ctx_; + const framework::Scope* scope_; +}; + +class RequestSendAndRecvHandler final : public HeterRequestHandler { + public: + RequestSendAndRecvHandler() { + this->num_microbatch_ = 0; + this->num_minibatch_ = 0; + } + + virtual ~RequestSendAndRecvHandler() {} + + void SetMiniScopes(SharedMiniScope mini_scopes) { + mini_scopes_ = mini_scopes; + num_minibatch_ = mini_scopes_->size(); + } + + void SetMicroScopes(SharedMicroScope micro_scopes) { + micro_scopes_ = micro_scopes; + for (auto& scope_pair : (*micro_scopes_)) { + // auto mini_idx = scope_pair.first; + auto& micro_scopes = scope_pair.second; + num_microbatch_ = micro_scopes->size(); + break; + } + } + + int GetThreadNum() { + std::unique_lock lk(scope_mutex_); + return (*task_queue_).size(); + } + + void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; } + + int Handle(const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) override { + platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle"); + FLAGS_eager_delete_tensor_gb = -1; + + // get microID from request + // deserialize variable to micro scope + // Push to heter worker's task_queue + std::unique_ptr local_scope_ptr( + new paddle::framework::Scope()); + auto& local_scope = *(local_scope_ptr.get()); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::CPUPlace cpu_place; + auto& cpu_dev_ctx = *pool.Get(cpu_place); + + auto message_name = request->message_name(); + auto& request_io_buffer = cntl->request_attachment(); + + distributed::DeserializeFromMultiVarMsgAndIOBuf( + *request, &request_io_buffer, cpu_dev_ctx, &local_scope); + + auto* var = local_scope.FindVar("microbatch_id"); + PADDLE_ENFORCE_NE(var, nullptr, + platform::errors::InvalidArgument( + "Not find variable microbatch_id in scope.")); + auto* tensor = var->GetMutable(); + auto data = reinterpret_cast(tensor->data()); + auto micro_id = static_cast(data[0]); + + int minibatch_index = micro_id / 10; + int microbatch_index = micro_id % 10; + + // check minibatch_index is in mini_scopes_ + std::unique_lock lk(scope_mutex_); + if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) { + lk.unlock(); + // PADDLE_ENFORCE_EQ( + // (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1, + // platform::errors::InvalidArgument( + // "minibatch index should in current trainer")); + PADDLE_ENFORCE_EQ( + (*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1, + platform::errors::InvalidArgument( + "minibatch index should in current trainer")); + + } else { + // create mini scope & micro scopes + auto* minibatch_scope = &(scope_->NewScope()); + (*mini_scopes_)[minibatch_index] = minibatch_scope; + (*micro_scopes_)[minibatch_index].reset( + new std::vector{}); + for (int i = 0; i < num_microbatch_; i++) { + auto* micro_scope = &(minibatch_scope->NewScope()); + (*((*micro_scopes_)[minibatch_index])).push_back(micro_scope); + } + (*task_queue_)[minibatch_index].reset( + new ::paddle::framework::BlockingQueue< + std::pair>()); + lk.unlock(); + } + + auto* micro_scope = + (*((*micro_scopes_)[minibatch_index]))[microbatch_index]; + + distributed::DeserializeFromMultiVarMsgAndIOBuf( + *request, &request_io_buffer, *dev_ctx_, micro_scope); + // blocking queue handles multi thread + (*task_queue_)[minibatch_index]->Push( + std::make_pair(message_name, microbatch_index)); + auto response_var_nums = request->recv_var_names_size(); + std::vector response_var_names(response_var_nums), + empty_var_names{}; + for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) { + response_var_names[var_idx] = request->recv_var_names(var_idx); + } + auto& response_io_buffer = cntl->response_attachment(); + distributed::SerializeToMultiVarMsgAndIOBuf( + message_name, response_var_names, empty_var_names, *dev_ctx_, + &local_scope, response, &response_io_buffer); + return 0; + } + + private: + // share with HeterPipelineTrainer + SharedMiniScope mini_scopes_{nullptr}; + SharedMicroScope micro_scopes_{nullptr}; + + int num_microbatch_; + int num_minibatch_; + std::mutex scope_mutex_; + + bool is_first_stage_ = false; + bool is_last_stage_ = false; + + SharedTaskQueue task_queue_; +}; + class HeterServer { public: virtual ~HeterServer() {} void Stop() { - VLOG(3) << "HeterServer Stop()"; std::unique_lock lock(mutex_); + if (stoped_ == true) return; + if (!IsExit()) service_.ForceExit(); + VLOG(3) << "HeterServer Stop()"; stoped_ = true; cv_.notify_all(); server_.Stop(1000); server_.Join(); } + bool IsStop() { + std::unique_lock lock(mutex_); + if (stoped_ == true) + return true; + else + return false; + } + bool IsExit() { return service_.IsExit(); } HeterServer() {} @@ -170,8 +343,27 @@ class HeterServer { void StartHeterService(); - void SetEndPoint(std::string& endpoint); - void SetFanin(int& fan_in); + void SetEndPoint(const std::string& endpoint); + void SetFanin(const int& fan_in); + + void SetRequestHandler( + std::shared_ptr request_handler) { + request_handler_ = request_handler; + } + + void SetMiniBatchScopes(SharedMiniScope mini_scopes) { + request_handler_->SetMiniScopes(mini_scopes); + } + + void SetMicroBatchScopes(SharedMicroScope micro_scopes) { + request_handler_->SetMicroScopes(micro_scopes); + } + + int GetThreadNum() { return request_handler_->GetThreadNum(); } + + void SetTaskQueue(SharedTaskQueue task_queue) { + request_handler_->SetTaskQueue(task_queue); + } // HeterWrapper singleton static std::shared_ptr GetInstance() { @@ -188,84 +380,19 @@ class HeterServer { mutable std::mutex mutex_; std::condition_variable cv_; std::condition_variable condition_ready_; - bool stoped_ = false; + bool stoped_ = true; std::string endpoint_; protected: brpc::Server server_; HeterService service_; + std::shared_ptr request_handler_; + DISABLE_COPY_AND_ASSIGN(HeterServer); std::mutex mutex_ready_; int ready_; }; -class HeterRequestHandler { - public: - HeterRequestHandler() - : dev_ctx_(nullptr), - executor_(nullptr), - scope_(nullptr), - program_(nullptr) {} - - virtual ~HeterRequestHandler() {} - - void SetScope(framework::Scope* scope) { scope_ = scope; } - void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; } - void SetProgram(framework::ProgramDesc* program) { program_ = program; } - void SetExecutor(framework::Executor* executor) { executor_ = executor; } - - void SetGradToPreparedCtx( - std::unordered_map< - std::string, std::shared_ptr>* g) { - message_to_prepared_ctx_ = g; - } - - virtual int Handle(const MultiVarMsg* request, MultiVarMsg* response, - brpc::Controller* cntl) = 0; - - protected: - const platform::DeviceContext* dev_ctx_; - framework::Executor* executor_; - framework::Scope* scope_; - framework::ProgramDesc* program_; - - std::unordered_map>* - message_to_prepared_ctx_; -}; - -class RequestSendAndRecvHandler final : public HeterRequestHandler { - public: - RequestSendAndRecvHandler() {} - virtual ~RequestSendAndRecvHandler() {} - int Handle(const MultiVarMsg* request, MultiVarMsg* response, - brpc::Controller* cntl) override { - platform::RecordEvent record_event("RequestSendAndRecvHandler->Handle"); - FLAGS_eager_delete_tensor_gb = -1; - auto& local_scope = scope_->NewScope(); - auto message_name = request->message_name(); - auto& request_io_buffer = cntl->request_attachment(); - distributed::DeserializeFromMultiVarMsgAndIOBuf( - *request, &request_io_buffer, *dev_ctx_, &local_scope); - executor_->RunPreparedContext( - (*message_to_prepared_ctx_)[message_name].get(), &local_scope, false); - - auto response_var_nums = request->recv_var_names_size(); - std::vector response_var_names(response_var_nums), - empty_var_names{}; - - for (int var_idx = 0; var_idx < response_var_nums; ++var_idx) { - response_var_names[var_idx] = request->recv_var_names(var_idx); - } - auto& response_io_buffer = cntl->response_attachment(); - distributed::SerializeToMultiVarMsgAndIOBuf( - message_name, response_var_names, empty_var_names, *dev_ctx_, - &local_scope, response, &response_io_buffer); - scope_->DeleteScope(&local_scope); - return 0; - } -}; - } // end namespace distributed } // end namespace paddle diff --git a/paddle/fluid/distributed/service/service.cc b/paddle/fluid/distributed/service/service.cc index 2759e4614e66e1d69c6427e0320ae44292757ffd..29941e36ea0513fb9458a1ced7e2669e96f839a1 100644 --- a/paddle/fluid/distributed/service/service.cc +++ b/paddle/fluid/distributed/service/service.cc @@ -51,7 +51,7 @@ void PSCore::init_gflag(const std::string& gflags) { std::vector flags = paddle::string::split_string(gflags); if (flags.size() < 1) { flags.push_back("-max_body_size=314217728"); - flags.push_back("-bthread_concurrency=40"); + flags.push_back("-bthread_concurrency=200"); flags.push_back("-socket_max_unwritten_bytes=2048000000"); flags.push_back("-max_connection_pool_size=1950"); } diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index fc562a26c199f76ea8eb099f2c16fccd4b48df46..3d1b4fdb485acf6179e50df7c1b94aec693bd7dd 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -294,17 +294,24 @@ if(WITH_DISTRIBUTE) elseif(WITH_PSCORE) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc - heterxpu_trainer.cc + heterxpu_trainer.cc heter_pipeline_trainer.cc data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc - pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry + pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method - graph_to_program_pass variable_helper timer monitor heter_service_proto fleet) + graph_to_program_pass variable_helper timer monitor heter_service_proto fleet heter_server brpc) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) + set(DISTRIBUTE_COMPILE_FLAGS + "${DISTRIBUTE_COMPILE_FLAGS} -faligned-new") + endif() set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(device_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(multi_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(hogwild_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(heter_section_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties(heter_pipeline_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc @@ -359,6 +366,8 @@ if(WITH_PSCORE) get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS) cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS conditional_block_op executor gloo_wrapper ${RPC_DEPS}) + cc_test(heter_pipeline_trainer_test SRCS heter_pipeline_trainer_test.cc DEPS + conditional_block_op scale_op heter_listen_and_serv_op executor heter_server gloo_wrapper eigen_function ${RPC_DEPS}) else() cc_test(dist_multi_trainer_test SRCS dist_multi_trainer_test.cc DEPS conditional_block_op executor gloo_wrapper) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 810e9a087d1220483347ce08a5130fb0987478b2..cbabf721634e55d28459579d77f9d3243e3cbb55 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -614,5 +614,95 @@ class SectionWorker : public DeviceWorker { }; #endif +#if defined(PADDLE_WITH_PSCORE) +class HeterSectionWorker : public DeviceWorker { + public: + HeterSectionWorker() {} + ~HeterSectionWorker() override {} + + void Initialize(const TrainerDesc& desc) override; + void CreateDeviceResource(const ProgramDesc& main_prog) override{}; + + void TrainFiles() override; + void TrainFilesWithProfiler() override; + + void BindingDataFeedMemory() override {} + void BindingDataFeedMemory(int micro_id); + void PrintFetchVars() override; + const platform::Place& place() const { return place_; } + + void SetDeviceIndex(int tid) override { thread_id_ = tid; } + void SetThreadNum(int thread_num) { thread_num_ = thread_num; } + void SetMicrobatchNum(int num) { num_microbatches_ = num; } + void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; } + void SetPipelineStage(int stage) { pipeline_stage_ = stage; } + std::shared_ptr> GetMicrobatchScopes() { + return microbatch_scopes_; + } + void SetMicrobatchScopes( + std::shared_ptr> microbatch_scopes) { + microbatch_scopes_ = microbatch_scopes; + } + using SHARED_THREAD_QUEUE = std::shared_ptr< + ::paddle::framework::BlockingQueue>>; + + SHARED_THREAD_QUEUE GetThreadQueue() { return thread_queue_; } + void SetThreadQueue(SHARED_THREAD_QUEUE thread_queue) { + thread_queue_ = thread_queue; + } + void CopyParameters(int microbatch_id, const ProgramDesc& program, + const platform::Place& place); + void SetMinibatchScope(Scope* scope) { minibatch_scope_ = scope; } + void SetTrainerId(int trainer_id) { this->trainer_id_ = trainer_id; } + void SetTrainers(int trainers) { this->trainers_ = trainers; } + void CreateMicrobatchScopes(); + void RunForward(int micro_id); + void RunBackward(int micro_id); + void RunListen(); + void MiniBatchBarrier(); + void Run(); + void BatchPostProcess(); + void SetDebug(bool debug) { debug_ = debug; } + Scope* GetThreadScope() override { return minibatch_scope_; } + + // multi-stream + // #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + // void SetStream(const gpuStream_t stream) override {} + // void SetEvent(const gpuEvent_t event) override {} + // #endif + + protected: + int trainer_id_; + int trainers_; + int thread_num_; + int thread_id_; + int num_microbatches_; + int num_pipeline_stages_; + int pipeline_stage_; + bool epoch_finish_; + + std::shared_ptr> microbatch_scopes_; + Scope* minibatch_scope_; + std::vector micro_ids_{}; + std::unique_ptr listen_op_{nullptr}; + std::vector> forward_ops_; + std::vector> backward_ops_; + std::shared_ptr program_; + std::shared_ptr< + ::paddle::framework::BlockingQueue>> + thread_queue_; + static uint64_t batch_id_; + uint64_t total_ins_num_ = 0; + platform::DeviceContext* dev_ctx_ = nullptr; + + bool debug_ = false; + std::vector op_total_time_; + std::vector op_name_; + platform::Timer timeline_; + double total_time_ = 0.0; + double read_time_ = 0.0; +}; +#endif + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/device_worker_factory.cc b/paddle/fluid/framework/device_worker_factory.cc index b6f87811bbdb813fadd5ac8a20bd7bf55415d01f..8259d43cb9a4715508467d1e67edb08e6c836995 100644 --- a/paddle/fluid/framework/device_worker_factory.cc +++ b/paddle/fluid/framework/device_worker_factory.cc @@ -65,6 +65,11 @@ std::shared_ptr DeviceWorkerFactory::CreateDeviceWorker( REGISTER_DEVICE_WORKER_CLASS(HogwildWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorker); REGISTER_DEVICE_WORKER_CLASS(DownpourWorkerOpt); + +#if defined(PADDLE_WITH_PSCORE) +REGISTER_DEVICE_WORKER_CLASS(HeterSectionWorker); +#endif + #ifdef PADDLE_WITH_PSLIB REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker); #endif diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 5f681ec7ea241f59b60ae665680c28984c4eadbf..93f4f8952fc675022ce6e142ac10b194958bd238 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -129,7 +129,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, std::shared_ptr Executor::InitForDataset( const ProgramDesc& main_program, const std::string& trainer_desc_str, Scope* scope, Dataset* dataset) { - VLOG(3) << "Start to RunFromDataset in executor"; + VLOG(3) << "Start to InitForDataset in executor"; TrainerDesc trainer_desc; bool success = trainer_desc.ParseFromString(trainer_desc_str); PADDLE_ENFORCE_EQ(success, true, diff --git a/paddle/fluid/framework/heter_pipeline_trainer.cc b/paddle/fluid/framework/heter_pipeline_trainer.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb939f38ff3d9678e09e5cae433317031a47d78f --- /dev/null +++ b/paddle/fluid/framework/heter_pipeline_trainer.cc @@ -0,0 +1,312 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if defined(PADDLE_WITH_PSCORE) +#include "paddle/fluid/distributed/service/heter_server.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/trainer.h" +#include "paddle/fluid/framework/trainer_desc.pb.h" + +namespace paddle { +namespace framework { + +class Variable; + +using MiniScope = std::unordered_map; +using MicroScope = + std::unordered_map>>; +using TaskQueue = + std::unordered_map>>>; + +void HeterPipelineTrainer::ResetDataset(Dataset* dataset) { + if (pipeline_stage_ == 0) { + SetDataset(dataset); + const std::vector readers = + dataset->GetReaders(); + VLOG(3) << "readers num: " << readers.size(); + // change thread num is not supported + PADDLE_ENFORCE_EQ(thread_num_, readers.size(), + platform::errors::InvalidArgument( + "change Dataset thread_num is not supported")); + int cnt = -1; + for (auto& worker_pair : workers_) { + cnt++; + auto device_worker = worker_pair.second; + auto this_worker = + std::dynamic_pointer_cast( + device_worker); + this_worker->SetDataFeed(readers[cnt]); + this_worker->SetReaderPlace(place_); + } + } +} + +void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc, + Dataset* dataset) { + thread_num_ = trainer_desc.thread_num(); + ParseDumpConfig(trainer_desc); + SetDebug(trainer_desc.debug()); + const std::vector readers = + dataset->GetReaders(); + VLOG(3) << "readers num: " << readers.size(); + // change thread num to readers num + thread_num_ = readers.size(); + VLOG(3) << "worker thread num: " << thread_num_; + const auto& heter_section_params = trainer_desc.heter_section_param(); + num_pipeline_stages_ = heter_section_params.num_pipeline_stages(); + pipeline_stage_ = heter_section_params.pipeline_stage(); + num_microbatches_ = heter_section_params.num_microbatches(); + VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; + trainer_desc_ = trainer_desc; + trainer_id_ = trainer_desc.trainer_id(); + for (int i = 0; i < num_pipeline_stages_; ++i) { + auto trainer_num = trainer_desc.trainers(i); + trainers_.push_back(trainer_num); + } + int cpu_trainer_num = trainers_[0]; + // int cur_stage_trainer_num = trainers_[pipeline_stage_]; + // int global_thread_num = cpu_trainer_num * thread_num_; + // int previous_trainers = 0; + // for (int i = 0; i < pipeline_stage_; i++) previous_trainers += + // trainers_[i]; + // int stage_trainer_id = + // trainer_id_ - previous_trainers; // trainer id in current stage + + if (pipeline_stage_ == 0) { // for cpu trainer + int cnt = -1; + int real_thread_id = trainer_id_; + for (int i = 0; i < thread_num_; i++) { + cnt++; + workers_[real_thread_id] = DeviceWorkerFactory::CreateDeviceWorker( + trainer_desc.device_worker_name()); + auto this_worker = + std::dynamic_pointer_cast( + workers_[real_thread_id]); + this_worker->SetDebug(debug_); + this_worker->SetNeedDumpField(need_dump_field_); + this_worker->SetNeedDumpParam(need_dump_param_); + this_worker->SetDumpFieldVector(dump_fields_); + this_worker->SetDumpParamVector(dump_param_); + this_worker->InitRandomDumpConfig(trainer_desc); + this_worker->SetDeviceIndex(real_thread_id); + real_thread_id += cpu_trainer_num; + // if (pipeline_stage_ == 0) { + this_worker->SetDataFeed(readers[cnt]); + //} + this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetPipelineStageNum(num_pipeline_stages_); + this_worker->SetPipelineStage(pipeline_stage_); + } + } else { // for heter_trainer + // heter trainer with thread_id == -1 is not for + // real training + workers_[-1] = DeviceWorkerFactory::CreateDeviceWorker( + trainer_desc.device_worker_name()); + auto this_worker = + std::dynamic_pointer_cast( + workers_[-1]); + this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetPipelineStageNum(num_pipeline_stages_); + this_worker->SetPipelineStage(pipeline_stage_); + this_worker->SetDeviceIndex(-1); + } +} + +void HeterPipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { + if (need_dump_field_) { + InitDumpEnv(); + } +} + +std::string HeterPipelineTrainer::GetDumpPath(int tid) { + return string::format_string("%s/part-%05d", dump_fields_path_.c_str(), tid); +} + +void HeterPipelineTrainer::InitDumpEnv() { + queue_ = paddle::framework::MakeChannel(); + for (int i = 0; i < thread_num_; ++i) { + workers_[i]->SetChannelWriter(queue_.get()); + } + dump_thread_num_ = 1; + for (int i = 0; i < dump_thread_num_; i++) { + dump_thread_.push_back( + std::thread(std::bind(&TrainerBase::DumpWork, this, i))); + } +} + +void HeterPipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, + const platform::Place& place) { + place_ = place; + PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( + "root_scope_ can not be nullptr")); + // initialize mini_scopes & micro_scopes + mini_scopes_.reset(new MiniScope{}); + micro_scopes_.reset(new MicroScope{}); + task_queue_.reset(new TaskQueue{}); + for (auto& worker_pair : workers_) { + auto worker_index = worker_pair.first; + auto device_worker = worker_pair.second; + auto this_worker = + std::dynamic_pointer_cast( + device_worker); + this_worker->SetPlace(place); + this_worker->Initialize(trainer_desc_); + if (pipeline_stage_ == 0) { + this_worker->SetReaderPlace(place); + } + this_worker->SetRootScope(root_scope_); + // generate mini_batch scope for every worker + auto* minibatch_scope = &root_scope_->NewScope(); + (*mini_scopes_)[worker_index] = minibatch_scope; + this_worker->SetMinibatchScope(minibatch_scope); + // after set micro num & mini batch scope + this_worker->CreateMicrobatchScopes(); + (*micro_scopes_)[worker_index] = this_worker->GetMicrobatchScopes(); + (*task_queue_)[worker_index] = this_worker->GetThreadQueue(); + } +} + +void HeterPipelineTrainer::Run() { + VLOG(3) << "Going to run HeterPipelineTrainer::Run()"; + if (listen_ptr_ == nullptr) { + for (auto& worker_pair : workers_) { + auto& device_worker = worker_pair.second; + auto worker_0 = + std::dynamic_pointer_cast( + device_worker); + listen_ptr_.reset(new std::thread( + std::bind(&HeterSectionWorker::RunListen, worker_0.get()))); + break; + } + } + auto heter_server = paddle::distributed::HeterServer::GetInstance(); + heter_server->WaitServerReady(); + heter_server->SetMiniBatchScopes(mini_scopes_); + heter_server->SetMicroBatchScopes(micro_scopes_); + heter_server->SetTaskQueue(task_queue_); + // main training logic + if (pipeline_stage_ == 0) { // for cpu trainer + for (auto& worker_pair : workers_) { + auto device_worker = worker_pair.second; + if (!debug_) { + threads_.push_back( + std::thread(&DeviceWorker::TrainFiles, device_worker.get())); + } else { + threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler, + device_worker.get())); + } + } + } else { // for heter worker + // start thread_worker with thread_id = -1 + for (auto& worker_pair : workers_) { + auto device_worker = worker_pair.second; + if (!debug_) { + threads_.push_back( + std::thread(&DeviceWorker::TrainFiles, device_worker.get())); + } else { + threads_.push_back(std::thread(&DeviceWorker::TrainFilesWithProfiler, + device_worker.get())); + } + } + bool epoch_finish = false; + auto heter_server = paddle::distributed::HeterServer::GetInstance(); + while (!epoch_finish) { + if (heter_server->IsStop()) { + epoch_finish = true; + continue; + } + // create new thread_worker + // size_t thread_num = (*micro_scopes_).size(); + // size_t thread_num = (*task_queue_).size(); + size_t thread_num = heter_server->GetThreadNum(); + while (thread_num > threads_.size()) { + for (auto& worker_pair : (*micro_scopes_)) { + auto worker_index = worker_pair.first; + if (workers_.find(worker_index) != workers_.end()) continue; + workers_[worker_index] = DeviceWorkerFactory::CreateDeviceWorker( + trainer_desc_.device_worker_name()); + auto this_worker = + std::dynamic_pointer_cast( + workers_[worker_index]); + this_worker->SetDebug(debug_); + this_worker->SetNeedDumpField(need_dump_field_); + this_worker->SetNeedDumpParam(need_dump_param_); + this_worker->SetDumpFieldVector(dump_fields_); + this_worker->SetDumpParamVector(dump_param_); + this_worker->InitRandomDumpConfig(trainer_desc_); + this_worker->SetDeviceIndex(worker_index); + this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetPipelineStageNum(num_pipeline_stages_); + this_worker->SetPipelineStage(pipeline_stage_); + this_worker->SetPlace(place_); + this_worker->Initialize(trainer_desc_); + this_worker->SetRootScope(root_scope_); + + // generate mini_batch scope for every worker + // auto* minibatch_scope = &root_scope_->NewScope(); + auto* minibatch_scope = (*mini_scopes_)[worker_index]; + // (*mini_scopes_)[worker_index] = minibatch_scope; + this_worker->SetMinibatchScope(minibatch_scope); + // after set micro num & mini batch scope + this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]); + this_worker->CreateMicrobatchScopes(); + // this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]); + this_worker->SetThreadQueue((*task_queue_)[worker_index]); + if (!debug_) { + threads_.push_back( + std::thread(&DeviceWorker::TrainFiles, this_worker.get())); + } else { + threads_.push_back(std::thread( + &DeviceWorker::TrainFilesWithProfiler, this_worker.get())); + } + } + } + } + } + for (auto& th : threads_) { + th.join(); + } + if (threads_.size() > 0) { + threads_.clear(); + } + VLOG(3) << "Epoch Trainging done"; +} + +void HeterPipelineTrainer::Finalize() { + VLOG(3) << "HeterPipelineTrainer Finalize"; + auto heter_server = paddle::distributed::HeterServer::GetInstance(); + heter_server->Stop(); + if (listen_ptr_) { + (listen_ptr_.get())->join(); + listen_ptr_.reset(nullptr); + } + if (need_dump_field_) { + FinalizeDumpEnv(); + } + root_scope_->DropKids(); +} + +Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) { + if (workers_.find(thread_id) != workers_.end()) { + return workers_[thread_id]->GetThreadScope(); + } else { + return nullptr; + } +} + +} // end namespace framework +} // end namespace paddle +#endif diff --git a/paddle/fluid/framework/heter_pipeline_trainer_test.cc b/paddle/fluid/framework/heter_pipeline_trainer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..af8eca32ee2f4e8914fa48fc4b129b9a08cca818 --- /dev/null +++ b/paddle/fluid/framework/heter_pipeline_trainer_test.cc @@ -0,0 +1,218 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#if (defined PADDLE_WITH_CUDA) && (defined PADDLE_WITH_PSCORE) +#include "gtest/gtest.h" +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/device_worker_factory.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/trainer.h" +#include "paddle/fluid/framework/trainer_factory.h" + +#if defined _WIN32 || defined __APPLE__ +#else +#define _LINUX +#endif + +USE_OP(scale); +USE_NO_KERNEL_OP(heter_listen_and_serv); +namespace paddle { +namespace framework { + +framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { + auto root_block = program->MutableBlock(0); + auto* block = program->AppendBlock(*root_block); + auto* block2 = program->AppendBlock(*root_block); + + framework::OpDesc* op = block->AppendOp(); + op->SetType("scale"); + op->SetInput("X", {"x"}); + op->SetOutput("Out", {"res"}); + op->SetAttr("scale", 0.5f); + + framework::OpDesc* op2 = block2->AppendOp(); + op2->SetType("scale"); + op2->SetInput("X", {"x"}); + op2->SetOutput("Out", {"res"}); + op2->SetAttr("scale", 0.5f); + + auto& out = *root_block->Var("res"); + out.SetType(framework::proto::VarType::LOD_TENSOR); + out.SetShape({1, 10}); + + auto& persistable_var = *root_block->Var("p_var"); + persistable_var.SetType(framework::proto::VarType::LOD_TENSOR); + persistable_var.SetShape({1, 10}); + persistable_var.SetPersistable(true); + + return block; +} + +void GetHeterListenAndServProgram(framework::ProgramDesc* program) { + auto root_block = program->MutableBlock(0); + + auto* sub_block = AppendSendAndRecvBlock(program); + std::vector optimize_blocks; + optimize_blocks.push_back(sub_block); + + std::vector message_to_block_id = {"x:1"}; + std::string endpoint = "127.0.0.1:19944"; + + framework::OpDesc* op = root_block->AppendOp(); + op->SetType("heter_listen_and_serv"); + op->SetInput("X", {}); + op->SetAttr("message_to_block_id", message_to_block_id); + op->SetAttr("optimize_blocks", optimize_blocks); + op->SetAttr("endpoint", endpoint); + op->SetAttr("fanin", 1); + op->SetAttr("pserver_id", 0); +} + +TEST(HeterPipelineTrainerTest, GPU) { +#ifdef _LINUX + TrainerDesc t, t2, t3; + // t2 + t.set_class_name("HeterPipelineTrainer"); + t.set_device_worker_name("HeterSectionWorker"); + t.set_thread_num(1); + t.set_trainer_id(0); + t.add_trainers(1); + t.add_trainers(1); + t.add_trainers(1); + auto* heter_section_param = t.mutable_heter_section_param(); + heter_section_param->set_num_pipeline_stages(3); + heter_section_param->set_pipeline_stage(0); + heter_section_param->set_num_microbatches(1); + // t2 + t2.set_class_name("HeterPipelineTrainer"); + t2.set_device_worker_name("HeterSectionWorker"); + t2.set_thread_num(1); + t2.set_trainer_id(1); + t2.add_trainers(1); + t2.add_trainers(1); + t2.add_trainers(1); + auto* heter_section_param2 = t2.mutable_heter_section_param(); + heter_section_param2->set_num_pipeline_stages(3); + heter_section_param2->set_pipeline_stage(1); + heter_section_param2->set_num_microbatches(1); + // t3 + t3.set_class_name("HeterPipelineTrainer"); + t3.set_device_worker_name("HeterSectionWorker"); + t3.set_thread_num(1); + t3.set_trainer_id(1); + t3.add_trainers(1); + t3.add_trainers(1); + t3.add_trainers(1); + t3.add_dump_fields("hello"); + t3.add_dump_param("fc_0"); + auto* heter_section_param3 = t3.mutable_heter_section_param(); + heter_section_param3->set_num_pipeline_stages(3); + heter_section_param3->set_pipeline_stage(2); + heter_section_param3->set_num_microbatches(1); + + std::string str; + str += "name: \"MultiSlotDataFeed\"\nbatch_size: 2\nmulti_slot_desc {\n"; + str += "slots {\nname: \"words\"\ntype: \"uint64\"\nis_dense: false\n"; + str += "is_used: true\n}\nslots {\nname: \"label\"\ntype: \"uint64\"\n"; + str += "is_dense: false\nis_used: true\n}\n}\n"; + std::shared_ptr dataset = + std::make_shared(); + dataset->SetFileList(std::vector{"a1.txt", "a2.txt"}); + dataset->SetThreadNum(1); + dataset->SetTrainerNum(1); + dataset->SetDataFeedDesc(str); + dataset->CreateReaders(); + + ProgramDesc p; + // construct program + // AppendSendAndRecvBlock(&p); + GetHeterListenAndServProgram(&p); + auto* section_config = heter_section_param->mutable_section_config(); + proto::ProgramDesc* pd = new proto::ProgramDesc(*(p.Proto())); + section_config->set_allocated_program_desc(pd); + + ProgramDesc p2; + // construct program + // AppendSendAndRecvBlock(&p2); + GetHeterListenAndServProgram(&p2); + auto* section_config2 = heter_section_param2->mutable_section_config(); + proto::ProgramDesc* pd2 = new proto::ProgramDesc(*(p2.Proto())); + section_config2->set_allocated_program_desc(pd2); + + ProgramDesc p3; + // construct program + // AppendSendAndRecvBlock(&p3); + GetHeterListenAndServProgram(&p3); + auto* section_config3 = heter_section_param3->mutable_section_config(); + proto::ProgramDesc* pd3 = new proto::ProgramDesc(*(p3.Proto())); + section_config3->set_allocated_program_desc(pd3); + + Scope root_scope, root_scope2, root_scope3; + paddle::platform::CPUPlace place; + paddle::platform::CUDAPlace place2; + + // tmp1 + std::shared_ptr tmp1; + tmp1 = TrainerFactory::CreateTrainer(t.class_name()); + tmp1->SetScope(&root_scope); + tmp1->Initialize(t, dataset.get()); + tmp1->InitTrainerEnv(p, place); + tmp1->InitOtherEnv(p); + tmp1->GetWorkerScope(0); + tmp1->ResetDataset(dataset.get()); + tmp1->Finalize(); + + // tmp2 + std::shared_ptr tmp2; + tmp2 = TrainerFactory::CreateTrainer(t2.class_name()); + tmp2->SetScope(&root_scope2); + tmp2->Initialize(t2, dataset.get()); + tmp2->InitTrainerEnv(p2, place2); + tmp2->InitOtherEnv(p2); + tmp2->GetWorkerScope(0); + tmp2->ResetDataset(dataset.get()); + tmp2->Finalize(); + + // tmp3 + std::shared_ptr tmp3; + tmp3 = TrainerFactory::CreateTrainer(t3.class_name()); + tmp3->SetScope(&root_scope3); + tmp3->Initialize(t3, dataset.get()); + tmp3->InitTrainerEnv(p3, place); + tmp3->InitOtherEnv(p3); + + // tmp3->GetDumpPath(0); + // tmp3->InitDumpEnv(); + // tmp3->FinalizeDumpEnv(); + + tmp3->GetWorkerScope(0); + tmp3->ResetDataset(dataset.get()); + tmp3->Finalize(); + + // tmp4 for coverage + std::shared_ptr tmp4; + tmp4 = TrainerFactory::CreateTrainer("MultiTrainer"); + tmp4->ResetDataset(dataset.get()); + + // heter_section_worker test + std::shared_ptr w_0; + w_0 = DeviceWorkerFactory::CreateDeviceWorker("HeterSectionWorker"); + w_0->CreateDeviceResource(p3); + w_0->BindingDataFeedMemory(); +#endif +} +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/heter_section_worker.cc b/paddle/fluid/framework/heter_section_worker.cc new file mode 100644 index 0000000000000000000000000000000000000000..ace6ac49255c85c7d370ef80db36f75b805eb379 --- /dev/null +++ b/paddle/fluid/framework/heter_section_worker.cc @@ -0,0 +1,445 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined(PADDLE_WITH_PSCORE) +#include +#include "paddle/fluid/distributed/service/heter_server.h" +#include "paddle/fluid/framework/device_worker.h" +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/lodtensor_printer.h" + +namespace paddle { +namespace framework { + +void SetMicroId(paddle::framework::Scope* scope, + platform::DeviceContext* dev_ctx, const platform::Place& place, + int micro_id) { + // create microbatch_id variable + // and set micro id value + auto* ptr = scope->Var("microbatch_id"); + InitializeVariable(ptr, proto::VarType::LOD_TENSOR); + framework::Variable* var = scope->FindVar("microbatch_id"); + PADDLE_ENFORCE_EQ(var->IsType(), 1, + platform::errors::InvalidArgument( + "the type of microbatch_id should be LoDTensor")); + auto* tensor = var->GetMutable(); + std::vector dims{1}; + tensor->Resize(framework::make_ddim(dims)); + void* tensor_data = + tensor->mutable_data(place, framework::proto::VarType::FP32); + if (platform::is_gpu_place(place)) { +#ifdef PADDLE_WITH_CUDA + std::vector temp; + temp.resize(tensor->numel() * framework::SizeOfType(tensor->type())); + char* temp_ptr = temp.data(); + float* temp_ptr_float = reinterpret_cast(temp_ptr); + temp_ptr_float[0] = micro_id; + auto stream = + reinterpret_cast(*dev_ctx).stream(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), tensor_data, + platform::CPUPlace(), reinterpret_cast(temp_ptr), + tensor->numel() * framework::SizeOfType(tensor->type()), + stream); +#endif + } else { + float* temp_ptr = reinterpret_cast(tensor_data); + temp_ptr[0] = micro_id; + } +} + +class TrainerDesc; + +uint64_t HeterSectionWorker::batch_id_(0); + +void HeterSectionWorker::Initialize(const TrainerDesc& desc) { + trainer_desc_ = desc; + fetch_config_ = desc.fetch_config(); + dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); + program_.reset(new ProgramDesc( + desc.heter_section_param().section_config().program_desc())); + thread_queue_.reset( + new ::paddle::framework::BlockingQueue>()); + bool is_first_stage = (pipeline_stage_ == 0); + bool is_last_stage = (pipeline_stage_ + 1 == num_pipeline_stages_); + + if (is_first_stage) { + for (auto& op_desc : program_->Block(0).AllOps()) { + auto op = std::move(OpRegistry::CreateOp(*op_desc)); + auto op_type = op->Type(); + if (listen_op_ == nullptr && op_type == "heter_listen_and_serv") { + listen_op_ = std::move(op); + } else { + forward_ops_.push_back(std::move(op)); + } + } + for (auto& op_desc : program_->Block(1).AllOps()) { + backward_ops_.push_back(OpRegistry::CreateOp(*op_desc)); + } + } else if (is_last_stage) { + for (auto& op_desc : program_->Block(0).AllOps()) { + if (listen_op_ == nullptr) { + listen_op_ = std::move(OpRegistry::CreateOp(*op_desc)); + } + } + for (auto& op_desc : program_->Block(1).AllOps()) { + auto op = std::move(OpRegistry::CreateOp(*op_desc)); + int op_role = op->Attr(std::string("op_role")); + bool is_forward_op = (op_role == static_cast(OpRole::kForward)) || + (op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss))) || + (op_role == static_cast(OpRole::kLRSched)); + if (is_forward_op) { + forward_ops_.push_back(std::move(op)); + } else { + backward_ops_.push_back(std::move(op)); + } + } + } else { + for (auto& op_desc : program_->Block(0).AllOps()) { + if (listen_op_ == nullptr) { + listen_op_ = std::move(OpRegistry::CreateOp(*op_desc)); + } + } + for (auto& op_desc : program_->Block(1).AllOps()) { + forward_ops_.push_back(OpRegistry::CreateOp(*op_desc)); + } + for (auto& op_desc : program_->Block(2).AllOps()) { + backward_ops_.push_back(OpRegistry::CreateOp(*op_desc)); + } + } +} + +void HeterSectionWorker::RunBackward(int micro_id) { + for (size_t i = 0; i < backward_ops_.size(); i++) { + auto& op = backward_ops_[i]; + VLOG(3) << "Backward: start to run op " << op->Type() << " for micro-batch " + << micro_id; + if (debug_) { + timeline_.Start(); + } + op->Run(*((*microbatch_scopes_)[micro_id]), place_); + dev_ctx_->Wait(); + if (debug_) { + timeline_.Pause(); + int offset = forward_ops_.size(); + op_total_time_[i + offset] += timeline_.ElapsedSec(); + total_time_ += timeline_.ElapsedSec(); + } + VLOG(3) << "Backward: finish running op " << op->Type() + << " for micro-batch " << micro_id; + } +} + +void HeterSectionWorker::MiniBatchBarrier() { + // get micro id & deserialize data + std::set micro_ids; + while (micro_ids.size() < micro_ids_.size()) { + auto task = (*thread_queue_).Pop(); + auto message_name = task.first; + auto micro_id = task.second; + PADDLE_ENFORCE_EQ(message_name.find("backward") != std::string::npos, true, + platform::errors::InvalidArgument( + "cpu trainers only receive backward data")); + PADDLE_ENFORCE_EQ( + micro_ids.find(micro_id) == micro_ids.end(), true, + platform::errors::InvalidArgument("minibatch_scope_ can not be nullptr " + "when create MicroBatch Scope")); + micro_ids.insert(micro_id); + // backward data has been deserialized to micro scope + // now run backward computation + RunBackward(micro_id); + batch_num_++; + BatchPostProcess(); + } + micro_ids_.clear(); +} + +void HeterSectionWorker::RunListen() { listen_op_->Run(*root_scope_, place_); } + +void HeterSectionWorker::RunForward(int micro_id) { + if (pipeline_stage_ == 0) { + BindingDataFeedMemory(micro_id); + if (debug_) { + timeline_.Start(); + } + int cur_micro_batch = device_reader_->Next(); + if (cur_micro_batch <= 0) { + epoch_finish_ = true; + return; + } + if (debug_) { + timeline_.Pause(); + read_time_ += timeline_.ElapsedSec(); + total_time_ += timeline_.ElapsedSec(); + total_ins_num_ += cur_micro_batch; + } + VLOG(3) << "read a batch in thread " << thread_id_ << " micro " << micro_id; + } + for (size_t i = 0; i < forward_ops_.size(); i++) { + auto& op = forward_ops_[i]; + VLOG(3) << "Forward: start to run op " << op->Type() << " for micro-batch " + << micro_id; + if (debug_) { + timeline_.Start(); + } + op->Run(*((*microbatch_scopes_)[micro_id]), place_); + dev_ctx_->Wait(); + if (debug_) { + timeline_.Pause(); + op_total_time_[i] += timeline_.ElapsedSec(); + total_time_ += timeline_.ElapsedSec(); + } + VLOG(3) << "Forward: finish running op " << op->Type() + << " for micro-batch " << micro_id; + } +} + +void HeterSectionWorker::BindingDataFeedMemory(int micro_id) { + const std::vector& input_feed = + device_reader_->GetUseSlotAlias(); + for (auto name : input_feed) { + device_reader_->AddFeedVar((*microbatch_scopes_)[micro_id]->FindVar(name), + name); + } +} + +void HeterSectionWorker::CreateMicrobatchScopes() { + PADDLE_ENFORCE_NOT_NULL( + minibatch_scope_, + platform::errors::InvalidArgument( + "minibatch_scope_ can not be nullptr when create MicroBatch Scopes")); + if (microbatch_scopes_.get() == nullptr) { + microbatch_scopes_.reset(new std::vector{}); + (*microbatch_scopes_).resize(num_microbatches_); + VLOG(3) << "Create microbatch scopes..."; + for (int j = 0; j < num_microbatches_; ++j) { + (*microbatch_scopes_)[j] = &minibatch_scope_->NewScope(); + } + } + if (thread_id_ >= 0) { + std::shared_ptr program; + program.reset(new ProgramDesc( + trainer_desc_.heter_section_param().section_config().program_desc())); + for (int j = 0; j < num_microbatches_; ++j) { + CopyParameters(j, *program, place_); + } + } +} + +void HeterSectionWorker::CopyParameters(int microbatch_id, + const ProgramDesc& program, + const platform::Place& place) { + auto& global_block = program.Block(0); + auto var_list = global_block.AllVars(); + if (program.Size() > 1) { + auto& heter_block = program.Block(1); + auto heter_var_list = heter_block.AllVars(); + var_list.insert(var_list.end(), heter_var_list.begin(), + heter_var_list.end()); + } + if (program.Size() > 2) { + auto& heter_block = program.Block(2); + auto heter_var_list = heter_block.AllVars(); + var_list.insert(var_list.end(), heter_var_list.begin(), + heter_var_list.end()); + } + auto global_micro_id = thread_id_ * 10 + microbatch_id; + SetMicroId((*microbatch_scopes_)[microbatch_id], dev_ctx_, place, + global_micro_id); + for (auto& var : var_list) { + if (var->Persistable() && microbatch_id == 0) { + if (root_scope_->FindVar(var->Name()) != nullptr) continue; + auto* ptr = root_scope_->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + VLOG(5) << "Create persistable var: " << var->Name() + << ", which pointer is " << ptr; + } else if (!var->Persistable()) { + if ((*microbatch_scopes_)[microbatch_id]->FindVar(var->Name()) != nullptr) + continue; + auto* ptr = (*microbatch_scopes_)[microbatch_id]->Var(var->Name()); + VLOG(5) << "Create variable " << var->Name() << " for microbatch " + << microbatch_id << ", which pointer is " << ptr; + InitializeVariable(ptr, var->GetType()); + } + } +} + +void HeterSectionWorker::Run() { + if (debug_) { + size_t total_ops_size = forward_ops_.size() + backward_ops_.size(); + op_name_.resize(total_ops_size); + op_total_time_.resize(total_ops_size); + platform::SetNumThreads(1); + // forward op + backward op + for (auto& op : forward_ops_) { + op_name_.push_back(op->Type()); + } + for (auto& op : backward_ops_) { + op_name_.push_back(op->Type()); + } + for (size_t i = 0; i < op_total_time_.size(); ++i) { + op_total_time_[i] = 0.0; + } + } + bool is_first_stage = (pipeline_stage_ == 0); + bool is_last_stage = (pipeline_stage_ + 1 == num_pipeline_stages_); + if (is_first_stage) { // for cpu trainer + while (!epoch_finish_) { + // forward + for (int i = 0; i < num_microbatches_; i++) { + VLOG(5) << "Run " << i << " microbatch"; + RunForward(i); + if (epoch_finish_ == true) { + break; + } + micro_ids_.push_back(i); + } + // backward + if (micro_ids_.size() > 0) { + MiniBatchBarrier(); + } + } + } else { // for heter worker + auto heter_server = paddle::distributed::HeterServer::GetInstance(); + while (true) { + if (heter_server->IsStop()) { + epoch_finish_ = true; + break; + } + auto task = (*thread_queue_).Pop(); + auto message_name = task.first; + auto micro_id = task.second; + if (is_last_stage) { + PADDLE_ENFORCE_EQ(message_name.find("forward") != std::string::npos, 1, + platform::errors::InvalidArgument( + "last stage only receive forward data")); + RunForward(micro_id); + RunBackward(micro_id); + batch_num_++; + BatchPostProcess(); + } else { + if (message_name.find("forward") != std::string::npos) { + RunForward(micro_id); + } else if (message_name.find("backward") != std::string::npos) { + RunBackward(micro_id); + batch_num_++; + BatchPostProcess(); + } + } + } + } +} + +void HeterSectionWorker::BatchPostProcess() { + PrintFetchVars(); + // dump param & field + if (need_dump_field_) { + DumpField(*((*microbatch_scopes_)[0]), dump_mode_, dump_interval_); + } + if (need_dump_param_ && thread_id_ == 0) { + DumpParam(*((*microbatch_scopes_)[0]), batch_num_); + } + // print each op time + if (thread_id_ == 0) { + size_t total_ops_size = forward_ops_.size() + backward_ops_.size(); + if (batch_num_ > 0 && batch_num_ % 100 == 0) { + for (size_t i = 0; i < total_ops_size; ++i) { + fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i, + op_name_[i].c_str(), op_total_time_[i] / batch_num_); + } + if (pipeline_stage_ == 0) { + fprintf(stderr, "mean read time: %fs\n", read_time_ / batch_num_); + fprintf(stderr, "IO percent: %f\n", read_time_ / total_time_ * 100); + } + fprintf(stderr, "%6.2f instances/s\n", total_ins_num_ / total_time_); + } + } +} + +void HeterSectionWorker::TrainFiles() { + if (thread_id_ >= 0) { + total_ins_num_ = 0; + batch_num_ = 0; + platform::SetNumThreads(1); + timeline_.Start(); + VLOG(3) << "begin section_worker TrainFiles"; + epoch_finish_ = false; + if (pipeline_stage_ == 0) { + device_reader_->Start(); + } + while (!epoch_finish_) { + Run(); + dev_ctx_->Wait(); + } + timeline_.Pause(); + VLOG(3) << "worker " << thread_id_ << " train cost " + << timeline_.ElapsedSec() + << " seconds, ins_num: " << total_ins_num_; + } +} + +void HeterSectionWorker::PrintFetchVars() { + // call count + int batch_per_print = fetch_config_.print_period(); + int fetch_var_num = fetch_config_.fetch_var_names_size(); + if (fetch_var_num == 0) { + return; + } + if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) { + time_t curtime; + time(&curtime); + char mbstr[80]; + std::strftime(mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S", + std::localtime(&curtime)); + std::stringstream ss; + ss << "time: [" << mbstr << "], "; + ss << "batch: [" << batch_num_ << "], "; + for (int i = 0; i < fetch_var_num; ++i) { + platform::PrintVar((*microbatch_scopes_)[0], + fetch_config_.fetch_var_names(i), + fetch_config_.fetch_var_str_format(i), &ss); + if (i < fetch_var_num - 1) { + ss << ", "; + } + } + std::cout << ss.str() << std::endl; + } +} + +void HeterSectionWorker::TrainFilesWithProfiler() { + if (thread_id_ >= 0) { + VLOG(3) << "begin section_worker TrainFilesWithProfiler"; + batch_num_ = 0; + epoch_finish_ = false; + total_ins_num_ = 0; + op_name_.clear(); + op_total_time_.clear(); + if (pipeline_stage_ == 0) { + device_reader_->Start(); + } + while (!epoch_finish_) { + Run(); + dev_ctx_->Wait(); + if (epoch_finish_) { + // dump param for debug + if (need_dump_field_ || need_dump_param_) { + writer_.Flush(); + } + } + } + } +} + +} // namespace framework +} // namespace paddle +#endif diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 0f34c84549f2b9ad56046ebde60e15d7d9ffff10..e885bcfe003de6fceaaeddef73d4ebe30485a21e 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -27,7 +27,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_set.h" #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/fleet/heter_context.h" -//#include "paddle/fluid/framework/fleet/heter_wrapper.h" #include "paddle/fluid/framework/heter_util.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/program_desc.h" @@ -72,6 +71,7 @@ class TrainerBase { virtual Scope* GetWorkerScope(int thread_id) = 0; virtual void InitDumpEnv() = 0; virtual void DumpWork(int tid); + virtual void ResetDataset(Dataset* dataset_ptr) {} protected: virtual std::string GetDumpPath(int tid) = 0; @@ -323,5 +323,51 @@ class PipelineTrainer : public TrainerBase { }; #endif +#if defined(PADDLE_WITH_PSCORE) +class HeterPipelineTrainer : public TrainerBase { + public: + HeterPipelineTrainer() {} + ~HeterPipelineTrainer() override {} + void Initialize(const TrainerDesc& trainer_desc, Dataset* data_set) override; + void InitTrainerEnv(const ProgramDesc& main_program, + const platform::Place& place) override; + void InitOtherEnv(const ProgramDesc& main_program) override; + void Run() override; + void Finalize() override; + Scope* GetWorkerScope(int thread_id) override; + void InitDumpEnv() override; + std::string GetDumpPath(int tid) override; + void ResetDataset(Dataset* dataset_ptr) override; + + protected: + int trainer_id_; // stage_trainer_id + std::vector trainers_; // std::vector trainers + int thread_num_; + std::vector threads_; + + int num_microbatches_; + platform::Place place_; + TrainerDesc trainer_desc_; + + int num_pipeline_stages_; + int pipeline_stage_; + std::unordered_map> + workers_; + + std::shared_ptr>>>> + task_queue_; + + platform::DeviceContext* dev_ctx_ = nullptr; + + std::shared_ptr> mini_scopes_; + std::shared_ptr>>> + micro_scopes_; + + std::unique_ptr listen_ptr_ = nullptr; +}; +#endif + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 6f487d6984cc43a5643b9784be287fa7795fb3fb..96d312437b34cf1fafc4fbcaeec91201a1fa934a 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -63,11 +63,15 @@ message TrainerDesc { optional string user_define_dump_filename = 33; optional bool scale_sparse_gradient_with_batch_size = 34 [ default = true ]; + repeated int32 trainers = 35; + optional int32 trainer_id = 36; + // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; optional DownpourWorkerParameter downpour_param = 103; optional PullDenseWorkerParameter pull_dense_param = 102; optional SectionWorkerParameter section_param = 104; + optional HeterSectionWorkerParameter heter_section_param = 105; // datafeed desc optional DataFeedDesc data_desc = 201; } @@ -99,6 +103,17 @@ message SectionWorkerParameter { optional int32 schedule_mode = 9 [ default = 0 ]; } +message HeterSectionWorkerParameter { + optional SectionConfig section_config = 1; + optional int32 queue_size = 2 [ default = 1 ]; + optional int64 sync_steps = 3 [ default = 1 ]; + optional int32 start_cpu_core_id = 4 [ default = 1 ]; + repeated string param_need_sync = 5; + optional int32 num_microbatches = 6; + optional int32 num_pipeline_stages = 7 [ default = 1 ]; + optional int32 pipeline_stage = 8 [ default = 1 ]; +} + message SectionConfig { enum Place { CPUPlace = 0; diff --git a/paddle/fluid/framework/trainer_factory.cc b/paddle/fluid/framework/trainer_factory.cc index 660511b1f268d910629199bd122561a2a24a1b0a..6f003c2f497b68527768af7fefd8ff69e9c23d3f 100644 --- a/paddle/fluid/framework/trainer_factory.cc +++ b/paddle/fluid/framework/trainer_factory.cc @@ -66,6 +66,11 @@ std::shared_ptr TrainerFactory::CreateTrainer( REGISTER_TRAINER_CLASS(MultiTrainer); REGISTER_TRAINER_CLASS(DistMultiTrainer); + +#if defined(PADDLE_WITH_PSCORE) +REGISTER_TRAINER_CLASS(HeterPipelineTrainer); +#endif + #if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_HIP || \ defined PADDLE_WITH_XPU) && \ (defined PADDLE_WITH_PSLIB) diff --git a/paddle/fluid/operators/pscore/CMakeLists.txt b/paddle/fluid/operators/pscore/CMakeLists.txt index e4d654008d3d03f5136493bf3719636a6c7daf96..baf82a9df31cba709fa0375351ff975c5fbccc4a 100644 --- a/paddle/fluid/operators/pscore/CMakeLists.txt +++ b/paddle/fluid/operators/pscore/CMakeLists.txt @@ -29,5 +29,11 @@ set(OPERATOR_DEPS ${OPERATOR_DEPS} ${DISTRIBUTE_DEPS} PARENT_SCOPE) set_source_files_properties(heter_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(heter_server_test SRCS heter_server_test.cc DEPS ${RPC_DEPS} ${DISTRIBUTE_DEPS} executor scope proto_desc scale_op eigen_function) +set_source_files_properties(send_and_recv_op_cpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(send_and_recv_cpu_test SRCS send_and_recv_op_cpu_test.cc DEPS executor scope proto_desc scale_op send_and_recv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) + +set_source_files_properties(send_and_recv_op_gpu_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +cc_test(send_and_recv_gpu_test SRCS send_and_recv_op_gpu_test.cc DEPS executor scope proto_desc scale_op send_and_recv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) + set_source_files_properties(heter_listen_and_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(heter_listen_and_server_test SRCS heter_listen_and_server_test.cc DEPS executor scope proto_desc scale_op heter_listen_and_serv_op ${RPC_DEPS} ${DISTRIBUTE_DEPS} eigen_function) diff --git a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.cc b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.cc index cd1bdc4d60c7496878d2d2a36021fc6efd6f4443..2c443e8c63cbef71b1948719ee0caf48309b8a0b 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.cc +++ b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.cc @@ -49,9 +49,7 @@ HeterListenAndServOp::~HeterListenAndServOp() { Stop(); } void HeterListenAndServOp::Stop() {} -void HeterListenAndServOp::RunAsyncLoop(framework::Executor *executor, - framework::ProgramDesc *program, - framework::Scope *recv_scope) const { +void HeterListenAndServOp::RunAsyncLoop(framework::ProgramDesc *program) const { VLOG(2) << "RunAsyncLoop"; auto message_to_block_id_str = Attr>("message_to_block_id"); @@ -90,28 +88,6 @@ void HeterListenAndServOp::RunAsyncLoop(framework::Executor *executor, for (size_t blkid = 1; blkid < num_blocks; ++blkid) { block_list.push_back(blkid); } - auto optimize_prepared = executor->Prepare(*program, block_list); - // execute global block if needed, block id 1 in the program is global - // block if it's not bind to a grad var for it's update. - if (block_list[0] == 1 && - message_to_block_id.find_value(static_cast(1)) == - message_to_block_id.end()) { - executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope); - } - - std::unordered_map> - message_to_prepared_ctx; - for (size_t i = 0; i < block_list.size(); ++i) { - auto blkid = block_list[i]; - auto it = message_to_block_id.find_value(blkid); - if (it != message_to_block_id.end()) { - message_to_prepared_ctx[it->first] = optimize_prepared[i]; - } - } - - request_send_and_recv_handler_->SetGradToPreparedCtx( - &message_to_prepared_ctx); for (size_t i = 0; i < block_list.size(); ++i) { auto blkid = block_list[i]; @@ -125,7 +101,7 @@ void HeterListenAndServOp::RunAsyncLoop(framework::Executor *executor, } while (true) { - if (rpc_service_->IsExit()) { + if (rpc_service_->IsExit() || rpc_service_->IsStop()) { rpc_service_->Stop(); VLOG(0) << "get exit. rpc_processor stop!"; break; @@ -145,7 +121,6 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, auto &dev_ctx = *pool.Get(dev_place); VLOG(1) << "HeterListenAndServOp::RunImpl On gpu? " << platform::is_gpu_place(dev_place); - framework::Scope &recv_scope = scope.NewScope(); auto pserver_id = Attr("pserver_id"); auto fan_in = Attr("fanin"); @@ -154,8 +129,8 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE_EQ(rpc_service_, nullptr, platform::errors::PreconditionNotMet( "RPC service has been created unexpectedly.")); - std::string endpoint = Attr("endpoint"); + std::string endpoint = Attr("endpoint"); VLOG(4) << "pserver_id: " << pserver_id << ", end_point:" << endpoint; rpc_service_ = distributed::HeterServer::GetInstance(); @@ -168,15 +143,14 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, platform::errors::PreconditionNotMet( "optimize blocks is less than 1. Optimize blocks " "should be 1 at least on the pserver side.")); + auto *program = optimize_blocks[0]->Program(); - framework::Executor executor(dev_place); request_send_and_recv_handler_.reset( new distributed::RequestSendAndRecvHandler()); - request_send_and_recv_handler_->SetScope(&recv_scope); + request_send_and_recv_handler_->SetScope(&scope); request_send_and_recv_handler_->SetDevCtx(&dev_ctx); - request_send_and_recv_handler_->SetProgram(program); - request_send_and_recv_handler_->SetExecutor(&executor); + rpc_service_->SetRequestHandler(request_send_and_recv_handler_); VLOG(2) << "RunAsyncLoop"; auto message_to_block_id_str = @@ -186,7 +160,7 @@ void HeterListenAndServOp::RunImpl(const framework::Scope &scope, server_thread_.reset(new std::thread(RunServer, rpc_service_)); VLOG(3) << "wait server thread to become ready..."; rpc_service_->WaitServerReady(); - RunAsyncLoop(&executor, program, &recv_scope); + RunAsyncLoop(program); VLOG(3) << "Wait for Server_thread_ stop"; (server_thread_.get())->join(); VLOG(3) << "Server_thread_ stop"; diff --git a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h index 6b208bf4974ad3d8f8c6839efc601458fc81dff4..f81b45ec05f9174937f5001d950e2f9675df91d0 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h +++ b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h @@ -77,9 +77,7 @@ class HeterListenAndServOp : public framework::OperatorBase { const framework::AttributeMap& attrs); virtual ~HeterListenAndServOp(); - void RunAsyncLoop(framework::Executor* executor, - framework::ProgramDesc* program, - framework::Scope* recv_scope) const; + void RunAsyncLoop(framework::ProgramDesc* program) const; void Stop() override; @@ -89,7 +87,7 @@ class HeterListenAndServOp : public framework::OperatorBase { protected: mutable std::shared_ptr rpc_service_; mutable std::shared_ptr server_thread_; - mutable std::shared_ptr + mutable std::shared_ptr request_send_and_recv_handler_; }; diff --git a/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc b/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc index 3b005e10d9b98c635c576f6529b34962b98c1ddb..c870e758e96afc1c70a26236b0d20ac05d77aaf1 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc +++ b/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "gtest/gtest.h" #include "paddle/fluid/distributed/service/heter_client.h" +#include "paddle/fluid/distributed/service/heter_server.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/op_registry.h" @@ -76,6 +77,9 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { auto x_var = scope->Var("x"); x_var->GetMutable(); + auto micro_var = scope->Var("microbatch_id"); + micro_var->GetMutable(); + auto res_var = scope->Var("res"); res_var->GetMutable(); } @@ -88,6 +92,32 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, x_var->mutable_data(framework::DDim({1, rows_numel}), *place); for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; + auto micro_id_var = + scope->Var("microbatch_id")->GetMutable(); + float* micro_id_ptr = + micro_id_var->mutable_data(framework::DDim({1}), *place); + micro_id_ptr[0] = 0; + + auto res_var = scope->Var("res")->GetMutable(); + float* res_ptr = + res_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) res_ptr[i] = 1.0; +} + +void InitTensorsOnClient2(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; + + auto micro_id_var = + scope->Var("microbatch_id")->GetMutable(); + float* micro_id_ptr = + micro_id_var->mutable_data(framework::DDim({1}), *place); + micro_id_ptr[0] = 1; + auto res_var = scope->Var("res")->GetMutable(); float* res_ptr = res_var->mutable_data(framework::DDim({1, rows_numel}), *place); @@ -121,45 +151,85 @@ TEST(HETER_LISTEN_AND_SERV, CPU) { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); std::string endpoint = "127.0.0.1:19944"; + std::string previous_endpoint = "127.0.0.1:19944"; LOG(INFO) << "before StartSendAndRecvServer"; FLAGS_eager_delete_tensor_gb = -1; std::thread server_thread(StartHeterServer); sleep(1); + auto b_rpc_service = distributed::HeterServer::GetInstance(); + b_rpc_service->WaitServerReady(); + using MicroScope = + std::unordered_map>>; + using MiniScope = std::unordered_map; + std::shared_ptr mini_scopes(new MiniScope{}); + std::shared_ptr micro_scopes(new MicroScope{}); + std::shared_ptr> micro_scope( + new std::vector{}); + auto* mini_scope = new framework::Scope(); + (*mini_scopes)[0] = mini_scope; + auto* micro_scope_0 = &(mini_scope->NewScope()); + auto* micro_scope_1 = &(mini_scope->NewScope()); + (*micro_scope).push_back(micro_scope_0); + (*micro_scope).push_back(micro_scope_1); + (*micro_scopes)[0] = micro_scope; + b_rpc_service->SetMicroBatchScopes(micro_scopes); + b_rpc_service->SetMiniBatchScopes(mini_scopes); + + using TaskQueue = + std::unordered_map>>>; + using SharedTaskQueue = std::shared_ptr>>>>; + SharedTaskQueue task_queue_(new TaskQueue{}); + (*task_queue_)[0] = std::make_shared< + ::paddle::framework::BlockingQueue>>(); + b_rpc_service->SetTaskQueue(task_queue_); + LOG(INFO) << "before HeterClient::GetInstance"; distributed::HeterClient* rpc_client = - distributed::HeterClient::GetInstance({endpoint}, 0).get(); + distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0) + .get(); PADDLE_ENFORCE_NE(rpc_client, nullptr, platform::errors::InvalidArgument( "Client Start Fail, Check Your Code & Env")); - framework::Scope scope; + framework::Scope* scope = (*micro_scope)[0]; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); // create var on local scope int64_t rows_numel = 10; LOG(INFO) << "before InitTensorsOnClient"; - InitTensorsOnClient(&scope, &place, rows_numel); + InitTensorsOnClient(scope, &place, rows_numel); std::string in_var_name("x"); + std::string micro_var_name("microbatch_id"); std::string out_var_name("res"); - std::vector send_var = {in_var_name}; - std::vector recv_var = {out_var_name}; + std::vector send_var = {in_var_name, micro_var_name}; + std::vector recv_var = {}; LOG(INFO) << "before SendAndRecvAsync"; - rpc_client->SendAndRecvAsync({endpoint}, ctx, scope, in_var_name, send_var, - recv_var); - auto var = scope.Var(out_var_name); - auto value = var->GetMutable(); - auto ptr = value->mutable_data(place); - - LOG(INFO) << "before CHECK"; - for (int64_t i = 0; i < rows_numel; ++i) { - LOG(INFO) << "ptr " << i << " is " << ptr[i]; - EXPECT_EQ(ptr[i], 0.5); - } - LOG(INFO) << "end CHECK"; + rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var, + "forward"); + auto task = (*task_queue_)[0]->Pop(); + PADDLE_ENFORCE_EQ( + task.first, "x", + platform::errors::InvalidArgument( + "Recv message and Send message name not match, Check your Code")); + + InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel); + LOG(INFO) << "before SendAndRecvAsync 2"; + rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name, send_var, + recv_var, "backward"); + auto task2 = (*task_queue_)[0]->Pop(); + PADDLE_ENFORCE_EQ( + task2.first, "x", + platform::errors::InvalidArgument( + "Recv message and Send message name not match, Check your Code")); + rpc_client->Stop(); LOG(INFO) << "end server Stop"; server_thread.join(); diff --git a/paddle/fluid/operators/pscore/heter_server_test.cc b/paddle/fluid/operators/pscore/heter_server_test.cc index df2eb70b144e4a3cd14384cd4077f44950f89c92..5029aa0ebdcc0c547c394053ce110dbc9f401a3f 100644 --- a/paddle/fluid/operators/pscore/heter_server_test.cc +++ b/paddle/fluid/operators/pscore/heter_server_test.cc @@ -57,6 +57,9 @@ void CreateVarsOnScope(framework::Scope* scope, platform::CPUPlace* place) { auto out_var = scope->Var("out"); out_var->GetMutable(); + auto micro_var = scope->Var("microbatch_id"); + micro_var->GetMutable(); + auto ids_var = scope->Var("ids"); ids_var->GetMutable(); @@ -75,6 +78,37 @@ void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, ids_var->mutable_data(framework::DDim({rows_numel, 1}), *place); for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; + auto micro_id_var = + scope->Var("microbatch_id")->GetMutable(); + float* micro_id_ptr = + micro_id_var->mutable_data(framework::DDim({1}), *place); + micro_id_ptr[0] = 0; + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; + + auto res_var = scope->Var("res")->GetMutable(); + float* res_ptr = + res_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) res_ptr[i] = 1.0; +} + +void InitTensorsOnClient2(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope, place); + auto ids_var = scope->Var("ids")->GetMutable(); + int64_t* ids_ptr = + ids_var->mutable_data(framework::DDim({rows_numel, 1}), *place); + for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; + + auto micro_id_var = + scope->Var("microbatch_id")->GetMutable(); + float* micro_id_ptr = + micro_id_var->mutable_data(framework::DDim({1}), *place); + micro_id_ptr[0] = 1; + auto x_var = scope->Var("x")->GetMutable(); float* x_ptr = x_var->mutable_data(framework::DDim({1, rows_numel}), *place); @@ -114,29 +148,19 @@ void StartSendAndRecvServer(std::string endpoint) { LOG(INFO) << "before AppendSendAndRecvBlock"; auto block = AppendSendAndRecvBlock(&program); std::string in_var_name("x"); + std::string in_var_name2("y"); std::vector prefetch_block_ids{block->ID()}; - auto prepared = exe.Prepare(program, prefetch_block_ids); LOG(INFO) << "before InitTensorsOnServer"; InitTensorsOnServer(&scope, &place, 10); LOG(INFO) << "end InitTensorsOnServer"; - std::unordered_map> - message_to_prepared_ctx; - message_to_prepared_ctx[in_var_name] = prepared[0]; std::shared_ptr b_req_handler; b_req_handler.reset(new distributed::RequestSendAndRecvHandler()); - LOG(INFO) << "before SetProgram"; - b_req_handler->SetProgram(&program); - LOG(INFO) << "before SetGradToPreparedCtx"; - b_req_handler->SetGradToPreparedCtx(&message_to_prepared_ctx); LOG(INFO) << "before SetDevCtx"; b_req_handler->SetDevCtx(&ctx); LOG(INFO) << "before SetScope"; b_req_handler->SetScope(&scope); - LOG(INFO) << "before SetExecutor"; - b_req_handler->SetExecutor(&exe); LOG(INFO) << "before HeterServer::GetInstance"; b_rpc_service = distributed::HeterServer::GetInstance(); b_rpc_service->SetEndPoint(endpoint); @@ -146,7 +170,13 @@ void StartSendAndRecvServer(std::string endpoint) { brpc::Controller* cntl) -> int { return b_req_handler->Handle(request, response, cntl); }); + b_rpc_service->RegisterServiceHandler( + in_var_name2, [&](const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) -> int { + return b_req_handler->Handle(request, response, cntl); + }); + b_rpc_service->SetRequestHandler(b_req_handler); LOG(INFO) << "before HeterServer::RunServer"; std::thread server_thread(std::bind(RunServer, b_rpc_service)); @@ -157,47 +187,89 @@ TEST(SENDANDRECV, CPU) { setenv("http_proxy", "", 1); setenv("https_proxy", "", 1); std::string endpoint = "127.0.0.1:4444"; + std::string previous_endpoint = "127.0.0.1:4444"; LOG(INFO) << "before StartSendAndRecvServer"; b_rpc_service = distributed::HeterServer::GetInstance(); std::thread server_thread(StartSendAndRecvServer, endpoint); b_rpc_service->WaitServerReady(); + using MicroScope = + std::unordered_map>>; + using MiniScope = std::unordered_map; + std::shared_ptr mini_scopes(new MiniScope{}); + std::shared_ptr micro_scopes(new MicroScope{}); + std::shared_ptr> micro_scope( + new std::vector{}); + auto* mini_scope = new framework::Scope(); + (*mini_scopes)[0] = mini_scope; + auto* micro_scope_0 = &(mini_scope->NewScope()); + auto* micro_scope_1 = &(mini_scope->NewScope()); + (*micro_scope).push_back(micro_scope_0); + (*micro_scope).push_back(micro_scope_1); + (*micro_scopes)[0] = micro_scope; + b_rpc_service->SetMicroBatchScopes(micro_scopes); + b_rpc_service->SetMiniBatchScopes(mini_scopes); + + using TaskQueue = + std::unordered_map>>>; + using SharedTaskQueue = std::shared_ptr>>>>; + SharedTaskQueue task_queue_(new TaskQueue{}); + (*task_queue_)[0] = std::make_shared< + ::paddle::framework::BlockingQueue>>(); + b_rpc_service->SetTaskQueue(task_queue_); LOG(INFO) << "before HeterClient::GetInstance"; distributed::HeterClient* rpc_client = - distributed::HeterClient::GetInstance({endpoint}, 0).get(); + distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0) + .get(); PADDLE_ENFORCE_NE(rpc_client, nullptr, platform::errors::InvalidArgument( "Client Start Fail, Check Your Code & Env")); - framework::Scope scope; + framework::Scope* scope = (*micro_scope)[0]; platform::CPUPlace place; platform::CPUDeviceContext ctx(place); // create var on local scope int64_t rows_numel = 10; LOG(INFO) << "before InitTensorsOnClient"; - InitTensorsOnClient(&scope, &place, rows_numel); + InitTensorsOnClient(scope, &place, rows_numel); std::string in_var_name("x"); + std::string micro_var_name("microbatch_id"); std::string out_var_name("res"); - std::vector send_var = {in_var_name}; - std::vector recv_var = {out_var_name}; + std::vector send_var = {in_var_name, micro_var_name}; + std::vector recv_var = {}; LOG(INFO) << "before SendAndRecvAsync"; - rpc_client->SendAndRecvAsync({endpoint}, ctx, scope, in_var_name, send_var, - recv_var); - auto var = scope.Var(out_var_name); - auto value = var->GetMutable(); - auto ptr = value->mutable_data(place); - - LOG(INFO) << "before CHECK"; - for (int64_t i = 0; i < rows_numel; ++i) { - LOG(INFO) << "ptr " << i << " is " << ptr[i]; - EXPECT_EQ(ptr[i], 0.5); - } - LOG(INFO) << "end CHECK"; + rpc_client->SendAndRecvAsync(ctx, *scope, in_var_name, send_var, recv_var, + "forward"); + + LOG(INFO) << "client wait for Pop"; + auto task = (*task_queue_)[0]->Pop(); + LOG(INFO) << "client get from task queue"; + PADDLE_ENFORCE_EQ( + task.first, "x", + platform::errors::InvalidArgument( + "Recv message and Send message name not match, Check your Code")); + + InitTensorsOnClient2((*micro_scope)[1], &place, rows_numel); + LOG(INFO) << "before SendAndRecvAsync 2"; + std::string in_var_name2("y"); + rpc_client->SendAndRecvAsync(ctx, *((*micro_scope)[1]), in_var_name2, + send_var, recv_var, "backward"); + LOG(INFO) << "after SendAndRecvAsync 2"; + + auto task2 = (*task_queue_)[0]->Pop(); + PADDLE_ENFORCE_EQ( + task2.first, "y", + platform::errors::InvalidArgument( + "Recv message and Send message name not match, Check your Code")); + rpc_client->FinalizeWorker(); - // b_rpc_service->Stop(); b_rpc_service->Stop(); LOG(INFO) << "end server Stop"; server_thread.join(); diff --git a/paddle/fluid/operators/pscore/send_and_recv_op.cc b/paddle/fluid/operators/pscore/send_and_recv_op.cc index e096e7ed0177de914e59de1a4acf5c43cde1d578..46f22bcc8b26bc0b4f782ed9459491d471ad219d 100644 --- a/paddle/fluid/operators/pscore/send_and_recv_op.cc +++ b/paddle/fluid/operators/pscore/send_and_recv_op.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/profiler.h" namespace paddle { @@ -34,17 +35,22 @@ class SendAndRecvKernel : public framework::OpKernel { auto message_name = ctx.Attr("message_name"); auto send_var_name = ctx.Attr>("send_var_name"); auto recv_var_name = ctx.Attr>("recv_var_name"); - auto epmap = ctx.Attr>("endpoints"); + auto next_epmap = ctx.Attr>("next_endpoints"); + auto previous_epmap = + ctx.Attr>("previous_endpoints"); auto trainer_id = ctx.Attr("trainer_id"); + auto mode = ctx.Attr("mode"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& context = *pool.Get(place); distributed::HeterClient* rpc_client = - distributed::HeterClient::GetInstance(epmap, trainer_id).get(); + distributed::HeterClient::GetInstance(next_epmap, previous_epmap, + trainer_id) + .get(); VLOG(3) << "SendAndRecvOp message_name: " << message_name; - rpc_client->SendAndRecvAsync(epmap, context, scope, message_name, - send_var_name, recv_var_name); + rpc_client->SendAndRecvAsync(context, scope, message_name, send_var_name, + recv_var_name, mode); } }; @@ -67,11 +73,17 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "Tensor Input variable to be sent").AsDuplicable(); AddOutput("Out", "Tensor Output varibale to be recv").AsDuplicable(); AddAttr("message_name", ""); + AddAttr("mode", "forward or backward").SetDefault("forward"); AddAttr>("send_var_name", "Send Tensor's name"); AddAttr>("recv_var_name", "Recv Tensor's name"); AddAttr("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0); AddAttr>("endpoints", "Server endpoint") .SetDefault({"127.0.0.1:6164"}); + AddAttr>("next_endpoints", "Server endpoint") + .SetDefault({"127.0.0.1:6164"}); + AddAttr>("previous_endpoints", + "Previous Server endpoint") + .SetDefault({"127.0.0.1:6164"}); AddComment(R"DOC( SendAndRecv operator This operator will send variables to listen_and_serve op at the parameter server. @@ -86,7 +98,25 @@ class SendAndRecvOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(send_and_recv, ops::SendAndRecvOp, ops::SendAndRecvOpMaker); - +REGISTER_OP_CUDA_KERNEL( + send_and_recv, + ops::SendAndRecvKernel, + ops::SendAndRecvKernel, + ops::SendAndRecvKernel, + ops::SendAndRecvKernel); REGISTER_OP_CPU_KERNEL( send_and_recv, - ops::SendAndRecvKernel) + ops::SendAndRecvKernel, + ops::SendAndRecvKernel, + ops::SendAndRecvKernel, + ops::SendAndRecvKernel); + +REGISTER_OP_VERSION(send_and_recv) + .AddCheckpoint( + R"ROC(add new attributes [next_endpoints] [previous_endpoints] and [mode])ROC", + paddle::framework::compatible::OpVersionDesc() + .NewAttr("next_endpoints", "Server endpoint", + std::vector({"127.0.0.1:6164"})) + .NewAttr("previous_endpoints", "Server endpoint", + std::vector({"127.0.0.1:6164"})) + .NewAttr("mode", "forward or backward", "forward")); diff --git a/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc b/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6b1ab77b45d35dfb4439cb4e1927cc928d7ffd4c --- /dev/null +++ b/paddle/fluid/operators/pscore/send_and_recv_op_cpu_test.cc @@ -0,0 +1,282 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if defined PADDLE_WITH_PSCORE +#include +#include +#include +#include // NOLINT + +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/service/heter_client.h" +#include "paddle/fluid/distributed/service/heter_server.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace distributed = paddle::distributed; + +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; +using VarMsg = ::paddle::distributed::VariableMessage; + +USE_OP(scale); +USE_OP(send_and_recv); + +std::shared_ptr b_rpc_service; + +framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { + auto root_block = program->MutableBlock(0); + auto* block = program->AppendBlock(*root_block); + + framework::OpDesc* op = block->AppendOp(); + op->SetType("scale"); + op->SetInput("X", {"x"}); + op->SetOutput("Out", {"res"}); + op->SetAttr("scale", 0.5f); + + auto& out = *root_block->Var("res"); + out.SetType(framework::proto::VarType::LOD_TENSOR); + out.SetShape({1, 10}); + + return block; +} + +void CreateVarsOnScope(framework::Scope* scope) { + auto w_var = scope->Var("w"); + w_var->GetMutable(); + + auto out_var = scope->Var("out"); + out_var->GetMutable(); + + auto micro_var = scope->Var("microbatch_id"); + micro_var->GetMutable(); + + auto ids_var = scope->Var("ids"); + ids_var->GetMutable(); + + auto x_var = scope->Var("x"); + x_var->GetMutable(); + + auto res_var = scope->Var("res"); + res_var->GetMutable(); +} + +void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope); + auto w = scope->Var("w")->GetMutable(); + auto w_value = w->mutable_value(); + w_value->Resize({rows_numel, 10}); + for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true); + + auto ptr = w_value->mutable_data(*place); + + for (int64_t i = 0; i < w_value->numel(); ++i) { + ptr[i] = static_cast(i / 10); + } +} + +void InitTensorsOnClient(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope); + auto ids_var = scope->Var("ids")->GetMutable(); + int64_t* ids_ptr = + ids_var->mutable_data(framework::DDim({rows_numel, 1}), *place); + for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; + + auto micro_id_var = + scope->Var("microbatch_id")->GetMutable(); + float* micro_id_ptr = + micro_id_var->mutable_data(framework::DDim({1}), *place); + micro_id_ptr[0] = 0; + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) x_ptr[i] = 1.0; + + auto res_var = scope->Var("res")->GetMutable(); + float* res_ptr = + res_var->mutable_data(framework::DDim({1, rows_numel}), *place); + for (int64_t i = 0; i < rows_numel; ++i) res_ptr[i] = 1.0; +} + +void RunServer(std::shared_ptr service) { + service->StartHeterService(); +} + +void StartSendAndRecvServer(std::string endpoint) { + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + framework::Executor exe(place); + platform::CPUDeviceContext ctx(place); + LOG(INFO) << "before AppendSendAndRecvBlock"; + auto block = AppendSendAndRecvBlock(&program); + std::string in_var_name("x"); + // std::string in_var_name2("y"); + std::vector prefetch_block_ids{block->ID()}; + + LOG(INFO) << "before InitTensorsOnServer"; + InitTensorsOnServer(&scope, &place, 10); + LOG(INFO) << "end InitTensorsOnServer"; + + std::shared_ptr b_req_handler; + b_req_handler.reset(new distributed::RequestSendAndRecvHandler()); + LOG(INFO) << "before SetDevCtx"; + b_req_handler->SetDevCtx(&ctx); + LOG(INFO) << "before SetScope"; + b_req_handler->SetScope(&scope); + LOG(INFO) << "before HeterServer::GetInstance"; + b_rpc_service = distributed::HeterServer::GetInstance(); + b_rpc_service->SetEndPoint(endpoint); + LOG(INFO) << "before HeterServer::RegisterServiceHandler"; + b_rpc_service->RegisterServiceHandler( + in_var_name, [&](const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) -> int { + return b_req_handler->Handle(request, response, cntl); + }); + + b_rpc_service->SetRequestHandler(b_req_handler); + LOG(INFO) << "before HeterServer::RunServer"; + std::thread server_thread(std::bind(RunServer, b_rpc_service)); + + server_thread.join(); +} + +TEST(SENDANDRECV, CPU) { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + std::string endpoint = "127.0.0.1:4444"; + std::string previous_endpoint = "127.0.0.1:4444"; + LOG(INFO) << "before StartSendAndRecvServer"; + b_rpc_service = distributed::HeterServer::GetInstance(); + std::thread server_thread(StartSendAndRecvServer, endpoint); + b_rpc_service->WaitServerReady(); + using MicroScope = + std::unordered_map>>; + using MiniScope = std::unordered_map; + std::shared_ptr mini_scopes(new MiniScope{}); + std::shared_ptr micro_scopes(new MicroScope{}); + auto* mini_scope = new framework::Scope(); + (*mini_scopes)[0] = mini_scope; + std::shared_ptr> micro_scope( + new std::vector{}); + auto* micro_scope_0 = &(mini_scope->NewScope()); + (*micro_scope).push_back(micro_scope_0); + (*micro_scopes)[0] = micro_scope; + b_rpc_service->SetMicroBatchScopes(micro_scopes); + b_rpc_service->SetMiniBatchScopes(mini_scopes); + + using TaskQueue = + std::unordered_map>>>; + using SharedTaskQueue = std::shared_ptr>>>>; + SharedTaskQueue task_queue_(new TaskQueue{}); + (*task_queue_)[0] = std::make_shared< + ::paddle::framework::BlockingQueue>>(); + b_rpc_service->SetTaskQueue(task_queue_); + + LOG(INFO) << "before HeterClient::GetInstance"; + distributed::HeterClient* rpc_client = + distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0) + .get(); + + PADDLE_ENFORCE_NE(rpc_client, nullptr, + platform::errors::InvalidArgument( + "Client Start Fail, Check Your Code & Env")); + + framework::Scope* scope = (*micro_scope)[0]; + platform::CPUPlace place; + platform::CPUDeviceContext ctx(place); + + framework::Executor exe(place); + // create var on local scope + int64_t rows_numel = 10; + LOG(INFO) << "before InitTensorsOnClient"; + InitTensorsOnClient(scope, &place, rows_numel); + std::string in_var_name("x"); + std::string micro_var_name("microbatch_id"); + // std::string out_var_name("res"); + std::vector send_var{in_var_name, micro_var_name}; + std::vector recv_var{}; + std::string mode_str("forward"); + + LOG(INFO) << "add block & op1"; + framework::ProgramDesc program; + auto root_block = program.MutableBlock(0); + // op for forward + framework::OpDesc* op = root_block->AppendOp(); + op->SetType("send_and_recv"); + LOG(INFO) << "op1 set input"; + op->SetInput("X", std::vector({in_var_name})); + op->SetOutput("Out", {}); + op->SetAttr("next_endpoints", std::vector({endpoint})); + op->SetAttr("previous_endpoints", + std::vector({previous_endpoint})); + op->SetAttr("trainer_id", 0); + op->SetAttr("mode", mode_str); + op->SetAttr("message_name", in_var_name); + op->SetAttr("send_var_name", send_var); + op->SetAttr("recv_var_name", recv_var); + + std::string mode_str2("backward"); + // op for backward + LOG(INFO) << "add op2"; + framework::OpDesc* op2 = root_block->AppendOp(); + op2->SetType("send_and_recv"); + LOG(INFO) << "op2 set input"; + op2->SetInput("X", std::vector({in_var_name})); + op2->SetOutput("Out", {}); + op2->SetAttr("next_endpoints", std::vector({endpoint})); + op2->SetAttr("previous_endpoints", + std::vector({previous_endpoint})); + op2->SetAttr("trainer_id", 0); + op2->SetAttr("mode", mode_str2); + op2->SetAttr("message_name", in_var_name); + op2->SetAttr("send_var_name", send_var); + op2->SetAttr("recv_var_name", recv_var); + + LOG(INFO) << "exe before prepare"; + auto prepared = exe.Prepare(program, 0); + LOG(INFO) << "exe after prepare"; + + LOG(INFO) << "before RunPreparedContext"; + exe.RunPreparedContext(prepared.get(), scope, false); + + LOG(INFO) << "client wait for Pop"; + auto task = (*task_queue_)[0]->Pop(); + LOG(INFO) << "client get from task queue"; + PADDLE_ENFORCE_EQ( + task.first, "x", + platform::errors::InvalidArgument( + "Recv message and Send message name not match, Check your Code")); + + auto task2 = (*task_queue_)[0]->Pop(); + PADDLE_ENFORCE_EQ( + task2.first, "x", + platform::errors::InvalidArgument( + "Recv message and Send message name not match, Check your Code")); + + rpc_client->FinalizeWorker(); + b_rpc_service->Stop(); + LOG(INFO) << "end server Stop"; + server_thread.join(); + LOG(INFO) << "end server thread join"; +} +#endif diff --git a/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc b/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..dc4bc36d34f22a9e9aaeb9c3ac9e2bb68f4601d7 --- /dev/null +++ b/paddle/fluid/operators/pscore/send_and_recv_op_gpu_test.cc @@ -0,0 +1,306 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#if (defined PADDLE_WITH_CUDA) && (defined PADDLE_WITH_PSCORE) + +#include +#include +#include +#include // NOLINT + +#include "gtest/gtest.h" +#include "paddle/fluid/distributed/service/heter_client.h" +#include "paddle/fluid/distributed/service/heter_server.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/device_context.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace distributed = paddle::distributed; +namespace memory = paddle::memory; + +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; +using VarMsg = ::paddle::distributed::VariableMessage; + +USE_OP(scale); +USE_OP(send_and_recv); + +std::shared_ptr b_rpc_service2; + +framework::BlockDesc* AppendSendAndRecvBlock(framework::ProgramDesc* program) { + auto root_block = program->MutableBlock(0); + auto* block = program->AppendBlock(*root_block); + + framework::OpDesc* op = block->AppendOp(); + op->SetType("scale"); + op->SetInput("X", {"x"}); + op->SetOutput("Out", {"res"}); + op->SetAttr("scale", 0.5f); + + auto& out = *root_block->Var("res"); + out.SetType(framework::proto::VarType::LOD_TENSOR); + out.SetShape({1, 10}); + + return block; +} + +void CreateVarsOnScope(framework::Scope* scope) { + auto w_var = scope->Var("w"); + w_var->GetMutable(); + + auto out_var = scope->Var("out"); + out_var->GetMutable(); + + auto micro_var = scope->Var("microbatch_id"); + micro_var->GetMutable(); + + auto ids_var = scope->Var("ids"); + ids_var->GetMutable(); + + auto x_var = scope->Var("x"); + x_var->GetMutable(); + + auto res_var = scope->Var("res"); + res_var->GetMutable(); +} + +void InitTensorsOnClient(framework::Scope* scope, int64_t rows_numel, + const platform::DeviceContext& ctx) { + CreateVarsOnScope(scope); + const auto place = ctx.GetPlace(); + // auto ids_var = scope->Var("ids")->GetMutable(); + // int64_t* ids_ptr = + // ids_var->mutable_data(framework::DDim({rows_numel, 1}), + // *place); + // for (int64_t i = 0; i < rows_numel; ++i) ids_ptr[i] = i * 2; + auto stream = + reinterpret_cast(ctx).stream(); + + auto micro_id_var = + scope->Var("microbatch_id")->GetMutable(); + float* micro_id_ptr = + micro_id_var->mutable_data(framework::DDim({1}), place); + std::vector temp_vec{0}; + float* temp_ptr = temp_vec.data(); + + memory::Copy( + BOOST_GET_CONST(platform::CUDAPlace, place), + reinterpret_cast(micro_id_ptr), platform::CPUPlace(), + reinterpret_cast(temp_ptr), + micro_id_var->numel() * framework::SizeOfType(micro_id_var->type()), + stream); + + auto x_var = scope->Var("x")->GetMutable(); + float* x_ptr = + x_var->mutable_data(framework::DDim({1, rows_numel}), place); + std::vector x_vec; + for (int64_t i = 0; i < rows_numel; ++i) x_vec.push_back(1.0); + float* x_vec_ptr = x_vec.data(); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, place), + reinterpret_cast(x_ptr), platform::CPUPlace(), + reinterpret_cast(x_vec_ptr), + x_var->numel() * framework::SizeOfType(x_var->type()), stream); + + // auto res_var = scope->Var("res")->GetMutable(); + // float* res_ptr = + // res_var->mutable_data(framework::DDim({1, rows_numel}), place); + // for (int64_t i = 0; i < rows_numel; ++i) res_ptr[i] = 1.0; +} + +void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, + int64_t rows_numel) { + CreateVarsOnScope(scope); + auto w = scope->Var("w")->GetMutable(); + auto w_value = w->mutable_value(); + w_value->Resize({rows_numel, 10}); + for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true); + + auto ptr = w_value->mutable_data(*place); + + for (int64_t i = 0; i < w_value->numel(); ++i) { + ptr[i] = static_cast(i / 10); + } +} + +void RunServer(std::shared_ptr service) { + service->StartHeterService(); +} + +void StartSendAndRecvServer(std::string endpoint) { + framework::ProgramDesc program; + framework::Scope scope; + platform::CPUPlace place; + framework::Executor exe(place); + platform::CPUDeviceContext ctx(place); + LOG(INFO) << "before AppendSendAndRecvBlock"; + auto block = AppendSendAndRecvBlock(&program); + std::string in_var_name("x"); + std::vector prefetch_block_ids{block->ID()}; + + LOG(INFO) << "before InitTensorsOnServer"; + InitTensorsOnServer(&scope, &place, 10); + LOG(INFO) << "end InitTensorsOnServer"; + + std::shared_ptr b_req_handler; + b_req_handler.reset(new distributed::RequestSendAndRecvHandler()); + LOG(INFO) << "before SetDevCtx"; + b_req_handler->SetDevCtx(&ctx); + LOG(INFO) << "before SetScope"; + b_req_handler->SetScope(&scope); + LOG(INFO) << "before HeterServer::GetInstance"; + b_rpc_service2 = distributed::HeterServer::GetInstance(); + b_rpc_service2->SetEndPoint(endpoint); + LOG(INFO) << "before HeterServer::RegisterServiceHandler"; + b_rpc_service2->RegisterServiceHandler( + in_var_name, [&](const MultiVarMsg* request, MultiVarMsg* response, + brpc::Controller* cntl) -> int { + return b_req_handler->Handle(request, response, cntl); + }); + + b_rpc_service2->SetRequestHandler(b_req_handler); + LOG(INFO) << "before HeterServer::RunServer"; + std::thread server_thread(std::bind(RunServer, b_rpc_service2)); + server_thread.join(); +} + +TEST(SENDANDRECV, GPU) { + setenv("http_proxy", "", 1); + setenv("https_proxy", "", 1); + std::string endpoint = "127.0.0.1:4445"; + std::string previous_endpoint = "127.0.0.1:4445"; + LOG(INFO) << "before StartSendAndRecvServer"; + b_rpc_service2 = distributed::HeterServer::GetInstance(); + std::thread server_thread(StartSendAndRecvServer, endpoint); + b_rpc_service2->WaitServerReady(); + using MicroScope = + std::unordered_map>>; + using MiniScope = std::unordered_map; + std::shared_ptr mini_scopes(new MiniScope{}); + std::shared_ptr micro_scopes(new MicroScope{}); + auto* mini_scope = new framework::Scope(); + (*mini_scopes)[0] = mini_scope; + std::shared_ptr> micro_scope( + new std::vector{}); + auto* micro_scope_0 = &(mini_scope->NewScope()); + (*micro_scope).push_back(micro_scope_0); + (*micro_scopes)[0] = micro_scope; + b_rpc_service2->SetMicroBatchScopes(micro_scopes); + b_rpc_service2->SetMiniBatchScopes(mini_scopes); + + using TaskQueue = + std::unordered_map>>>; + using SharedTaskQueue = std::shared_ptr>>>>; + SharedTaskQueue task_queue_(new TaskQueue{}); + (*task_queue_)[0] = std::make_shared< + ::paddle::framework::BlockingQueue>>(); + b_rpc_service2->SetTaskQueue(task_queue_); + + LOG(INFO) << "before HeterClient::GetInstance"; + distributed::HeterClient* rpc_client = + distributed::HeterClient::GetInstance({endpoint}, {previous_endpoint}, 0) + .get(); + + PADDLE_ENFORCE_NE(rpc_client, nullptr, + platform::errors::InvalidArgument( + "Client Start Fail, Check Your Code & Env")); + + framework::Scope* scope = (*micro_scope)[0]; + platform::CUDAPlace place; + platform::CUDADeviceContext ctx(place); + + framework::Executor exe(place); + // create var on local scope + int64_t rows_numel = 10; + LOG(INFO) << "before InitTensorsOnClient"; + InitTensorsOnClient(scope, rows_numel, ctx); + LOG(INFO) << "after InitTensorsOnClient2"; + std::string in_var_name("x"); + std::string micro_var_name("microbatch_id"); + // std::string out_var_name("res"); + std::vector send_var{in_var_name, micro_var_name}; + std::vector recv_var{}; + std::string mode_str("forward"); + + LOG(INFO) << "add block & op1"; + framework::ProgramDesc program; + auto root_block = program.MutableBlock(0); + // op for forward + framework::OpDesc* op = root_block->AppendOp(); + op->SetType("send_and_recv"); + LOG(INFO) << "op1 set input"; + op->SetInput("X", std::vector({in_var_name})); + op->SetOutput("Out", {}); + op->SetAttr("next_endpoints", std::vector({endpoint})); + op->SetAttr("previous_endpoints", + std::vector({previous_endpoint})); + op->SetAttr("trainer_id", 0); + op->SetAttr("mode", mode_str); + op->SetAttr("message_name", in_var_name); + op->SetAttr("send_var_name", send_var); + op->SetAttr("recv_var_name", recv_var); + op->SetAttr("op_device", std::string("gpu")); + + std::string mode_str2("backward"); + // op for backward + LOG(INFO) << "add op2"; + framework::OpDesc* op2 = root_block->AppendOp(); + + op2->SetType("send_and_recv"); + LOG(INFO) << "op2 set input"; + op2->SetInput("X", std::vector({in_var_name})); + op2->SetOutput("Out", {}); + op2->SetAttr("next_endpoints", std::vector({endpoint})); + op2->SetAttr("previous_endpoints", + std::vector({previous_endpoint})); + op2->SetAttr("trainer_id", 0); + op2->SetAttr("mode", mode_str2); + op2->SetAttr("message_name", in_var_name); + op2->SetAttr("send_var_name", send_var); + op2->SetAttr("recv_var_name", recv_var); + op2->SetAttr("op_device", std::string("gpu")); + + LOG(INFO) << "exe before prepare"; + auto prepared = exe.Prepare(program, 0); + LOG(INFO) << "exe after prepare"; + + LOG(INFO) << "before RunPreparedContext"; + exe.RunPreparedContext(prepared.get(), scope, false); + + LOG(INFO) << "client wait for Pop"; + auto task = (*task_queue_)[0]->Pop(); + LOG(INFO) << "client get from task queue"; + PADDLE_ENFORCE_EQ( + task.first, "x", + platform::errors::InvalidArgument( + "Recv message and Send message name not match, Check your Code")); + + auto task2 = (*task_queue_)[0]->Pop(); + PADDLE_ENFORCE_EQ( + task2.first, "x", + platform::errors::InvalidArgument( + "Recv message and Send message name not match, Check your Code")); + + rpc_client->FinalizeWorker(); + b_rpc_service2->Stop(); + LOG(INFO) << "end server Stop"; + server_thread.join(); + LOG(INFO) << "end server thread join"; +} +#endif diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 0a39f529387a2581db370a53edeba7b74f6768fc..ed203697357446a2f22b96dada85586059a75a63 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -165,10 +165,12 @@ void BindDistCommunicator(py::module* m) { void BindHeterClient(py::module* m) { py::class_>(*m, "HeterClient") - .def(py::init( - [](const std::vector& endpoint, const int& trainer_id) { - return HeterClient::GetInstance(endpoint, trainer_id); - })) + .def(py::init([](const std::vector& endpoints, + const std::vector& previous_endpoints, + const int& trainer_id) { + return HeterClient::GetInstance(endpoints, previous_endpoints, + trainer_id); + })) .def("stop", &HeterClient::Stop); } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 2fb7395be34e999c225745d75077d0d769763414..7089ddffa7ceb7a94bf3f6da9da86f35331e0b9d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1966,7 +1966,8 @@ All parameter, weight, gradient are variables in Paddle. return self.GetWorkerScope(thread_id); }, py::return_value_policy::reference) - .def("finalize", &TrainerBase::Finalize); + .def("finalize", &TrainerBase::Finalize) + .def("ResetDataset", &TrainerBase::ResetDataset); m.def("_get_eager_deletion_vars", &framework::GetEagerDeletionCleanVars); diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index c930e1c06aeb22c4e9381cb497278b4987ee84b0..f80f8ffd0f02750210c99885f5ac635f0bbfee47 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -563,8 +563,7 @@ class Fleet(object): fleet.is_server() """ - return self._role_maker._is_server( - ) or self._role_maker._is_heter_worker() + return self._role_maker._is_server() def barrier_worker(self): """ @@ -1525,13 +1524,15 @@ class Fleet(object): else: apply_ir_passes(loss.block.program, startup_program, self) - program = paddle.static.default_main_program() - opt_info = {} - opt_info["mpi_size"] = self.worker_num() - opt_info["mpi_rank"] = self.worker_index() - for k, v in self._user_defined_strategy.trainer_desc_configs.items(): - opt_info[k] = v - program._fleet_opt = opt_info + if not self._role_maker._is_heter_parameter_server_mode: + program = paddle.static.default_main_program() + opt_info = {} + opt_info["mpi_size"] = self.worker_num() + opt_info["mpi_rank"] = self.worker_index() + for k, v in self._user_defined_strategy.trainer_desc_configs.items( + ): + opt_info[k] = v + program._fleet_opt = opt_info if self._runtime_handle is None: self._runtime_handle = RuntimeFactory()._create_runtime(context) diff --git a/python/paddle/distributed/fleet/base/role_maker.py b/python/paddle/distributed/fleet/base/role_maker.py index f89d73416960a8a2d82a1155e1bc3463255a1067..a77f52d788f89eade00dea638bfab20a65511cae 100644 --- a/python/paddle/distributed/fleet/base/role_maker.py +++ b/python/paddle/distributed/fleet/base/role_maker.py @@ -371,11 +371,6 @@ class RoleMakerBase(object): self._role = None self._current_id = -1 - # for heter parameter server mode - self._heter_trainer_endpoints = [] - self._heter_trainer_device = "CPU" - self._is_heter_parameter_server_mode = False - def _is_worker(self): """ return is_worker() of current process @@ -487,56 +482,56 @@ class RoleMakerBase(object): """ print("warning: RoleMakerBase does not have barrier worker.") - def _is_heter_worker(self): - """ - Return is_heter_worker() of current process - """ - warnings.warn("RoleMakerBase does not have function: _is_heter_worker.") - return False - - def _heter_worker_num(self): - """ - Get current total heter-worker number. - - Returns: - int: heter_worker number - """ - warnings.warn( - "RoleMakerBase does not have function: _heter_worker_num.") - return 0 - - def _get_heter_worker_endpoints(self): - """ - Returns: - string: all heter_trainers'endpoints - """ - assert self._heter_trainer_endpoints != [], "Heter Worker Endpoints Not initialized" - return self._heter_trainer_endpoints - - def _get_heter_worker_endpoint(self): - """ - Returns: - int: corresponding heter_trainer's endpoint - - e.g: if we have 4 cpu-trainer(default), 2 gpu-trainer(heter) - then No.0 and No.2 cpu-trainer will work with No.0 gpu-trainer - and No.1 and No.3 cpu-trainer will work with No.1 gpu-trainer - """ - assert self._heter_trainer_endpoints != [], "Heter Worker Endpoints Not initialized" - return self._heter_trainer_endpoints[(self._current_id) % - self._heter_worker_num()] + #def _is_heter_worker(self): + # """ + # Return is_heter_worker() of current process + # """ + # raise NotImplementedError("Please implement this method in child class") + + #def _heter_worker_num(self): + # """ + # Get current total heter-worker number. + # + # Returns: + # int: heter_worker number + # """ + # raise NotImplementedError("Please implement this method in child class") + + #def _get_heter_worker_endpoints(self): + # """ + # Returns: + # string: all heter_trainers'endpoints + # """ + # raise NotImplementedError("Please implement this method in child class") + + #def _get_heter_worker_endpoint(self): + # """ + # Returns: + # int: corresponding heter_trainer's endpoint + # """ + # raise NotImplementedError("Please implement this method in child class") class PaddleCloudRoleMaker(RoleMakerBase): def __init__(self, is_collective=False, **kwargs): super(PaddleCloudRoleMaker, self).__init__() self._is_collective = is_collective - self._non_distributed = False self._kwargs = kwargs self._role_is_generated = False + # for heterps + self._stage_id = 1 + self._stage_num = 1 + self._next_heter_trainer_endpoints = [] + self._previous_heter_trainer_endpoints = [] + self._heter_trainer_endpoints = [] + self._heter_trainer_device = "cpu" + self._heter_trainer_device_type = "cpu" + self._is_heter_parameter_server_mode = False + self._stage_trainers = [] + self._server_endpoints = [] self._worker_endpoints = [] @@ -551,6 +546,46 @@ class PaddleCloudRoleMaker(RoleMakerBase): def _all_reduce(self, input, mode="sum", comm_world="worker"): return self._gloo.all_reduce(input, mode, comm_world) + def _heter_device(self): + """ + return the heter device that current heter worker is using + """ + if not self._role_is_generated: + self._generate_role() + return self._heter_trainer_device + + def _heter_device_type(self): + """ + return the heter device type that current heter worker is using + """ + if not self._role_is_generated: + self._generate_role() + return self._heter_trainer_device_type + + def _get_stage_id(self): + """ + return stage id of current heter worker + """ + if not self._role_is_generated: + self._generate_role() + return self._stage_id + + def _get_stage_trainers(self): + """ + return trainer num of all stages + """ + if not self._role_is_generated: + self._generate_role() + return self._stage_trainers + + def _get_num_stage(self): + """ + return stage num + """ + if not self._role_is_generated: + self._generate_role() + return self._stage_num + def _is_worker(self): """ whether current process is worker @@ -655,6 +690,32 @@ class PaddleCloudRoleMaker(RoleMakerBase): self._generate_role() return self._worker_endpoints + def _get_trainer_endpoint(self): + if not self._role_is_generated: + self._generate_role() + assert self._role == Role.WORKER, "get_trainer_endpoint should be called by trainer" + return self._cur_endpoint + + def _get_heter_worker_endpoints(self): + """ + Returns: + string: all heter_trainers'endpoints + """ + if not self._role_is_generated: + self._generate_role() + assert self._heter_trainer_endpoints != [], "Heter Worker Endpoints Not initialized" + return self._heter_trainer_endpoints + + def _get_heter_worker_endpoint(self): + """ + Returns: + int: corresponding heter_trainer's endpoint + """ + if not self._role_is_generated: + self._generate_role() + assert self._role == Role.HETER_WORKER, "_get_heter_worker_endpoint should be invoked by heter worker" + return self._cur_endpoint + def _get_pserver_endpoints(self): """ get endpoint of all pservers @@ -663,6 +724,28 @@ class PaddleCloudRoleMaker(RoleMakerBase): self._generate_role() return self._server_endpoints + def _get_previous_trainers(self): + """ + invoked by heter worker + """ + if not self._role_is_generated: + self._generate_role() + assert self._role in ( + Role.WORKER, Role.HETER_WORKER + ), "_get_previous_trainers should be invoked by trainer or heter worker" + return self._previous_heter_trainer_endpoints + + def _get_next_trainers(self): + """ + invoked by heter worker + """ + if not self._role_is_generated: + self._generate_role() + assert self._role in ( + Role.WORKER, Role.HETER_WORKER + ), "_get_next_trainers should be invoked by trainer or heter worker" + return self._next_heter_trainer_endpoints + def _is_non_distributed(self): """ Return True if indispensable environment for fleetrun is not found @@ -730,23 +813,67 @@ class PaddleCloudRoleMaker(RoleMakerBase): "TRAINING_ROLE must be PSERVER or TRAINER or HETER_TRAINER, but get {}, please check your environment.". format(training_role)) - # For heter parameter server env setting - heter_trainer_eplist = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST", - "") - if heter_trainer_eplist != "": - try: - heter_trainer_eplist = os.environ[ - "PADDLE_HETER_TRAINER_IP_PORT_LIST"].split(",") - except: - raise ValueError( - "Can not Find PADDLE_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ." - ) + # For Heter Parameter Server env setting + next_heter_trainer_eplist = os.getenv( + "PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST", "") + previous_heter_trainer_eplist = os.getenv( + "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST", "") + all_heter_trainer_eplist = os.getenv( + "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST", "") + if all_heter_trainer_eplist != "": + self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",") self._is_heter_parameter_server_mode = True - heter_trainers_num = len(heter_trainer_eplist) + self._heter_trainers_num = len(self._heter_trainer_endpoints) + + if previous_heter_trainer_eplist == "": + assert training_role in ( + "TRAINER", "PSERVER" + ), "training_role should be trainer or pserver" + else: + try: + self._previous_heter_trainer_endpoints = previous_heter_trainer_eplist.split( + ",") + except: + raise ValueError( + "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ." + ) + + if next_heter_trainer_eplist == "": + assert training_role in ( + "HETER_TRAINER", "PSERVER" + ), "training_role should be heter trainer or pserver" + else: + try: + self._next_heter_trainer_endpoints = next_heter_trainer_eplist.split( + ",") + except: + raise ValueError( + "Can not Find PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ." + ) + + #self._is_heter_parameter_server_mode = True + #heter_trainers_num = len(all_heter_trainer_eplist.split(",")) + #self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",") else: self._is_heter_parameter_server_mode = False - heter_trainers_num = 0 + self._heter_trainers_num = 0 + + #if previous_heter_trainer_eplist == "": + # self._is_heter_parameter_server_mode = False + # heter_trainers_num = 0 + #else: ## for the last heter worker + # try: + # previous_heter_trainer_eplist = os.environ[ + # "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"].split(",") + # self._previous_heter_trainer_endpoints = previous_heter_trainer_eplist + # except: + # raise ValueError( + # "Can not Find PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST in env or its format doesn't match the requirement: 'IP:PORT,IP:PORT' ." + # ) + # self._is_heter_parameter_server_mode = True + # heter_trainers_num = len(all_heter_trainer_eplist.split(",")) + # self._heter_trainer_endpoints = all_heter_trainer_eplist.split(",") if training_role == "TRAINER": role = Role.WORKER @@ -756,22 +883,85 @@ class PaddleCloudRoleMaker(RoleMakerBase): "Can not find PADDLE_TRAINER_ID, please check your environment." ) current_id = int(current_id) - if len(self._worker_endpoints) > 0: - self._cur_endpoint = self._worker_endpoints[current_id] + if self._is_heter_parameter_server_mode: + self._stage_id = os.getenv("STAGE_ID", None) + if self._stage_id == None: + raise ValueError( + "Can not find STAGE_ID, please check your environment.") + self._stage_id = int(self._stage_id) + self._stage_num = os.getenv("STAGE_NUM", None) + if self._stage_num == None: + raise ValueError( + "Can not find STAGE_NUM, please check your environment.") + self._stage_num = int(self._stage_num) + self._stage_trainers = os.getenv("PADDLE_STAGE_TRAINERS_NUM", + None) + if self._stage_trainers == None: + raise ValueError( + "Can not find PADDLE_STAGE_TRAINERS_NUM, please check your environment." + ) + self._stage_trainers = eval(self._stage_trainers) + cur_port = os.getenv("PADDLE_PORT", None) + if cur_port == None: + raise ValueError( + "Can not find PADDLE_PORT, please check your environment.") + cur_ip = os.getenv("POD_IP", None) + if cur_ip == None: + raise ValueError( + "Can not find POD_IP, please check your environment.") + curr_endpoint = ":".join([cur_ip, cur_port]) + self._cur_endpoint = curr_endpoint elif training_role == "PSERVER": role = Role.SERVER - port = os.getenv("PADDLE_PORT", None) - if port == None: + cur_port = os.getenv("PADDLE_PORT", None) + if cur_port == None: raise ValueError( "Can not find PADDLE_PORT, please check your environment.") - ip = os.getenv("POD_IP", None) - if ip == None: + cur_ip = os.getenv("POD_IP", None) + if cur_ip == None: raise ValueError( "Can not find POD_IP, please check your environment.") - self._cur_endpoint = ip + ":" + port + curr_endpoint = ":".join([cur_ip, cur_port]) + self._cur_endpoint = curr_endpoint current_id = self._server_endpoints.index(self._cur_endpoint) elif training_role == "HETER_TRAINER": role = Role.HETER_WORKER + self._stage_id = os.getenv("STAGE_ID", None) + if self._stage_id == None: + raise ValueError( + "Can not find STAGE_ID, please check your environment.") + self._stage_id = int(self._stage_id) + self._stage_num = os.getenv("STAGE_NUM", None) + if self._stage_num == None: + raise ValueError( + "Can not find STAGE_NUM, please check your environment.") + self._stage_num = int(self._stage_num) + + self._stage_trainers = os.getenv("PADDLE_STAGE_TRAINERS_NUM", None) + if self._stage_trainers == None: + raise ValueError( + "Can not find PADDLE_STAGE_TRAINERS_NUM, please check your environment." + ) + self._stage_trainers = eval(self._stage_trainers) + + self._heter_trainer_device_type = os.getenv("HETER_DEVICE_TYPE", + None) + if self._heter_trainer_device_type == None: + raise ValueError( + "Can not find HETER_DEVICE_TYPE, please check your environment." + ) + assert self._heter_trainer_device_type in ( + "cpu", "gpu", "xpu" + ), "HETER_DEVICE_TYPE should be cpu,gpu or xpu" + if self._heter_trainer_device_type == "gpu": + heter_device_id = os.getenv("FLAGS_selected_gpus", "0") + self._heter_trainer_device = ":".join( + (self._heter_trainer_device_type, heter_device_id)) + if self._heter_trainer_device == "xpu": + heter_device_id = os.getenv("FLAGS_selected_xpus", "0") + self._heter_trainer_device = ":".join( + (self._heter_trainer_device_type, heter_device_id)) + cur_port = os.getenv("PADDLE_PORT", None) if cur_port == None: raise ValueError( @@ -781,15 +971,15 @@ class PaddleCloudRoleMaker(RoleMakerBase): raise ValueError( "Can not find POD_IP, please check your environment.") curr_endpoint = ":".join([cur_ip, cur_port]) - current_id = heter_trainer_eplist.index(curr_endpoint) + self._cur_endpoint = curr_endpoint + current_id = all_heter_trainer_eplist.split(",").index( + curr_endpoint) + trainers_num self._trainers_num = trainers_num self._role = role self._current_id = current_id self._nodes_num = len( set([x.split(':')[0] for x in self._worker_endpoints])) - self._heter_trainers_num = heter_trainers_num - self._heter_trainer_endpoints = heter_trainer_eplist def _collective_env(self): self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) diff --git a/python/paddle/distributed/fleet/launch.py b/python/paddle/distributed/fleet/launch.py index b12a392501a000017387032181054111c5fa94b9..dda1c191790b554b0bb066cb8d8e3b8987deb4c1 100644 --- a/python/paddle/distributed/fleet/launch.py +++ b/python/paddle/distributed/fleet/launch.py @@ -173,12 +173,19 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra "--heter_workers", type=str, default="", - help="User defined heter workers ip:port") + help="User defined heter workers in each stage ip1:port1;ip2:port2") + ps_group.add_argument( + "--heter_devices", + type=str, + default="", + help="User defined heter devices in each stage cpu;gpu;cpu") ps_group.add_argument("--worker_num", type=int, help="number of workers") ps_group.add_argument("--server_num", type=int, help="number of servers") ps_group.add_argument( - "--heter_worker_num", type=int, help="number of heter_workers") + "--heter_worker_num", + type=str, + help="number of heter_workers in each stage 1;2;3") ps_group.add_argument("--http_port", type=int, help="Gloo http Port") # parameter elastic mode @@ -323,11 +330,11 @@ def launch_ps(args, distribute_mode): if cloud_flag and distribute_mode == DistributeMode.PS: direct_start(args) return - elif cloud_flag and distribute_mode == DistributeMode.PS_HETER: - cloud_ps_heter_env_set(args) - args.workers = os.getenv("PADDLE_TRAINER_ENDPOINTS") - args.servers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST") - args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST") + #elif cloud_flag and distribute_mode == DistributeMode.PS_HETER: + # cloud_ps_heter_env_set(args) + # args.workers = os.getenv("PADDLE_TRAINER_ENDPOINTS") + # args.servers = os.getenv("PADDLE_PSERVERS_IP_PORT_LIST") + # args.heter_workers = os.getenv("PADDLE_HETER_TRAINER_IP_PORT_LIST") ps_launcher = ParameterServerLauncher(args, distribute_mode) ps_launcher.start_ps() @@ -360,11 +367,11 @@ def which_distributed_mode(args): ps_args = [ '--worker_num', '--server_num', '--heter_worker_num', '--servers', - '--workers', '--heter_workers', '--http_port' + '--workers', '--heter_workers', '--heter_devices', '--http_port' ] collective_args = ['--ips'] - ps_heter_args = ["--heter_worker_num", "--heter_workers"] + ps_heter_args = ["--heter_worker_num", "--heter_workers", "--heter_devices"] has_ps_args = [ ps_arg for ps_arg in ps_args if ps_arg in " ".join(sys.argv[1:-1]) @@ -461,13 +468,15 @@ def launch(): - ``--workers``: User defined workers ip:port, e.g., ``--workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172"`` - - ``--heter_workers``: User defined heter workers ip:port, e.g., ``--heter_workers="192.168.0.16:6172,192.168.0.17:6172"`` + - ``--heter_workers``: User defined heter workers ip1:port1;ip2:port2, e.g., ``--heter_workers="192.168.0.16:6172;192.168.0.17:6172"`` - ``--worker_num``: Number of workers (It recommend to set when in the emulated distributed environment using single node) - ``--server_num``: Number of servers (It recommend to set when in the emulated distributed environment using single node) - - ``--heter_worker_num``: Number of heter_workers (It recommend to set when in the emulated distributed environment using single node) + - ``--heter_worker_num``: Number of heter_workers in each stage (It recommend to set when in the emulated distributed environment using single node) + + - ``--heter_devices``: Type of heter_device in each stage - ``--http_port``: Gloo http Port diff --git a/python/paddle/distributed/fleet/launch_utils.py b/python/paddle/distributed/fleet/launch_utils.py index 3aced0ab996cb5328e0ef236b2ff2255edc25e08..251248f18ec39b6a1235efd89d9dcee37bd3cb51 100644 --- a/python/paddle/distributed/fleet/launch_utils.py +++ b/python/paddle/distributed/fleet/launch_utils.py @@ -146,6 +146,7 @@ class Trainer(object): self.accelerators = [] self.endpoint = None self.rank = None + self.stage = None def __str__(self): return "accelerator:{} endpoint:{} rank:{}".format( @@ -765,44 +766,44 @@ def get_custom_endpoints(origin_endpoints, offset=0): return paddle_user_define_endpoints -def cloud_ps_heter_env_set(args): - environs = {} - - paddle_trainer_endpoints = os.getenv("TRAINER_IP_PORT_LIST", "") - assert paddle_trainer_endpoints != None - - paddle_pserver_endpoints = os.getenv("PSERVER_IP_PORT_LIST", "") - assert paddle_pserver_endpoints != None - - # hard code for paddlecloud custom-framework - avilable_ports = os.getenv("TRAINER_PORTS", "").split(",") - assert len( - avilable_ports - ) >= 2, "set paddle_ports_num >= 2 in config.ini for paddlecloud job submit" - - # hard code for paddlecloud custom-framework - trainers_num = len(paddle_pserver_endpoints.split(",")) - assert trainers_num != 0 - environs["PADDLE_TRAINERS_NUM"] = trainers_num - environs["TRAINERS_NUM"] = trainers_num - - # hard code for paddlecloud custom-framework - environs["PADDLE_HETER_TRAINER_IP_PORT_LIST"] = paddle_trainer_endpoints - environs["PADDLE_PSERVERS_IP_PORT_LIST"] = paddle_pserver_endpoints - environs["PADDLE_TRAINER_ENDPOINTS"] = get_custom_endpoints( - paddle_pserver_endpoints, 1) - heter_worker_num = len(paddle_trainer_endpoints.split(",")) - if (args.heter_worker_num != None) and ( - heter_worker_num != args.heter_worker_num): - warnings.warn( - "Your fleetrun setting: heter_worker_num is {}, but we find {} device can be used, this setting has been changed.". - format(args.heter_worker_num, heter_worker_num)) - args.heter_worker_num = heter_worker_num - - for k, v in environs.items(): - os.environ[k] = str(v) - logger.info("Set heter parameter server env: {}".format( - pretty_print_envs(environs))) +#def cloud_ps_heter_env_set(args): +# environs = {} +# +# paddle_trainer_endpoints = os.getenv("TRAINER_IP_PORT_LIST", "") +# assert paddle_trainer_endpoints != None +# +# paddle_pserver_endpoints = os.getenv("PSERVER_IP_PORT_LIST", "") +# assert paddle_pserver_endpoints != None +# +# # hard code for paddlecloud custom-framework +# avilable_ports = os.getenv("TRAINER_PORTS", "").split(",") +# assert len( +# avilable_ports +# ) >= 2, "set paddle_ports_num >= 2 in config.ini for paddlecloud job submit" +# +# # hard code for paddlecloud custom-framework +# trainers_num = len(paddle_pserver_endpoints.split(",")) +# assert trainers_num != 0 +# environs["PADDLE_TRAINERS_NUM"] = trainers_num +# environs["TRAINERS_NUM"] = trainers_num +# +# # hard code for paddlecloud custom-framework +# environs["PADDLE_HETER_TRAINER_IP_PORT_LIST"] = paddle_trainer_endpoints +# environs["PADDLE_PSERVERS_IP_PORT_LIST"] = paddle_pserver_endpoints +# environs["PADDLE_TRAINER_ENDPOINTS"] = get_custom_endpoints( +# paddle_pserver_endpoints, 1) +# heter_worker_num = len(paddle_trainer_endpoints.split(",")) +# if (args.heter_worker_num != None) and ( +# heter_worker_num != args.heter_worker_num): +# warnings.warn( +# "Your fleetrun setting: heter_worker_num is {}, but we find {} device can be used, this setting has been changed.". +# format(args.heter_worker_num, heter_worker_num)) +# args.heter_worker_num = heter_worker_num +# +# for k, v in environs.items(): +# os.environ[k] = str(v) +# logger.info("Set heter parameter server env: {}".format( +# pretty_print_envs(environs))) class ParameterServerLauncher(object): @@ -828,10 +829,15 @@ class ParameterServerLauncher(object): self.is_local = True self.current_node_ip = "" + self.stage_trainer_num = [] + self.stage_heter_map = {} + self.stage_list = [] + self.stage_device_map = {} + self.stage_num = 0 + self.get_role_endpoints(args) def get_role_endpoints(self, args): - # get server envs if args.server_num: self.server_num = args.server_num if args.servers: @@ -888,35 +894,140 @@ class ParameterServerLauncher(object): else: self.worker_endpoints = args.workers - # get http_port - if args.http_port: - self.http_port = args.http_port - else: - http_port = get_ports(1, self.server_num + self.worker_num) - http_ip = self.server_endpoints.split(",")[0].split(":")[0] - self.http_port = http_ip + ":" + str(http_port[0]) - # get heter worker envs if self.distribute_mode == DistributeMode.PS_HETER: + assert args.heter_devices != "", "The setting of Parameter-Server heter mode must has heter_devices." + self.stage_device_map[1] = "cpu" # for cpu trainer + heter_devices_list = args.heter_devices.split(";") + for i in range(len(heter_devices_list)): + self.stage_device_map[i + 2] = heter_devices_list[i] + + self.stage_heter_map[1] = self.worker_endpoints if args.heter_worker_num: - self.heter_worker_num = args.heter_worker_num + self.stage_heter_trainer_num = args.heter_worker_num.split(";") + self.stage_heter_trainer_num = [ + int(trainer_num) + for trainer_num in self.stage_heter_trainer_num + ] + if args.heter_workers: - assert len( - args.heter_workers.split(",") - ) == self.heter_worker_num, "The heter_worker_num and heter_workers doesn't match. Expect heter_workers endpoints num epual to heter_worker_num, but received heter_workers enpoint num: {} and heter_worker_num {}".format( - len(args.heter_workers.split(",")), - self.heter_worker_num) - self.heter_worker_endpoints = args.heter_workers + assert len(args.heter_workers.split(";")) == len( + self.stage_heter_trainer_num + ), "The stage_num and heter_workers doesn't match. Expect heter_workers endpoints stage num epual to heter_worker_num stage, but received heter_workers enpoint stage num: {} and heter_worker_num stage {}".format( + len(args.heter_workers.split(";")), + len(self.stage_heter_trainer_num)) + heter_worker_endpoints_list = args.heter_workers.split(";") + self.heter_worker_endpoints = "" + for i in range(len(self.stage_heter_trainer_num)): + if self.heter_worker_endpoints != "": + self.heter_worker_endpoints += "," + heter_worker_endpoints = heter_worker_endpoints_list[ + i].split(",") + assert len( + heter_worker_endpoints + ) == self.stage_heter_trainer_num[ + i], "The heter trainer num in stage {} is not equal in args.heter_worker_num and args.heter_workers".format( + i) + + heter_worker_endpoints_ips = [ + x.strip().split(":")[0] + for x in heter_worker_endpoints + ] + heter_worker_endpoints_len = [ + len(x.strip().split(":")) + for x in heter_worker_endpoints + ] + + if 1 in heter_worker_endpoints_len: + # if no port value in heter_worker_endpoint, will set default port values. + heter_worker_endpoints_port = get_ports( + len(heter_worker_endpoints_ips), self.worker_num + + self.server_num + self.heter_worker_num) + new_heter_worker_endpoints = [] + for j in range(len(heter_worker_endpoints_ips)): + new_heter_worker_endpoints.append(":".join(( + heter_worker_endpoints_ips[j], str( + heter_worker_endpoints_port[j])))) + ip_port_list = ",".join(new_heter_worker_endpoints) + else: + ip_port_list = ",".join(heter_worker_endpoints) + + self.stage_heter_map[i + 2] = ip_port_list + self.stage_list.extend([i + 2] * + len(ip_port_list.split(','))) + + self.heter_worker_num += self.stage_heter_trainer_num[i] + self.heter_worker_endpoints += ip_port_list else: - ports = get_ports(self.heter_worker_num, - self.server_num + self.worker_num) - self.heter_worker_endpoints = ",".join( - ["127.0.0.1:" + str(x) for x in ports]) + for i in range(len(self.stage_heter_trainer_num)): + heter_trainer_num = self.stage_heter_trainer_num[i] + ports = get_ports(heter_trainer_num, + self.server_num + self.worker_num + + self.heter_worker_num) + ip_port_list = ",".join( + ["127.0.0.1:" + str(x) for x in ports]) + self.stage_heter_map[i + 2] = ip_port_list + self.stage_list.extend([i + 2] * + len(ip_port_list.split(','))) + self.heter_worker_num += heter_trainer_num + if self.heter_worker_endpoints != "": + self.heter_worker_endpoints += "," + self.heter_worker_endpoints += ip_port_list else: assert args.heter_workers != "", "The setting of Parameter-Server heter mode must has heter_worker_num or heter_workers." - self.heter_worker_endpoints = args.heter_workers - self.heter_worker_num = len( - self.heter_worker_endpoints.split(",")) + self.stage_heter_trainer_num = [] + heter_worker_endpoints_list = args.heter_workers.split(";") + self.heter_worker_endpoints = "" + for i in range(len(heter_worker_endpoints_list)): + if self.heter_worker_endpoints != "": + self.heter_worker_endpoints += "," + heter_worker_endpoints = heter_worker_endpoints_list[ + i].split(",") + self.stage_heter_trainer_num.append( + len(heter_worker_endpoints)) + heter_worker_endpoints_ips = [ + x.strip().split(":")[0] for x in heter_worker_endpoints + ] + heter_worker_endpoints_len = [ + len(x.strip().split(":")) + for x in heter_worker_endpoints + ] + if 1 in heter_worker_endpoints_len: + # if no port value in heter_worker_endpoint, will set default port values. + heter_worker_endpoints_port = get_ports( + len(heter_worker_endpoints_ips), self.worker_num + + self.server_num + self.heter_worker_num) + + new_heter_worker_endpoints = [] + for j in range(len(heter_worker_endpoints_ips)): + new_heter_worker_endpoints.append(":".join(( + heter_worker_endpoints_ips[j], str( + heter_worker_endpoints_port[j])))) + ip_port_list = ",".join(new_heter_worker_endpoints) + else: + ip_port_list = ",".join(heter_worker_endpoints) + + self.stage_heter_map[i + 2] = ip_port_list + self.stage_list.extend([i + 2] * + len(ip_port_list.split(','))) + + self.heter_worker_num += self.stage_heter_trainer_num[-1] + if self.heter_worker_endpoints != "": + self.heter_worker_endpoints += "," + self.heter_worker_endpoints += ip_port_list + + self.stage_trainer_num = [self.worker_num + ] + self.stage_heter_trainer_num + self.stage_num = len(self.stage_trainer_num) + + # get http_port + if args.http_port: + self.http_port = args.http_port + else: + http_port = get_ports( + 1, self.server_num + self.worker_num + self.heter_worker_num) + http_ip = self.server_endpoints.split(",")[0].split(":")[0] + self.http_port = http_ip + ":" + str(http_port[0]) # check local or user define self.server_endpoints_ips = [ @@ -931,8 +1042,14 @@ class ParameterServerLauncher(object): self.worker_endpoints_port = [ x.strip().split(":")[1] for x in self.worker_endpoints.split(",") ] - self.node_ips = list( - set(self.server_endpoints_ips + self.worker_endpoints_ips)) + self.node_ips = [] + for ip in self.server_endpoints_ips: + if ip not in self.node_ips: + self.node_ips.append(ip) + for ip in self.worker_endpoints_ips: + if ip not in self.node_ips: + self.node_ips.append(ip) + if self.distribute_mode == DistributeMode.PS_HETER: self.heter_worker_endpoints_ips = [ x.strip().split(":")[0] @@ -942,8 +1059,9 @@ class ParameterServerLauncher(object): x.strip().split(":")[1] for x in self.heter_worker_endpoints.split(",") ] - self.node_ips = list( - set(self.node_ips + self.heter_worker_endpoints_ips)) + for ip in self.heter_worker_endpoints_ips: + if ip not in self.node_ips: + self.node_ips.append(ip) if len(set(self.node_ips)) == 1: self.is_local = True @@ -968,7 +1086,6 @@ class ParameterServerLauncher(object): server_rank = 0 worker_rank = 0 heter_worker_rank = 0 - for node_rank, ip in enumerate(self.node_ips): pod = Pod() pod.rank = node_rank @@ -987,6 +1104,7 @@ class ParameterServerLauncher(object): worker.endpoint = "%s:%s" % (ip, self.worker_endpoints_port[j]) worker.rank = worker_rank + worker.stage = 1 worker_rank += 1 pod.workers.append(worker) for k in range(len(self.heter_worker_endpoints_ips)): @@ -995,6 +1113,7 @@ class ParameterServerLauncher(object): heter_worker.endpoint = "%s:%s" % ( ip, self.heter_worker_endpoints_port[k]) heter_worker.rank = heter_worker_rank + heter_worker.stage = self.stage_list[k] heter_worker_rank += 1 pod.heter_workers.append(heter_worker) @@ -1060,20 +1179,36 @@ class ParameterServerLauncher(object): current_env.pop("http_proxy", None) current_env.pop("https_proxy", None) for idx, cur_server in enumerate(pod.servers): - proc_env = { - "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, - "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, - "PADDLE_HETER_TRAINER_IP_PORT_LIST": - self.heter_worker_endpoints, - "PADDLE_PORT": cur_server.endpoint.split(":")[1], - "TRAINING_ROLE": "PSERVER", - "PADDLE_TRAINERS_NUM": str(self.worker_num), - "POD_IP": cur_server.endpoint.split(":")[0], - "PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")), - "PADDLE_GLOO_RENDEZVOUS": "3", - "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, - "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port - } + if self.distribute_mode == DistributeMode.PS_HETER: + proc_env = { + "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, + "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, + "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST": + self.heter_worker_endpoints, + "PADDLE_PORT": cur_server.endpoint.split(":")[1], + "TRAINING_ROLE": "PSERVER", + "PADDLE_TRAINERS_NUM": str(self.worker_num), + "POD_IP": cur_server.endpoint.split(":")[0], + "PADDLE_WITH_GLOO": + str(os.getenv("PADDLE_WITH_GLOO", "0")), + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, + "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port + } + else: + proc_env = { + "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, + "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, + "PADDLE_PORT": cur_server.endpoint.split(":")[1], + "TRAINING_ROLE": "PSERVER", + "PADDLE_TRAINERS_NUM": str(self.worker_num), + "POD_IP": cur_server.endpoint.split(":")[0], + "PADDLE_WITH_GLOO": + str(os.getenv("PADDLE_WITH_GLOO", "0")), + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, + "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port + } current_env.update(proc_env) cmd = [sys.executable, "-u", args.training_script @@ -1123,27 +1258,57 @@ class ParameterServerLauncher(object): device_list = [str(x) for x in range(0, heter_device_num)] for idx, cur_worker in enumerate(pod.workers): - device_id = "0" if heter_device_num == 0 else str(device_list[ - idx % heter_device_num]) - proc_env = { - "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, - "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, - "PADDLE_TRAINERS_NUM": str(self.worker_num), - "PADDLE_HETER_TRAINER_IP_PORT_LIST": - self.heter_worker_endpoints, - "TRAINING_ROLE": "TRAINER", - "PADDLE_TRAINER_ID": str(cur_worker.rank), - "PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")), - "PADDLE_GLOO_RENDEZVOUS": "3", - "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, - "FLAGS_selected_gpus": "0", - "FLAGS_selected_xpus": "0", - "CUDA_VISIBLE_DEVICES": device_id, - "XPU_VISIBLE_DEVICES": device_id, - "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port - } - current_env.update(proc_env) + device_id = "0" if heter_device_num == 0 else str(device_list[( + idx) % heter_device_num]) + if self.distribute_mode == DistributeMode.PS_HETER: + proc_env = { + "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, + "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, + "PADDLE_TRAINERS_NUM": str(self.worker_num), + "PADDLE_STAGE_TRAINERS_NUM": str(self.stage_trainer_num), + "STAGE_ID": "1", + "STAGE_NUM": str(self.stage_num), + "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST": "", + "PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST": + self.stage_heter_map[2], + "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST": + self.heter_worker_endpoints, + "HETER_DEVICE_TYPE": self.stage_device_map[1], + "TRAINING_ROLE": "TRAINER", + "POD_IP": cur_worker.endpoint.split(":")[0], + "PADDLE_PORT": cur_worker.endpoint.split(":")[1], + "PADDLE_TRAINER_ID": str(cur_worker.rank), + "PADDLE_WITH_GLOO": + str(os.getenv("PADDLE_WITH_GLOO", "0")), + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, + "FLAGS_selected_gpus": "0", + "FLAGS_selected_xpus": "0", + "CUDA_VISIBLE_DEVICES": device_id, + "XPU_VISIBLE_DEVICES": device_id, + "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port + } + else: + proc_env = { + "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, + "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, + "PADDLE_TRAINERS_NUM": str(self.worker_num), + "TRAINING_ROLE": "TRAINER", + "POD_IP": cur_worker.endpoint.split(":")[0], + "PADDLE_PORT": cur_worker.endpoint.split(":")[1], + "PADDLE_TRAINER_ID": str(cur_worker.rank), + "PADDLE_WITH_GLOO": + str(os.getenv("PADDLE_WITH_GLOO", "0")), + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, + "FLAGS_selected_gpus": "0", + "FLAGS_selected_xpus": "0", + "CUDA_VISIBLE_DEVICES": device_id, + "XPU_VISIBLE_DEVICES": device_id, + "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port + } + current_env.update(proc_env) cmd = [sys.executable, "-u", args.training_script ] + args.training_script_args self.cmds["worker"].append(cmd) @@ -1189,19 +1354,28 @@ class ParameterServerLauncher(object): elif fluid.core.is_compiled_with_xpu(): heter_device_num = fluid.core.get_xpu_device_count() device_list = [str(x) for x in range(0, heter_device_num)] - if heter_device_num == 0: - return for idx, cur_heter_worker in enumerate(pod.heter_workers): - device_id = str(device_list[idx % heter_device_num]) + device_id = "0" if heter_device_num == 0 else str(device_list[( + idx) % heter_device_num]) + stage_id = cur_heter_worker.stage proc_env = { "PADDLE_PSERVERS_IP_PORT_LIST": self.server_endpoints, "PADDLE_TRAINER_ENDPOINTS": self.worker_endpoints, - "PADDLE_HETER_TRAINER_IP_PORT_LIST": + "PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST": + self.stage_heter_map[stage_id + 1] + if stage_id <= self.stage_num - 1 else "", + "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST": + self.stage_heter_map[stage_id - 1], + "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST": self.heter_worker_endpoints, + "HETER_DEVICE_TYPE": self.stage_device_map[stage_id], + "STAGE_ID": str(stage_id), + "STAGE_NUM": str(self.stage_num), "PADDLE_PORT": cur_heter_worker.endpoint.split(":")[1], "TRAINING_ROLE": "HETER_TRAINER", "PADDLE_TRAINERS_NUM": str(self.worker_num), + "PADDLE_STAGE_TRAINERS_NUM": str(self.stage_trainer_num), "POD_IP": cur_heter_worker.endpoint.split(":")[0], "PADDLE_WITH_GLOO": str(os.getenv("PADDLE_WITH_GLOO", "0")), "PADDLE_GLOO_RENDEZVOUS": "3", diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py index 88180221ff4ff550ba8ff0b1b7af153c06c8c272..aec2436522300fde51951ce9347da813b7952eff 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py @@ -30,6 +30,16 @@ class ParameterServerOptimizer(MetaOptimizerBase): # we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [] + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(ParameterServerOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + + #self.micro_batch_size = user_defined_strategy.pipeline_configs[ + # 'micro_batch_size'] + self.num_microbatches = user_defined_strategy.pipeline_configs[ + 'accumulate_steps'] + def _is_graph_out(self): return False @@ -97,8 +107,8 @@ class ParameterServerOptimizer(MetaOptimizerBase): if not use_ps_gpu: _main = worker.delete_optimizer_pass(_main, compiled_config) _main = worker.append_send_ops_pass(_main, compiled_config) - _startup = worker.delet_extra_optimizes_pass(_startup, - compiled_config) + _startup = worker.delete_extra_optimizes_pass(_startup, + compiled_config) # for startup program _startup = worker.fake_init_ops_pass(_startup, compiled_config) @@ -122,15 +132,14 @@ class ParameterServerOptimizer(MetaOptimizerBase): from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker if self.role_maker._is_heter_worker(): # for heter worker + stage_id = self.role_maker._get_stage_id() + device = self.role_maker._heter_device_type().lower() _main = heter_worker.split_heter_worker_ops_pass( - _main, compiled_config) + _main, compiled_config, stage_id, device) else: # for default worker _main = heter_worker.split_trainer_ops_pass(_main, compiled_config) - # for startup change - _startup = heter_worker.delete_startup_useless_ops_var_pass( - _startup, _main, compiled_config) else: _main = worker.append_send_ops_pass(_main, compiled_config) _startup = _startup @@ -319,22 +328,56 @@ class ParameterServerOptimizer(MetaOptimizerBase): if self.role_maker._is_worker() or self.role_maker._is_heter_worker(): main_program, startup_program = self._build_trainer_programs( compiled_config) + if self.role_maker._is_heter_parameter_server_mode: + _origin_startup_program._heter_pipeline_opt = { + "startup_program": startup_program, + "pipeline_stage": int(self.role_maker._get_stage_id()) - 1, + "heter_place": self.role_maker._heter_device(), + } + + loss.block.program._heter_pipeline_opt = { + "trainer": "HeterPipelineTrainer", + "device_worker": "HeterSection", + "trainers": self.role_maker._get_stage_trainers( + ), ## trainer num in each stage + "trainer_id": int(self.role_maker._role_id()), + "pipeline_stage": int(self.role_maker._get_stage_id()) - 1, + "num_pipeline_stages": + int(self.role_maker._get_num_stage()), + "section_program": main_program, + "num_microbatches": self.num_microbatches, + "heter_place": self.role_maker._heter_device(), + } + else: + loss.block.program = main_program + fluid.framework.switch_startup_program(startup_program) + elif self.role_maker._is_server(): main_program, startup_program = self._build_pserver_programs( compiled_config) - - loss.block.program = main_program - fluid.framework.switch_startup_program(startup_program) - + loss.block.program = main_program + fluid.framework.switch_startup_program(startup_program) return None, None def _disable_strategy(self, dist_strategy): + #if self.role_maker._is_heter_parameter_server_mode: + # dist_strategy.pipeline = False + # dist_strategy.pipeline_configs = { + # "micro_batch_size": 1, + # "accumulate_steps": 1, + # } dist_strategy.a_sync = False a_sync_configs = dist_strategy.a_sync_configs a_sync_configs["k_steps"] = -1 dist_strategy.a_sync_configs = a_sync_configs def _enable_strategy(self, dist_strategy, context): + #if self.role_maker._is_heter_parameter_server_mode: + # dist_strategy.pipeline = True + # dist_strategy.pipeline_configs = { + # "micro_batch_size": 1, + # "accumulate_steps": 1, + # } a_sync_configs = dist_strategy.a_sync_configs if a_sync_configs["k_steps"] >= 0: return diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index d17546475c75d172fabe8c4c991a685f5faa539d..13aad87f2c7e15dca3cb75453551fd4ac6c15600 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -528,6 +528,7 @@ class TheOnePSRuntime(RuntimeBase): split_dense_table=self.role_maker._is_heter_parameter_server_mode) send_ctx = self.compiled_strategy.get_the_one_send_context( split_dense_table=self.role_maker._is_heter_parameter_server_mode, + use_origin_program=self.role_maker._is_heter_parameter_server_mode, ep_list=endpoints) trainer_config = self.async_strategy.get_trainer_runtime_config() @@ -545,8 +546,8 @@ class TheOnePSRuntime(RuntimeBase): kwargs['need_global_step'] = "0" kwargs["trainer_id"] = self.role_maker._role_id() kwargs["trainers"] = self.role_maker._worker_num() - if self.role_maker._is_heter_worker(): - kwargs["trainer_id"] += kwargs["trainers"] + #if self.role_maker._is_heter_worker(): + # kwargs["trainer_id"] += kwargs["trainers"] for table in server.servers[0].tables: if table.table_class == "BarrierTable": @@ -589,15 +590,19 @@ class TheOnePSRuntime(RuntimeBase): if launch_barrier and launch_barrier_flag: # for trainer wait server ready wait_server_ready(self.role_maker._get_pserver_endpoints()) - - # for ps-heter mode, wait heter worker ready - if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker( - ): - wait_server_ready(self.role_maker._get_heter_worker_endpoints()) - - self._heter_client = HeterClient( - self.role_maker._get_heter_worker_endpoints(), - self.role_maker._role_id()) + if self.role_maker._is_heter_parameter_server_mode and self.role_maker._get_next_trainers( + ) != []: + wait_server_ready(self.role_maker._get_next_trainers()) + if self.role_maker._is_heter_parameter_server_mode: + previous_trainers = [] + if self.role_maker._get_previous_trainers() != []: + previous_trainers = self.role_maker._get_previous_trainers() + next_trainers = [] + if self.role_maker._get_next_trainers() != []: + next_trainers = self.role_maker._get_next_trainers() + self._heter_client = HeterClient(next_trainers, + previous_trainers, + self.role_maker._role_id()) def _push_sparse_param(self, var_name, @@ -608,18 +613,16 @@ class TheOnePSRuntime(RuntimeBase): def _get_executor(self): executor = fluid.Executor(fluid.CPUPlace()) if self.role_maker._is_heter_parameter_server_mode: - heter_worker_device_guard = self.context[ - "valid_strategy"].a_sync_configs[ - "heter_worker_device_guard"].upper() - if heter_worker_device_guard not in ["GPU", "XPU", "CPU"]: - raise ValueError("Heter Worker Not Support Device {}".format( - heter_worker_device_guard)) if self.role_maker._is_heter_worker(): - if heter_worker_device_guard == "GPU": + heter_device_type = self.role_maker._heter_device_type().upper() + if heter_device_type not in ["GPU", "XPU", "CPU"]: + raise ValueError("Heter Worker Not Support Device {}". + format(device_type)) + if heter_device_type == "GPU": executor = Executor( fluid.CUDAPlace( int(os.getenv("FLAGS_selected_gpus", "0")))) - elif heter_worker_device_guard == "XPU": + elif heter_device_type == "XPU": executor = Executor( fluid.XPUPlace( int(os.getenv("FLAGS_selected_xpus", "0")))) @@ -813,14 +816,12 @@ class TheOnePSRuntime(RuntimeBase): return worker def _init_server(self, dirname=None, var_names=None, **kwargs): - if self.role_maker._is_heter_worker(): - self._init_heter_worker() - return role_id = self.compiled_strategy.get_role_id() endpoints = self.compiled_strategy.get_ps_endpoints() is_sync = self.compiled_strategy.is_sync_mode() trainers = self.compiled_strategy.get_trainers() - + if self.role_maker._is_heter_parameter_server_mode: + trainers += len(self.role_maker._get_heter_worker_endpoints()) server = self._get_fleet_proto(is_server=True, is_sync=is_sync) proto_txt = str(server) @@ -875,30 +876,17 @@ class TheOnePSRuntime(RuntimeBase): self._server.load_sparse(dirname, "0", table_id) def _run_server(self): - if self.role_maker._is_heter_worker(): - self._run_heter_worker() - return - ep = self.compiled_strategy.get_ps_endpoint() host, port = ep.split(":") self._server.run_server(host, int(port)) - def _init_heter_worker(self): - executor = self._get_executor() - executor.run(fluid.default_startup_program()) - self._init_worker() - - def _run_heter_worker(self): - executor = self._get_executor() - executor.run(fluid.default_main_program()) - def _stop_worker(self): self._communicator.stop() - if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker( - ): + if self.role_maker._is_heter_parameter_server_mode: + assert self._heter_client != None, "heter client should not be None in heterps mode" self._heter_client.stop() - executor = self._get_executor() - executor.close() + #executor = self._get_executor() + #executor.close() @staticmethod def __exclude_vars(exclude_var_names=[]): diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 30e3f9dec973c7f1be9d78616ef43e264a1d0e37..9175e706d3a40fbcaf83e0e4984930e3415ac7ef 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -93,7 +93,7 @@ from .dygraph.varbase_patch_methods import monkey_patch_varbase from . import generator from .core import _cuda_synchronize from .generator import Generator -from .trainer_desc import TrainerDesc, DistMultiTrainer, PipelineTrainer, MultiTrainer, HeterXpuTrainer +from .trainer_desc import TrainerDesc, DistMultiTrainer, PipelineTrainer, HeterPipelineTrainer, MultiTrainer, HeterXpuTrainer from .transpiler import HashName, RoundRobin from .backward import append_backward diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py index 9a75ef8c58edfcb7c748fc8a796d1f820c169159..eb8739b15b4bfdc0ca4e4ab6175326d7a7ce24c1 100644 --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -191,8 +191,9 @@ class LargeScaleKV(object): class HeterClient(object): - def __init__(self, endpoint, trainer_id): - self.heter_client_ = core.HeterClient(endpoint, trainer_id) + def __init__(self, endpoint, previous_endpoint, trainer_id): + self.heter_client_ = core.HeterClient(endpoint, previous_endpoint, + trainer_id) def stop(self): self.heter_client_.stop() diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index a246474e21e20677937ea539a603a5b9c43da950..20d44a772ba9369672668cde084fa0c164b7080b 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -16,7 +16,8 @@ from __future__ import print_function __all__ = [ - 'DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section', 'DownpourSGDOPT' + 'DeviceWorker', 'Hogwild', 'DownpourSGD', 'Section', 'DownpourSGDOPT', + 'HeterSection' ] @@ -444,6 +445,36 @@ class Section(DeviceWorker): cfg.place_id = place_id +class HeterSection(DeviceWorker): + """HeterSectionWorker.""" + + def __init__(self): + """Init.""" + super(HeterSection, self).__init__() + + def _gen_worker_desc(self, trainer_desc): + """ + Generator worker desc, which device worker is HeterSectionWorker. + Args: + trainer_desc(TrainerDesc): a TrainerDesc object + """ + from google.protobuf import text_format + from . import core + trainer_desc.device_worker_name = "HeterSectionWorker" + heter_pipeline_opt = self._program._heter_pipeline_opt + heter_section_param = trainer_desc.heter_section_param + heter_section_param.num_microbatches = heter_pipeline_opt[ + "num_microbatches"] + heter_section_param.pipeline_stage = heter_pipeline_opt[ + "pipeline_stage"] + heter_section_param.num_pipeline_stages = heter_pipeline_opt[ + "num_pipeline_stages"] + cfg = heter_section_param.section_config + program = heter_pipeline_opt["section_program"] + cfg.program_desc.ParseFromString(program._get_desc() + .serialize_to_string()) + + class DeviceWorkerFactory(object): def _create_device_worker(self, worker_type): classname = worker_type.capitalize() diff --git a/python/paddle/fluid/distributed/node.py b/python/paddle/fluid/distributed/node.py index 5a1e9362c2fbcbbe85fc4eb360ec98e5951d9975..6fc1c51e06a0f971e8b9082936f3d8bfb9f007a8 100644 --- a/python/paddle/fluid/distributed/node.py +++ b/python/paddle/fluid/distributed/node.py @@ -49,7 +49,6 @@ class DownpourServer(Server): self.server_.downpour_server_param.service_param.server_class = "DownpourBrpcPsServer" self.server_.downpour_server_param.service_param.client_class = "DownpourBrpcPsClient" self.server_.downpour_server_param.service_param.service_class = "DownpourPsService" - self.server_.downpour_server_param.service_param.start_server_port = 0 self.server_.downpour_server_param.service_param.server_thread_num = 12 def add_sparse_table(self, table_id, learning_rate, slot_key_vars, diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 4d424fc09f14a3483d7e1cfbfe1fc3def237e553..4efea86591373bdb240407f84f2f7f354b655e8c 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -681,6 +681,7 @@ class Executor(object): self.place = framework._get_paddle_place(place) self.program_caches = dict() self.ctx_caches = dict() + self.trainer_caches = dict() self.scope_caches = dict() self.var_caches = dict() self.pruned_program_caches = dict() @@ -704,6 +705,9 @@ class Executor(object): def _get_ctx_cache(self, program_cache_key): return self.ctx_caches.get(program_cache_key, None) + def _get_trainer_cache(self, program_cache_key): + return self.trainer_caches.get(program_cache_key, None) + def _get_program_cache(self, program_cache_key): return self.program_caches.get(program_cache_key, None) @@ -725,6 +729,9 @@ class Executor(object): def _add_ctx_cache(self, ctx_cache_key, ctx): self.ctx_caches[ctx_cache_key] = ctx + def _add_trainer_cache(self, trainer_cache_key, ctx): + self.trainer_caches[trainer_cache_key] = ctx + def _add_scope_cache(self, scope_cache_key, scope): self.scope_caches[scope_cache_key] = scope @@ -986,8 +993,11 @@ class Executor(object): exe.close() """ if not self._closed: - self._default_executor.close() self._closed = True + for k, trainer_instance in self.trainer_caches.items(): + self._default_executor.release_trainer(trainer_instance) + del trainer_instance + self._default_executor.close() def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, return_numpy, return_merged): @@ -1271,6 +1281,18 @@ class Executor(object): program, fetch_list=fetch_list, use_program_cache=use_program_cache) + + if isinstance(program, Program) and program._heter_pipeline_opt: + ## change default executor + heter_place = program._heter_pipeline_opt["heter_place"] + heter_place = framework._get_paddle_place(heter_place) + p = core.Place() + p.set_place(heter_place) + self._default_executor = core.Executor(p) + # TODO(zhangminxu): support heterps pipeline training using exe.run + if "startup_program" in program._heter_pipeline_opt: + program = program._heter_pipeline_opt["startup_program"] + if isinstance(program, Program) and \ len(program.global_block().ops) == 0: if use_default_main_program: @@ -1569,6 +1591,9 @@ class Executor(object): if program._pipeline_opt: trainer = TrainerFactory()._create_trainer( program._pipeline_opt) + elif program._heter_pipeline_opt: + trainer = TrainerFactory()._create_trainer( + program._heter_pipeline_opt) else: trainer = TrainerFactory()._create_trainer(program._fleet_opt) trainer._set_thread_barrier(program._is_distributed) @@ -1579,6 +1604,9 @@ class Executor(object): if program._pipeline_opt: trainer = TrainerFactory()._create_trainer( program.program._pipeline_opt) + elif program._heter_pipeline_opt: + trainer = TrainerFactory()._create_trainer( + program.program._heter_pipeline_opt) else: trainer = TrainerFactory()._create_trainer( program.program._fleet_opt) @@ -1631,6 +1659,39 @@ class Executor(object): dataset.set_thread(1) dataset.set_filelist(['None']) dataset.set_use_var(data_vars) + elif program._heter_pipeline_opt is not None: + stage_id = program._heter_pipeline_opt["pipeline_stage"] + heter_place = program._heter_pipeline_opt["heter_place"] + if stage_id != 0: + import paddle + if dataset is not None: + raise RuntimeError( + "dataset should be None for heter pipeline mode") + # The following fake dataset is created to call + # the _prepare_trainer api, and it is meaningless. + data_vars = [] + for var in program.global_block().vars.values(): + if var.is_data: + data_vars.append(var) + if core.is_compiled_with_npu(): + dataset = paddle.fluid.DatasetFactory().create_dataset( + 'InMemoryDataset') + else: + dataset = paddle.fluid.DatasetFactory().create_dataset( + 'FileInstantDataset') + dataset.set_batch_size(1) + dataset.set_thread(1) + dataset.set_filelist(['None']) + dataset.set_use_var(data_vars) + else: + if dataset is None: + raise RuntimeError( + "dataset is need and should be initialized") + ## change default executor + heter_place = framework._get_paddle_place(heter_place) + p = core.Place() + p.set_place(heter_place) + self._default_executor = core.Executor(p) else: if dataset is None: raise RuntimeError("dataset is need and should be initialized") @@ -1662,7 +1723,6 @@ class Executor(object): 'op_role', core.op_proto_and_checker_maker.OpRole.Optimize) fetch_list = None - scope, trainer = self._prepare_trainer( program=program, dataset=dataset, @@ -1677,14 +1737,28 @@ class Executor(object): trainer._gen_trainer_desc() if program._pipeline_opt is None: - self._dump_debug_info(program=program, trainer=trainer) + if program._heter_pipeline_opt is None: + self._dump_debug_info(program=program, trainer=trainer) # in case of calling _set_use_ps_gpu explicitly if dataset.use_ps_gpu is False: dataset._set_use_ps_gpu(trainer.proto_desc.use_ps_gpu) dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num) - trainer_instance = self._default_executor.init_for_dataset( - program.desc, trainer._desc(), scope, dataset.dataset) + if program._heter_pipeline_opt is None: + trainer_instance = self._default_executor.init_for_dataset( + program.desc, trainer._desc(), scope, dataset.dataset) + else: + # cache trainer instance for heterps pipeline training + if fetch_list == None: + fetch_list = [] + cache_key = _get_strong_program_cache_key(program, None, fetch_list) + trainer_instance = self._get_trainer_cache(cache_key) + if trainer_instance is None: + trainer_instance = self._default_executor.init_for_dataset( + program.desc, trainer._desc(), scope, dataset.dataset) + self._add_trainer_cache(cache_key, trainer_instance) + else: + trainer_instance.ResetDataset(dataset.dataset) if fetch_handler is not None: scope0 = trainer_instance.get_worker_scope(0) @@ -1692,11 +1766,12 @@ class Executor(object): fetch_monitor.start() self._default_executor.run_from_dataset(trainer_instance) fetch_monitor.stop() - self._default_executor.release_trainer(trainer_instance) + if program._heter_pipeline_opt is None: + self._default_executor.release_trainer(trainer_instance) else: - self._default_executor.run_from_dataset(trainer_instance) - self._default_executor.release_trainer(trainer_instance) + if program._heter_pipeline_opt is None: + self._default_executor.release_trainer(trainer_instance) dataset._dynamic_adjust_after_train() dataset._finish_to_run() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 59c6cecb19c48ffa750129e200f845a43542c30f..cc10ded6923758bd7aba3fb084e002d713fb1a0e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -4376,6 +4376,9 @@ class Program(object): # assigned if this program has been parsed by a pipeline optimizer self._pipeline_opt = None + # assigned if this program has been parsed by a heter pipeline parameter server optimizer + self._heter_pipeline_opt = None + # appending gradients times self._appending_grad_times = 0 diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py index e8668e39bd4e2e9724d79352f805aa6e6d68e5c4..ebf9395361ce18f030da0e921824409d86e3001a 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/heter_trainer_pass.py @@ -20,6 +20,7 @@ import paddle.fluid.framework as framework from paddle.fluid.transpiler.details.program_utils import delete_ops from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_heter_ops +from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import union_forward_gradient_op from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_heter_program from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import create_trainer_program from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_block_joints @@ -27,7 +28,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import find_op from paddle.fluid.incubate.fleet.parameter_server.ir.trainer_pass import get_vars_name_in_block -def split_heter_worker_ops_pass(program, config): +def split_heter_worker_ops_pass(program, config, stage_id, device): """ split heter worker program from origin-program 1. find heter op (located on different device) @@ -43,19 +44,15 @@ def split_heter_worker_ops_pass(program, config): ) return program - current_device = "gpu" - if current_device not in heter_ops: - raise ValueError("Op which run on device {} not exist.".format( - current_device)) - + program_block_ops = union_forward_gradient_op(program_block_ops) block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) heter_program = framework.Program() - create_heter_program(program, config, heter_program, heter_ops, - block_vars_detail, current_device) + create_heter_program(program, config, heter_program, program_block_ops, + heter_ops, block_vars_detail, device, stage_id) return heter_program -def split_trainer_ops_pass(program, config): +def split_trainer_ops_pass(program, config, default_device="cpu"): """ split cpu-trainer program from origin-program 1. find heter op (located on different device) @@ -63,38 +60,13 @@ def split_trainer_ops_pass(program, config): 3. create cpu-trainer program, add send&recv op """ # Todo: support user define default_device (MrChengmo) - default_deveice = "cpu" - program, heter_ops, _, program_block_ops = find_heter_ops(program, - default_deveice) - block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) - create_trainer_program(program, config, heter_ops, block_vars_detail) - return program - + default_device_ = default_device + program, heter_ops, default_ops, program_block_ops = find_heter_ops( + program, default_device_) + program_block_ops = union_forward_gradient_op(program_block_ops) -def delete_startup_useless_ops_var_pass(startup_program, main_program, config): - """ - delete variable which not used in current main_program - """ - # find all op and its var - vars_in_main_program = get_vars_name_in_block(main_program.global_block()) - - block_nums = startup_program.num_blocks - for block_index in range(1, block_nums): - current_block = startup_program.block(block_index) - # delete useless op - need_delete_op = [] - for op in current_block.ops: - inputs, outputs = find_op_input_output(startup_program, - current_block, op) - inputs += outputs - # Todo: delete some concat op - if list(set(inputs) & set(vars_in_main_program)) == None: - need_delete_op.append(op) - delete_ops(current_block, need_delete_op) - - # delete useless var - for var in current_block.vars: - if var.name not in vars_in_main_program: - startup_program._remove_var(var.name) - - return startup_program + block_vars_detail = find_block_joints(program, program_block_ops, heter_ops) + trainer_program = program.clone() + create_trainer_program(trainer_program, program, config, program_block_ops, + block_vars_detail) + return trainer_program diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index 9246b8e44840c19b2a1415c56323d6cc9a386328..4b8c7ccbb69cfcbea8c4db99b191b84b0d4ac1fd 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -216,12 +216,36 @@ class CompileTimeStrategy(object): except Exception: return self.role_maker.get_heter_worker_endpoints() + def get_next_stage_trainers(self): + try: + return self.role_maker._get_next_trainers() + except Exception: + return self.role_maker.get_next_trainers() + def get_heter_worker_endpoint(self): try: return self.role_maker._get_heter_worker_endpoint() except Exception: return self.role_maker.get_heter_worker_endpoint() + def get_trainer_endpoints(self): + try: + return self.role_maker._get_trainer_endpoints() + except Exception: + return self.role_maker.get_trainer_endpoints() + + def get_trainer_endpoint(self): + try: + return self.role_maker._get_trainer_endpoint() + except Exception: + return self.role_maker.get_trainer_endpoint() + + def get_previous_stage_trainers(self): + try: + return self.role_maker._get_previous_trainers() + except Exception: + return self.role_maker.get_previous_trainers() + def get_origin_programs(self): return self.origin_main_program, self.origin_startup_program diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py index 89b2a8237dc65ab8ebd6b145c878e9da5501946d..a0832d886d6c6a21464b888222d072b54e567707 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/trainer_pass.py @@ -105,6 +105,9 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): if op.type in SPARSE_OP_TYPE_DICT.keys() \ and op.attr('remote_prefetch') is True: param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0] + if config.is_heter_ps_mode: + # trick for matchnet, need to modify + param_name += op.input("Ids")[0][0] ops = pull_sparse_ops.get(param_name, []) ops.append(op) pull_sparse_ops[param_name] = ops @@ -114,6 +117,9 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): for param, ops in pull_sparse_ops.items(): all_ops = program.global_block().ops op_idxs = [all_ops.index(op) for op in ops] + op_device = "" + if config.is_heter_ps_mode: + op_device = ops[0].attr("op_device") inputs = [ program.global_block().vars[op.input("Ids")[0]] for op in ops ] @@ -158,7 +164,11 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): outputs_idxs[out_id] = idx if min(outputs_idxs) - max(inputs_idxs) >= 1: - distributed_idx = max(inputs_idxs) + 1 + + if max(inputs_idxs) == -1: + distributed_idx = min(op_idxs) + else: + distributed_idx = max(inputs_idxs) + 1 if use_ps_gpu: program.global_block()._insert_op( @@ -183,7 +193,8 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): "is_distributed": is_distributed, "padding_idx": padding_idx, "table_id": table_id, - "lookup_table_version": op_type + "lookup_table_version": op_type, + "op_device": op_device }) else: for i in range(len(inputs_idxs)): @@ -199,7 +210,8 @@ def distributed_ops_pass(program, config, use_ps_gpu=False): "is_distributed": is_distributed, "padding_idx": padding_idx, "table_id": table_id, - "lookup_table_version": op_type + "lookup_table_version": op_type, + "op_device": op_device }) pull_sparse_ops = _get_pull_sparse_ops(program) @@ -404,7 +416,7 @@ def ps_gpu_pass(program): return program -def delet_extra_optimizes_pass(program, config): +def delete_extra_optimizes_pass(program, config): optimize_vars = [] optimize_op_role_vars = [] optimize_need_delete_vars = [] @@ -416,7 +428,6 @@ def delet_extra_optimizes_pass(program, config): optimize_vars = list(set(optimize_vars)) optimize_op_role_vars = list(set(optimize_op_role_vars)) - for var in optimize_vars: if var not in optimize_op_role_vars: optimize_need_delete_vars.append(var) @@ -453,7 +464,7 @@ def find_heter_ops(program, default_device="cpu"): elif op_type in COMMUNICATE_OPS_TYPE and current_heter_device != default_device: # for distributed communciate ops: send & recv & barrier etc. # Todo: need update this method - op._set_attr('op_device', current_heter_device) + #op._set_attr('op_device', current_heter_device) return True elif op_device == None or op_device == default_device: op._set_attr('op_device', default_device) @@ -474,6 +485,138 @@ def find_heter_ops(program, default_device="cpu"): heter_ops[op_device] = {} current_heter_block_ops.append(op) + origin_porgram = program.clone() + block = program.global_block() + ''' + re-place sum op to fix bug for union forward backward op + ''' + var2idx = {} + op_list = list(block.ops) + op_size = len(op_list) + + for i in range(op_size - 1, -1, -1): + op_list = list(block.ops) + op = op_list[i] + if "_grad" in op.type: + forward_op_type = op.type.split("_grad")[0] + if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \ + and op.attr('remote_prefetch') is True: + param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0] + if param_name in var2idx: + ## insert sum op & remove sum op from var2idx and origin place + op_list = list(block.ops) + sum_op = op_list[var2idx[param_name]] + sum_op_inputs = { + sum_op.input_names[0]: [ + block.vars[input] + for input in sum_op.input_arg_names + ] + } + sum_op_outputs = { + sum_op.output_names[0]: [ + block.vars[output] + for output in sum_op.output_arg_names + ] + } + block._insert_op( + index=i + 1, + type=sum_op.type, + inputs=sum_op_inputs, + outputs=sum_op_outputs, + attrs=sum_op.all_attrs()) + block._remove_op(var2idx[param_name] + 1) + var2idx.pop(param_name) + for var_ in var2idx: + var2idx[var_] += 1 + elif forward_op_type == "elementwise_mul": + """ + get output varname of pre op + + """ + output_vars_no_grad = [] + for key in pre_op.output_names: + for varname in op.output(key): + if varname == "@EMPTY@": + continue + if "lod_tensor_blocking_queue" in varname: + continue + output_vars_no_grad.append(varname.split("@GRAD")[0]) + for no_grad_var in output_vars_no_grad: + if no_grad_var in var2idx: + """ + insert sum op & remove sum op from var2idx and origin place + + """ + op_list = list(block.ops) + sum_op = op_list[var2idx[no_grad_var]] + sum_op_inputs = { + sum_op.input_names[0]: [ + block.vars[input] + for input in sum_op.input_arg_names + ] + } + sum_op_outputs = { + sum_op.output_names[0]: [ + block.vars[output] + for output in sum_op.output_arg_names + ] + } + block._insert_op( + index=i + 1, + type=sum_op.type, + inputs=sum_op_inputs, + outputs=sum_op_outputs, + attrs=sum_op.all_attrs()) + block._remove_op(var2idx[no_grad_var] + 1) + var2idx.pop(no_grad_var) + for var_ in var2idx: + var2idx[var_] += 1 + else: + if op.type == "sum": + var = op.output("Out")[0] + if "@GRAD" in var: + origin_var = var.split("@GRAD")[0] + pre_op = op_list[i - 1] + if "_grad" in pre_op.type: + forward_op_type = pre_op.type.split("_grad")[0] + if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \ + and pre_op.attr('remote_prefetch') is True: + param_name = pre_op.input(SPARSE_OP_TYPE_DICT[ + forward_op_type])[0] + if param_name == origin_var and op.attr( + "op_device") == pre_op.attr("op_device"): + continue + else: + var2idx[origin_var] = i + elif forward_op_type == "elementwise_mul": + output_vars = [] + for key in pre_op.output_names: + for varname in pre_op.output(key): + if varname == "@EMPTY@": + continue + if "lod_tensor_blocking_queue" in varname: + continue + output_vars.append(varname) + input_vars = [] + for key in op.input_names: + for varname in op.input(key): + if varname == "@EMPTY@": + continue + if "lod_tensor_blocking_queue" in varname: + continue + input_vars.append(varname) + is_match = False + for varname in output_vars: + if varname in input_vars: + is_match = True + break + if is_match: + continue + else: + var2idx[origin_var] = i + else: + var2idx[origin_var] = i + origin_porgram = program.clone() block = program.global_block() @@ -481,7 +624,6 @@ def find_heter_ops(program, default_device="cpu"): default_ops = {default_device: {}} heter_ops = {} block_index = 0 - # heter_ops: {"gpu": {1:[op1, op2, ...], 2:[op1, op2, ...] }; "xpu": {3:[op1, op2, ...], 4:[op1, op2, ...] }} current_heter_block_ops = [] current_default_block_ops = [] @@ -552,12 +694,12 @@ def find_heter_ops(program, default_device="cpu"): print( "There are {} OPs in your main_program, and contains {} heter-OPs which is made up of {} heter-blocks.". format(len(block.ops), total_heter_ops, heter_blocks)) - return origin_porgram, heter_ops, default_ops, program_block_ops + return origin_porgram, heter_ops, default_ops, program_block_ops -def create_heter_program(program, config, heter_program, heter_ops, - block_var_detail, current_device): +def create_heter_program(program, config, heter_program, program_block_ops_list, + heter_ops, block_var_detail, current_device, stage_id): # This function mainly includes the following contents: # 1. For every heter block: # a) copy heter device op from origin program @@ -571,7 +713,7 @@ def create_heter_program(program, config, heter_program, heter_ops, # d) copy send op from origin program for var@grad which loacted in current heter block # e) re-check every op in current blcok if its device is not current heter devie # 2. Create send op for step counter in last heter-block - # 3. Create Listen&Serv OP for distributed training + # 3. Create Listen&Serv OP and Send&Recv OP for distributed training # 4. update CompileTimeStrategy for heter_program optimizer_block = [] @@ -579,33 +721,84 @@ def create_heter_program(program, config, heter_program, heter_ops, send_grad_var_list = [] pre_block_idx = heter_program.num_blocks - 1 - for index, heter_block_ops in heter_ops[current_device].items(): - heter_block = heter_program._create_block(pre_block_idx) - optimizer_block.append(heter_block) - for _, op in enumerate(heter_block_ops): + stage_id = int(stage_id) + print("stage id", stage_id) + heter_block_ops_forward = program_block_ops_list[stage_id - 1]["forward"] + + heter_block_ops_backward = program_block_ops_list[stage_id - 1]["backward"] + + heter_block = heter_program._create_block(pre_block_idx) + optimizer_block.append(heter_block) + for _, op in enumerate(heter_block_ops_forward): + block_append_op(heter_program, program, heter_block, op) + + entrance_vars = block_var_detail[stage_id - 1]["forward"]["entrance"] + add_vars_by_var_list(entrance_vars, program, heter_program, heter_block) + exit_vars = block_var_detail[stage_id - 1]["forward"]["exit"] + add_vars_by_var_list(exit_vars, program, heter_program, heter_block) + + first_op_index_fp = len(heter_block.ops) + + if stage_id < len(program_block_ops_list): + + heter_block_bp = heter_program._create_block(pre_block_idx) + optimizer_block.append(heter_block_bp) + + for _, op in enumerate(heter_block_ops_backward): + block_append_op(heter_program, program, heter_block_bp, op) + + bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][ + "entrance"] + add_vars_by_var_list(bp_entrance_vars, program, heter_program, + heter_block_bp) + bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"] + add_vars_by_var_list(bp_exit_vars, program, heter_program, + heter_block_bp) + backward_comm_info = get_communicate_var_info( + program, stage_id, bp_entrance_vars, type="backward") + + grad_to_block_id.append(backward_comm_info["block_input_var_name"] + ":" + + str(heter_block_bp.idx)) + + else: + for _, op in enumerate(heter_block_ops_backward): block_append_op(heter_program, program, heter_block, op) - entrance_vars = block_var_detail[index]["entrance"] - add_vars_by_var_list(entrance_vars, program, heter_program, heter_block) - exit_vars = block_var_detail[index]["exit"] - add_vars_by_var_list(exit_vars, program, heter_program, heter_block) + bp_entrance_vars = block_var_detail[stage_id - 1]["backward"][ + "entrance"] + add_vars_by_var_list(bp_entrance_vars, program, heter_program, + heter_block) + bp_exit_vars = block_var_detail[stage_id - 1]["backward"]["exit"] + add_vars_by_var_list(bp_exit_vars, program, heter_program, heter_block) + + heter_block_bp = heter_block - comm_info = get_communicate_var_info(program, index, entrance_vars, - exit_vars) + forward_comm_info = get_communicate_var_info( + program, stage_id, entrance_vars, type="forward") - grad_to_block_id.append(comm_info["block_input_var_name"] + ":" + str( - heter_block.idx)) + grad_to_block_id.append(forward_comm_info["block_input_var_name"] + ":" + + str(heter_block.idx)) - first_op_index = 0 + first_op_index_bp = len(heter_block_bp.ops) - # add send op - send_grad_var_list = send_grad_var_list + add_heter_send_op( - program, heter_program, heter_block, block_var_detail[index]) + if stage_id <= len(block_var_detail) - 1: + static_var = insert_communicate_op(program, config, heter_block, + stage_id, first_op_index_fp, + block_var_detail, current_device) + static_var_bp = insert_communicate_op( + program, config, heter_block_bp, stage_id, first_op_index_bp, + block_var_detail, current_device, False) + # add send op + send_grad_var_list = add_heter_send_op( + program, heter_program, heter_block_bp, block_var_detail[stage_id - 1]) + + # --------------- # add step conter send_input_vars = [] dummy_output = [] pserver_endpoints = config.get_ps_endpoints() + # optimizer_block[-1].append_op( # type="send", # inputs={"X": send_input_vars}, @@ -619,14 +812,18 @@ def create_heter_program(program, config, heter_program, heter_ops, # add info in listen&serv attrs = { + #"mode": "sync", + #"trainers": config.get_trainers(), + #"trainer_id": config.get_role_id() + config.get_trainers(), "message_to_block_id": grad_to_block_id, "optimize_blocks": optimizer_block, # runtime attribute "endpoint": config.get_heter_worker_endpoint(), - "fanin": config.get_trainers(), + "fanin": len(config.get_previous_stage_trainers()), "pserver_id": config.get_role_id(), "distributed_mode": config.get_distributed_mode(), - "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)) + "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)), + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE } # append the listen_and_serv op heter_program.global_block().append_op( @@ -648,7 +845,8 @@ def check_heter_compile_time_strategy(program, config, send_grad_var_list): config.remove_var_pair_by_grad(useless_grad_var) -def create_trainer_program(program, config, heter_ops, block_var_detail): +def create_trainer_program(program, origin_program, config, + program_block_ops_list, block_var_detail): # This function mainly includes the following contents: # 1. For every heter block in origin program # a) delete heter op and related variables @@ -660,17 +858,127 @@ def create_trainer_program(program, config, heter_ops, block_var_detail): # d) remove send op which related var@grad is not in trainer program # 2. check every op's device static_var = [] - for device in heter_ops.keys(): - for heter_block_index in sorted(heter_ops[device]): - static_var += replace_ops_by_communicate_op( - program, config, heter_block_index, - heter_ops[device][heter_block_index], block_var_detail) - remove_trainer_send_op(program, config, heter_block_index, - block_var_detail) - deleter_trainer_useless_var(config, program, static_var) + for heter_block_index in range(1, len(program_block_ops_list)): + ops_list = program_block_ops_list[heter_block_index][ + "forward"] + program_block_ops_list[heter_block_index]["backward"] + static_var += replace_ops_by_communicate_op( + program, config, heter_block_index, ops_list, block_var_detail) + remove_trainer_send_op(program, config, heter_block_index, + block_var_detail) + + optimizer_block = [] + grad_to_block_id = [] + + bp_ops_list = program_block_ops_list[0]["backward"] + delete_same_ops(program.global_block(), bp_ops_list) + delete_trainer_useless_var(config, program, static_var) + backward_block = create_backward_block(program, origin_program, config, + bp_ops_list, block_var_detail) + + bp_entrance_vars = block_var_detail[0]["backward"]["entrance"] + backward_comm_info = get_communicate_var_info( + origin_program, 1, bp_entrance_vars, type="backward") + + grad_to_block_id.append(backward_comm_info["block_input_var_name"] + ":" + + str(backward_block.idx)) + optimizer_block.append(backward_block) + + attrs = { + #"mode": "sync", + #"trainers": config.get_trainers(), + #"trainer_id": config.get_role_id(), + "message_to_block_id": grad_to_block_id, + "optimize_blocks": optimizer_block, + # runtime attribute + "endpoint": config.get_trainer_endpoint(), ## get trainer endpoint + "fanin": 0, ## get heter worker + "pserver_id": config.get_role_id(), + "distributed_mode": config.get_distributed_mode(), + "rpc_exec_thread_num": int(os.getenv("CPU_NUM", 32)), + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + } + # append the listen_and_serv op + program.global_block()._insert_op( + index=0, + type="heter_listen_and_serv", + inputs={'X': []}, + outputs={}, + attrs=attrs) + + ## TODO add check for bp block check_op_device(program.global_block(), DEFAULT_DEVICE) +def insert_communicate_op(orign_program, + config, + heter_block, + stage_id, + first_op_index, + block_var_detail, + device, + is_forward=True): + + if is_forward: + next_heter_worker_endpoints = config.get_next_stage_trainers() + previous_heter_worker_endpoints = config.get_previous_stage_trainers() + entrance_var = block_var_detail[stage_id]["forward"]["entrance"] + comm_info = get_communicate_var_info(orign_program, stage_id + 1, + entrance_var) + + else: + next_heter_worker_endpoints = config.get_next_stage_trainers() + #if next_heter_worker_endpoints == "": + # next_heter_worker_endpoints = [] + previous_heter_worker_endpoints = config.get_previous_stage_trainers() + entrance_var = block_var_detail[stage_id - 1]["backward"]["exit"] + comm_info = get_communicate_var_info(orign_program, stage_id - 1, + entrance_var, "backward") + + heter_block._insert_op( + index=first_op_index, + type="send_and_recv", + inputs={"X": heter_block.vars[entrance_var[0]]}, + outputs={"Out": []}, + attrs={ + "mode": "forward" if is_forward else "backward", + "send_var_name": entrance_var + ["microbatch_id"], + "recv_var_name": [], + "message_name": comm_info["block_input_var_name"], + "next_endpoints": next_heter_worker_endpoints, + "previous_endpoints": previous_heter_worker_endpoints, + "trainer_id": config.get_role_id(), + "op_device": device, + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) + + return entrance_var + + +def create_backward_block(program, origin_program, config, bp_ops_list, + block_var_detail): + pre_block_idx = program.num_blocks - 1 + heter_block = program._create_block(pre_block_idx) + + for _, op in enumerate(bp_ops_list): + if op.type == "send": + send_varnames = op.attr('send_varnames') + is_skip = False + for varname in send_varnames: + if varname not in program.global_block( + ).vars and varname not in heter_block.vars: + is_skip = True + break + if is_skip == True: + continue + block_append_op(program, origin_program, heter_block, op) + + entrance_vars = block_var_detail[0]["backward"]["entrance"] + add_vars_by_var_list(entrance_vars, origin_program, program, heter_block) + exit_vars = block_var_detail[0]["backward"]["exit"] + add_vars_by_var_list(exit_vars, origin_program, program, heter_block) + return heter_block + + def replace_ops_by_communicate_op(program, config, heter_block_index, ops_list, block_var_detail): all_op = program.global_block().ops @@ -683,37 +991,44 @@ def replace_ops_by_communicate_op(program, config, heter_block_index, ops_list, assert first_op_idx != -1 delete_same_ops(program.global_block(), ops_list) - mode = config.get_distributed_mode() - heter_worker_endpoint = config.get_heter_worker_endpoints() - entrance_var = block_var_detail[heter_block_index]["entrance"] - exit_var = block_var_detail[heter_block_index]["exit"] + entrance_var = [] - comm_info = get_communicate_var_info(program, heter_block_index, - entrance_var, exit_var) + if heter_block_index == 1: + mode = config.get_distributed_mode() + next_heter_worker_endpoints = config.get_next_stage_trainers() - program.global_block()._insert_op( - index=first_op_idx, - type="send_and_recv", - inputs={"X": program.global_block().vars[entrance_var[0]]}, - outputs={"Out": program.global_block().vars[exit_var[0]]}, - attrs={ - "send_var_name": entrance_var, - "recv_var_name": exit_var, - "message_name": comm_info["block_input_var_name"], - "endpoints": heter_worker_endpoint, - "trainer_id": config.get_role_id(), - RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE - }) + entrance_var = block_var_detail[heter_block_index]["forward"][ + "entrance"] + + comm_info = get_communicate_var_info(program, heter_block_index + 1, + entrance_var) + program.global_block()._insert_op( + index=first_op_idx, + type="send_and_recv", + inputs={"X": program.global_block().vars[entrance_var[0]]}, + outputs={"Out": []}, + attrs={ + "mode": "forward", + "send_var_name": entrance_var + ["microbatch_id"], + "recv_var_name": [], + "message_name": comm_info["block_input_var_name"], + "next_endpoints": next_heter_worker_endpoints, + "previous_endpoints": [], + "trainer_id": config.get_role_id(), + RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE + }) - return entrance_var + exit_var + return entrance_var def remove_trainer_send_op(program, config, heter_block_index, - block_var_detaile): + block_var_detail): + # if trainer do FF->BP->SEND, it has follow vars: var, var@GRAD # if trainer only do SEND, it has one var: var@GRAD # Delete Send op ,if trainer doesn't has pair var (var<->var@GRAD) - persistables = block_var_detaile[heter_block_index]["persistables"] + persistables = block_var_detail[heter_block_index]["forward"]["persistables"] + \ + block_var_detail[heter_block_index]["backward"]["persistables"] need_remove_send_op = [] need_remove_grad_var = [] for op in find_send_op(program): @@ -749,7 +1064,7 @@ def add_heter_send_op(program, heter_program, block, block_var_detail): send_grad_var_list = [] send_op_dict = _get_send_op_dict() table_dict = {} - for persistable_var in block_var_detail["persistables"]: + for persistable_var in block_var_detail["backward"]["persistables"]: # check var_name == var@GRAD if "@GRAD" not in persistable_var: continue @@ -798,18 +1113,21 @@ def find_send_op(program): return send_op_list -def get_communicate_var_info(program, block_index, entrance_var_list, - exit_var_list): +def get_communicate_var_info(program, + block_index, + entrance_var_list, + type="forward"): input_var_reshape_dim = [] input_var_reshape_name = [] - block_input_var_name = "joint_{}_{}@Heter".format(block_index - 1, - block_index) - output_var_reshape_dim = [] - output_var_reshape_name = [] - block_output_var_name = "joint_{}_{}@Heter".format(block_index, - block_index + 1) + + if type == "forward": + block_input_var_name = "forward_joint_{}_{}@Heter".format( + block_index - 1, block_index) + else: + block_input_var_name = "backward_joint_{}_{}@Heter".format( + block_index + 1, block_index) + entrance_var_list.sort() - exit_var_list.sort() # input # Heter_SERVER_BLOCK_index@JOINT_VAR -> slice -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> reshape -> var for name in entrance_var_list: @@ -825,30 +1143,95 @@ def get_communicate_var_info(program, block_index, entrance_var_list, # output # var -> reshape -> var@Heter_SERVER_BLOCK@INPUT_RESHAPE_VAR -> concat -> Heter_SERVER_BLOCK_index@JOINT_VAR - for var_name in exit_var_list: - var = program.global_block().vars[var_name] - shape = var.shape - # if len(shape) < 2 or shape[0] != -1: - # raise ValueError( - # "Variable {} not support heter training. its shape is {}". - # format(var_name, shape)) - send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape) - output_var_reshape_dim.append(send_reshape_dim) - output_var_reshape_name.append("{}.output_reshape@Heter".format( - var_name)) + #for var_name in exit_var_list: + # var = program.global_block().vars[var_name] + # shape = var.shape + # # if len(shape) < 2 or shape[0] != -1: + # # raise ValueError( + # # "Variable {} not support heter training. its shape is {}". + # # format(var_name, shape)) + # send_reshape_dim = -1 * reduce(lambda x, y: x * y, shape) + # output_var_reshape_dim.append(send_reshape_dim) + # output_var_reshape_name.append("{}.output_reshape@Heter".format( + # var_name)) info = { "input_var_reshape_dim": input_var_reshape_dim, "input_var_reshape_name": input_var_reshape_name, "block_input_var_name": block_input_var_name, - "output_var_reshape_dim": output_var_reshape_dim, - "output_var_reshape_name": output_var_reshape_name, - "block_output_var_name": block_output_var_name + # "output_var_reshape_dim": output_var_reshape_dim, + # "output_var_reshape_name": output_var_reshape_name, + # "block_output_var_name": block_output_var_name } return info +def union_forward_gradient_op(program_block_ops_list): + """ + before analyzing the input & output of each block in program_block_list, we should + union the forward op and corresponding gradient op to elimincate the uneccessary variable + transmit + """ + """ + fix for 2emb model, re-place sum op + + """ + block_length = len(program_block_ops_list) + ''' + ## get the final part + final_part_idx = -1 + for i in range(block_length): + op_list = program_block_ops_list[i] + for op in op_list: + if "_grad" in op.type: + final_part_idx = i + break + if final_part_idx != -1: + break + + ## eliminate wrong partition because of sum op + ## lookup_table_v2_grad + ## every looup_table_v2_grad op block should follow a sum op + var2idx = {} + + for i in range(final_part_idx, block_length): + op_list = program_block_ops_list[i] + for j in range(len(op_list) - 1, -1, -1): + op = op_list[j] + #if op.type == "lookup_table_v2_grad": + # if j < len(op_list) - 1): + # else: + # ## get var and record place + if _grad in op.type: + forward_op_type = op.type.split("_grad")[0] + if forward_op_type in SPARSE_OP_TYPE_DICT.keys() \ + and op.attr('remote_prefetch') is True: + param_name = op.input(SPARSE_OP_TYPE_DICT[forward_op_type])[0] + + var2idx[] = [i,j] ## + + ''' + + union_program_block_ops_list = [] + assert block_length % 2 != 0, "the length of program_block_ops_list should be odd" + for i in range(0, block_length // 2): + block_op_list = {"forward": program_block_ops_list[i]} + block_op_list.update({ + "backward": program_block_ops_list[block_length - 1 - i] + }) + union_program_block_ops_list.append(block_op_list) + + block_op_list = {"forward": [], "backward": []} + for op in program_block_ops_list[block_length // 2]: + if not "_grad" in op.type and not (op.type == "sum"): + block_op_list["forward"].append(op) + else: + block_op_list["backward"].append(op) + union_program_block_ops_list.append(block_op_list) + return union_program_block_ops_list + + def find_block_joints(program, program_block_ops_list, heter_ops): block_var_detail = find_entrance_exit_private(program, program_block_ops_list) @@ -856,6 +1239,7 @@ def find_block_joints(program, program_block_ops_list, heter_ops): block_var_detail, heter_ops) block_var_detail = delete_block_useless_exit( program, program_block_ops_list, block_var_detail) + return block_var_detail @@ -863,8 +1247,9 @@ def find_entrance_exit_private(program, program_block_ops_list): block_var_detail = [] persistables = [] for index, block_op_list in enumerate(program_block_ops_list): - block_input, block_output = find_ops_list_input_output(program, - block_op_list) + ## forward + block_input, block_output = find_ops_list_input_output( + program, block_op_list["forward"]) persistables = screen_persistables( program, block_input) + screen_persistables(program, block_output) # find entrance & exit @@ -872,11 +1257,33 @@ def find_entrance_exit_private(program, program_block_ops_list): block_entrance = list(set(block_input) - set(block_private_vars)) block_exit = list(set(block_output) - set(block_private_vars)) detail = { - "entrance": block_entrance, - "exit": block_exit, - "private": block_private_vars, - "persistables": persistables + "forward": { + "entrance": block_entrance, + "exit": block_exit, + "private": block_private_vars, + "persistables": persistables + } } + + ## backward + bp_block_input, bp_block_output = find_ops_list_input_output( + program, block_op_list["backward"]) + bp_persistables = screen_persistables( + program, bp_block_input) + screen_persistables(program, + bp_block_output) + # find entrance & exit + bp_block_private_vars = list(set(bp_block_input) & set(bp_block_output)) + bp_block_entrance = list( + set(bp_block_input) - set(bp_block_private_vars)) + bp_block_exit = list(set(bp_block_output) - set(bp_block_private_vars)) + detail.update({ + "backward": { + "entrance": bp_block_entrance, + "exit": bp_block_exit, + "private": bp_block_private_vars, + "persistables": bp_persistables + } + }) block_var_detail.append(detail) return block_var_detail @@ -886,20 +1293,64 @@ def entrance_exit_check(program, program_block_ops_list, block_var_detail, for index in range(len(block_var_detail) - 1, -1, -1): if index - 1 < 0: break - previous_block_exit = block_var_detail[index - 1]["exit"] + previous_block_exit = block_var_detail[index - 1]["forward"]["exit"] previous_block_exit.sort() - current_block_entrance = block_var_detail[index]["entrance"] + current_block_entrance = block_var_detail[index]["forward"]["entrance"] + + backward_entrance = block_var_detail[index]["backward"]["entrance"] + + forward_all = block_var_detail[index]["forward"][ + "entrance"] + block_var_detail[index]["forward"][ + "private"] + block_var_detail[index]["forward"]["exit"] + + for var in backward_entrance: + if not ("@GRAD" in var) and not (var in forward_all): + current_block_entrance.append(var) + current_block_entrance.sort() + if previous_block_exit == current_block_entrance: continue exist_vars = list( set(previous_block_exit) & set(current_block_entrance)) need_add_vars = list(set(current_block_entrance) - set(exist_vars)) - need_add_vars = find_need_var_from_previous_block( - need_add_vars, block_var_detail, index, heter_ops) + # var in different stage should not be ignored, since they are not placed in the same program & device + #need_add_vars = find_need_var_from_previous_block( + # need_add_vars, block_var_detail, index, heter_ops) + + previous_block_private = block_var_detail[index - 1]["forward"][ + "private"] + previous_block_entrance = block_var_detail[index - 1]["forward"][ + "entrance"] + for var in need_add_vars: + if var not in previous_block_private and var not in previous_block_entrance: + previous_block_entrance.append(var) + previous_block_exit.append(var) + if not var in current_block_entrance: + current_block_entrance.append(var) - previous_block_private = block_var_detail[index - 1]["private"] - previous_block_entrance = block_var_detail[index - 1]["entrance"] + for index in range(0, len(block_var_detail) - 1, 1): + previous_block_exit = block_var_detail[index + 1]["backward"]["exit"] + previous_block_exit.sort() + current_block_entrance = block_var_detail[index]["backward"]["entrance"] + + current_block_entrance.sort() + + if previous_block_exit == current_block_entrance: + continue + exist_vars = list( + set(previous_block_exit) & set(current_block_entrance)) + need_add_vars = list(set(current_block_entrance) - set(exist_vars)) + need_ignore_vars = [] + for var in need_add_vars: + if not "@GRAD" in var: + need_ignore_vars.append(var) + need_add_vars = list( + set(need_add_vars).difference(set(need_ignore_vars))) + previous_block_private = block_var_detail[index + 1]["backward"][ + "private"] + previous_block_entrance = block_var_detail[index + 1]["backward"][ + "entrance"] for var in need_add_vars: if var not in previous_block_private and var not in previous_block_entrance: previous_block_entrance.append(var) @@ -915,7 +1366,8 @@ def find_need_var_from_previous_block(need_add_vars, block_var_detail, index_device_map[index] = DEFAULT_DEVICE for device in heter_ops: for index in heter_ops[device].keys(): - index_device_map[index] = device + if index < len(block_var_detail): + index_device_map[index] = device pre_index = current_index - 1 need_ignore_var = [] @@ -941,11 +1393,12 @@ def find_need_var_from_previous_block(need_add_vars, block_var_detail, def delete_block_useless_exit(program, program_block_ops_list, block_var_detail): + ## forward for index in range(len(block_var_detail)): if index == len(block_var_detail) - 1: break - current_block_exit = block_var_detail[index]["exit"] - next_block_entrance = block_var_detail[index + 1]["entrance"] + current_block_exit = block_var_detail[index]["forward"]["exit"] + next_block_entrance = block_var_detail[index + 1]["forward"]["entrance"] need_delete_var = [] for var in current_block_exit: if var not in next_block_entrance: @@ -953,6 +1406,19 @@ def delete_block_useless_exit(program, program_block_ops_list, for var in need_delete_var: current_block_exit.remove(var) + ## backward + for index in range(len(block_var_detail) - 1, -1, -1): + if index - 1 < 0: + break + current_block_exit = block_var_detail[index]["backward"]["exit"] + next_block_entrance = block_var_detail[index - 1]["backward"][ + "entrance"] + need_delete_var = [] + for var in current_block_exit: + if var not in next_block_entrance: + need_delete_var.append(var) + for var in need_delete_var: + current_block_exit.remove(var) return block_var_detail @@ -966,6 +1432,8 @@ def screen_persistables(program, var_list): need_remove = [] for var_name in var_list: if "@GRAD" in var_name: + if "GRAD" != var_name.split("@")[-1]: + continue origin_var_name = var_name.split("@GRAD")[0] var = program.global_block().vars[origin_var_name] else: @@ -1070,27 +1538,40 @@ def insert_recv_slice_op(program, block, index, var_name, var_shape, dtype, index += 1 -def deleter_trainer_useless_var(config, program, static_var): - if config.role_maker._is_first_worker(): - return [] +def add_heter_trainer_useful_vars(config, program, heter_program, heter_block, + static_var): static_var = list(set(static_var)) - porgram_useful_var_list = [] + for var_name in static_var: + if var_name not in heter_program.global_block( + ).vars and var_name not in heter_block.vars: + var = program.global_block().vars[var_name] + if var.persistable: + heter_program.global_block()._clone_variable( + var, force_persistable=False) + else: + heter_block._clone_variable(var, force_persistable=False) + + +def delete_trainer_useless_var(config, program, static_var): + static_var = list(set(static_var)) + program_useful_var_list = [] for op in program.global_block().ops: input_var_list, output_var_list = find_op_input_output( program, program.global_block(), op) op_var_list = list(set(input_var_list).union(set(output_var_list))) - porgram_useful_var_list = list( - set(porgram_useful_var_list).union(set(op_var_list))) - porgram_useful_var_list += static_var + program_useful_var_list = list( + set(program_useful_var_list).union(set(op_var_list))) + program_useful_var_list += static_var program_useless_var_list = list( set(get_vars_name_in_block(program.global_block())).difference( - set(porgram_useful_var_list))) + set(program_useful_var_list))) for var in program_useless_var_list: program.global_block()._remove_var(var) return program_useless_var_list def block_append_op(program, origin_program, block, op): + merge_ordereddict = origin_program.global_block().vars.copy() merge_ordereddict.update(block.vars) inputs = _get_input_map_from_op(merge_ordereddict, op) @@ -1144,7 +1625,8 @@ def block_append_op(program, origin_program, block, op): def add_vars_by_var_list(var_name_list, origin_program, program, block): for var_name in var_name_list: - if var_name not in program.global_block().vars: + if var_name not in program.global_block( + ).vars and var_name not in block.vars: var = origin_program.global_block().vars[var_name] if var.persistable: program.global_block()._clone_variable( diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_heter_pipeline_ctr.py similarity index 51% rename from python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py rename to python/paddle/fluid/tests/unittests/dist_fleet_heter_pipeline_ctr.py index 26b43f46ac6610c20fb9e5502cdbcd22755d2357..c6c2537b42c18ad37805f41a27897459a4a529c7 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_heter_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_heter_pipeline_ctr.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ fluid.default_startup_program().random_seed = 1 fluid.default_main_program().random_seed = 1 -class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): +class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase): """ For test CTR model, using Fleet api """ @@ -54,58 +54,53 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): """ dnn_input_dim, lr_input_dim = int(1e5), int(1e5) - dnn_data = fluid.layers.data( - name="dnn_data", - shape=[-1, 1], - dtype="int64", - lod_level=1, - append_batch_size=False) - lr_data = fluid.layers.data( - name="lr_data", - shape=[-1, 1], - dtype="int64", - lod_level=1, - append_batch_size=False) - label = fluid.layers.data( - name="click", - shape=[-1, 1], - dtype="float32", - lod_level=0, - append_batch_size=False) - - datas = [dnn_data, lr_data, label] - - if args.reader == "pyreader": - self.reader = fluid.io.PyReader( - feed_list=datas, - capacity=64, - iterable=False, - use_double_buffer=False) - - # build dnn model - dnn_layer_dims = [128, 64, 32, 1] - dnn_embedding = fluid.layers.embedding( - is_distributed=False, - input=dnn_data, - size=[dnn_input_dim, dnn_layer_dims[0]], - param_attr=fluid.ParamAttr( - name="deep_embedding", - initializer=fluid.initializer.Constant(value=0.01)), - is_sparse=True) - dnn_pool = fluid.layers.sequence_pool( - input=dnn_embedding, pool_type="sum") - dnn_out = dnn_pool - - # build lr model - lr_embbding = fluid.layers.embedding( - is_distributed=False, - input=lr_data, - size=[lr_input_dim, 1], - param_attr=fluid.ParamAttr( - name="wide_embedding", - initializer=fluid.initializer.Constant(value=0.01)), - is_sparse=True) - lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum") + with fluid.device_guard("cpu"): + dnn_data = fluid.layers.data( + name="dnn_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + lr_data = fluid.layers.data( + name="lr_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + label = fluid.layers.data( + name="click", + shape=[-1, 1], + dtype="float32", + lod_level=0, + append_batch_size=False) + + datas = [dnn_data, lr_data, label] + + # build dnn model + dnn_layer_dims = [128, 64, 32, 1] + dnn_embedding = fluid.layers.embedding( + is_distributed=False, + input=dnn_data, + size=[dnn_input_dim, dnn_layer_dims[0]], + param_attr=fluid.ParamAttr( + name="deep_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=True) + dnn_pool = fluid.layers.sequence_pool( + input=dnn_embedding, pool_type="sum") + dnn_out = dnn_pool + + # build lr model + lr_embbding = fluid.layers.embedding( + is_distributed=False, + input=lr_data, + size=[lr_input_dim, 1], + param_attr=fluid.ParamAttr( + name="wide_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=True) + lr_pool = fluid.layers.sequence_pool( + input=lr_embbding, pool_type="sum") with fluid.device_guard("gpu"): for i, dim in enumerate(dnn_layer_dims[1:]): @@ -118,6 +113,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): name='dnn-fc-%d' % i) dnn_out = fc + with fluid.device_guard("cpu"): merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1) label = fluid.layers.cast(label, dtype="int64") predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax') @@ -143,59 +139,33 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): with open(os.path.join(dirname, "__model__.proto"), "w") as wn: wn.write(str(program)) - def do_pyreader_training(self, fleet): - """ - do training using dataset, using fetch handler to catch variable - Args: - fleet(Fleet api): the fleet object of Parameter Server, define distribute training role - """ - - exe = fluid.Executor(fluid.CPUPlace()) - exe.run(fluid.default_startup_program()) - fleet.init_worker() - - batch_size = 4 - train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) - self.reader.decorate_sample_list_generator(train_reader) - - for epoch_id in range(1): - self.reader.start() - try: - pass_start = time.time() - while True: - exe.run(program=fluid.default_main_program()) - - pass_time = time.time() - pass_start - except fluid.core.EOFException: - self.reader.reset() - - if fleet.is_first_worker(): - model_path = tempfile.mkdtemp() - fleet.save_persistables(executor=exe, dirname=model_path) - shutil.rmtree(model_path) - def do_dataset_training(self, fleet): + train_file_list = ctr_dataset_reader.prepare_fake_data() exe = fluid.Executor(fluid.CPUPlace()) + real_program = fluid.default_main_program()._heter_pipeline_opt[ + "section_program"] + print(real_program) exe.run(fluid.default_startup_program()) fleet.init_worker() thread_num = int(os.getenv("CPU_NUM", 2)) batch_size = 128 + filelist = fleet.util.get_file_shard(train_file_list) print("filelist: {}".format(filelist)) # config dataset - dataset = paddle.distributed.QueueDataset() - dataset._set_batch_size(batch_size) - dataset._set_use_var(self.feeds) - pipe_command = 'python ctr_dataset_reader.py' - dataset._set_pipe_command(pipe_command) + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_batch_size(batch_size) + dataset.set_use_var(self.feeds) + pipe_command = 'python3 ctr_dataset_reader.py' + dataset.set_pipe_command(pipe_command) dataset.set_filelist(filelist) - dataset._set_thread(thread_num) + dataset.set_thread(thread_num) for epoch_id in range(1): pass_start = time.time() @@ -209,7 +179,44 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): debug=int(os.getenv("Debug", "0"))) pass_time = time.time() - pass_start print("do_dataset_training done. using time {}".format(pass_time)) + exe.close() + + def do_dataset_heter_training(self, fleet): + + exe = fluid.Executor() + exe.run(fluid.default_startup_program()) + fleet.init_worker() + real_program = fluid.default_main_program()._heter_pipeline_opt[ + "section_program"] + print(real_program) + + thread_num = int(os.getenv("CPU_NUM", 2)) + batch_size = 128 + + pass_start = time.time() + exe.train_from_dataset( + program=fluid.default_main_program(), + fetch_list=[self.avg_cost], + fetch_info=["cost"], + print_period=2, + debug=int(os.getenv("Debug", "0"))) + exe.close() + pass_time = time.time() - pass_start + print("do_dataset_heter_training done. using time {}".format(pass_time)) + + #for epoch_id in range(1): + # pass_start = time.time() + # dataset.set_filelist(filelist) + # exe.train_from_dataset( + # program=fluid.default_main_program(), + # dataset=dataset, + # fetch_list=[self.avg_cost], + # fetch_info=["cost"], + # print_period=2, + # debug=int(os.getenv("Debug", "0"))) + # pass_time = time.time() - pass_start + # print("do_dataset_heter_training done. using time {}".format(pass_time)) if __name__ == "__main__": - runtime_main(TestHeterPsCTR2x2) + runtime_main(TestHeterPipelinePsCTR2x2) diff --git a/python/paddle/fluid/tests/unittests/fleet_heter_ps_training.py b/python/paddle/fluid/tests/unittests/fleet_heter_ps_training.py new file mode 100644 index 0000000000000000000000000000000000000000..2b4ae3d60dd78a4fef8a72be6f7bd6bf22b959f7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/fleet_heter_ps_training.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.fluid as fluid +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker + +fluid.disable_dygraph() + + +def get_dataset(inputs): + dataset = fluid.DatasetFactory().create_dataset() + dataset.set_use_var(inputs) + dataset.set_batch_size(1) + dataset.set_filelist([]) + dataset.set_thread(1) + return dataset + + +def net(batch_size=4, lr=0.01): + """ + network definition + + Args: + batch_size(int): the size of mini-batch for training + lr(float): learning rate of training + Returns: + avg_cost: LoDTensor of cost. + """ + dnn_input_dim, lr_input_dim = int(2), int(2) + + with fluid.device_guard("cpu"): + dnn_data = fluid.layers.data( + name="dnn_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + lr_data = fluid.layers.data( + name="lr_data", + shape=[-1, 1], + dtype="int64", + lod_level=1, + append_batch_size=False) + label = fluid.layers.data( + name="click", + shape=[-1, 1], + dtype="float32", + lod_level=0, + append_batch_size=False) + + datas = [dnn_data, lr_data, label] + + # build dnn model + dnn_layer_dims = [2, 1] + dnn_embedding = fluid.layers.embedding( + is_distributed=False, + input=dnn_data, + size=[dnn_input_dim, dnn_layer_dims[0]], + param_attr=fluid.ParamAttr( + name="deep_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=True) + dnn_pool = fluid.layers.sequence_pool( + input=dnn_embedding, pool_type="sum") + dnn_out = dnn_pool + + # build lr model + lr_embbding = fluid.layers.embedding( + is_distributed=False, + input=lr_data, + size=[lr_input_dim, 1], + param_attr=fluid.ParamAttr( + name="wide_embedding", + initializer=fluid.initializer.Constant(value=0.01)), + is_sparse=True) + lr_pool = fluid.layers.sequence_pool(input=lr_embbding, pool_type="sum") + + with fluid.device_guard("gpu"): + for i, dim in enumerate(dnn_layer_dims[1:]): + fc = fluid.layers.fc( + input=dnn_out, + size=dim, + act="relu", + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01)), + name='dnn-fc-%d' % i) + dnn_out = fc + + merge_layer = fluid.layers.concat(input=[dnn_out, lr_pool], axis=1) + label = fluid.layers.cast(label, dtype="int64") + predict = fluid.layers.fc(input=merge_layer, size=2, act='softmax') + + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + return datas, avg_cost + + +''' +optimizer = fluid.optimizer.Adam(learning_rate=0.01) + +role = role_maker.PaddleCloudRoleMaker() +fleet.init(role) + +strategy = paddle.distributed.fleet.DistributedStrategy() +strategy.a_sync = True +strategy.a_sync_configs = {"heter_worker_device_guard": 'gpu'} + +strategy.pipeline = True +strategy.pipeline_configs = {"accumulate_steps": 1, "micro_batch_size": 2048} +feeds, avg_cost = net() +optimizer = fleet.distributed_optimizer(optimizer, strategy) +optimizer.minimize(avg_cost) +dataset = get_dataset(feeds) +''' + +if fleet.is_server(): + pass + #fleet.init_server() + #fleet.run_server() +elif fleet.is_heter_worker(): + pass + #fleet.init_heter_worker() + #fleet.run_heter_worker(dataset=dataset) + fleet.stop_worker() +elif fleet.is_worker(): + pass + #place = fluid.CPUPlace() + #exe = fluid.Executor(place) + #exe.run(fluid.default_startup_program()) + #fleet.init_worker() + #step = 1 + #for i in range(step): + # exe.train_from_dataset( + # program=fluid.default_main_program(), dataset=dataset, debug=False) + #exe.close() + #fleet.stop_worker() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py index b77cfb095f063c976f77b7730f8ed29e7ee6bb4f..6111d40c7d640b2c8bf89f2c7bdb37ff03d41d9c 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_base.py @@ -52,27 +52,74 @@ class FleetDistHeterRunnerBase(object): def build_role(self, args): environs = {} - environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints - environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints - environs[ - "PADDLE_HETER_TRAINER_IP_PORT_LIST"] = args.heter_trainer_endpoints - environs["PADDLE_HETER_TRAINER_DEVICE"] = args.heter_trainer_device - environs["TRAINING_ROLE"] = args.role.upper() - environs["PADDLE_TRAINERS_NUM"] = args.trainers - environs["PADDLE_TRAINER_ID"] = args.current_id + heter_trainer_endpoints = args.heter_trainer_endpoints.split(";") + all_heter_trainer_endpoints = ",".join(heter_trainer_endpoints) if args.role.upper() == "PSERVER": + environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints + environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints + environs[ + "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST"] = all_heter_trainer_endpoints environs["POD_IP"] = args.endpoints.split(",")[int( args.current_id)].split(":")[0] environs["PADDLE_PORT"] = args.endpoints.split(",")[int( args.current_id)].split(":")[1] + environs["TRAINING_ROLE"] = args.role.upper() + environs["PADDLE_TRAINERS_NUM"] = args.trainers elif args.role.upper() == "HETER_TRAINER": - environs["POD_IP"] = args.heter_trainer_endpoints.split(",")[int( + previous_endpoints = args.trainer_endpoints if args.stage_id == 2 else heter_trainer_endpoints[ + 0] + next_endpoints = heter_trainer_endpoints[ + 1] if args.stage_id == 2 else "" + heter_device = args.heter_trainer_device.split(";")[args.stage_id - + 2] + environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints + environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints + environs["PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST"] = next_endpoints + environs[ + "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"] = previous_endpoints + environs[ + "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST"] = all_heter_trainer_endpoints + environs["HETER_DEVICE_TYPE"] = heter_device + environs["TRAINING_ROLE"] = args.role.upper() + environs["POD_IP"] = all_heter_trainer_endpoints.split(",")[int( args.current_id)].split(":")[0] - environs["PADDLE_PORT"] = args.heter_trainer_endpoints.split(",")[ + environs["PADDLE_PORT"] = all_heter_trainer_endpoints.split(",")[ int(args.current_id)].split(":")[1] - environs["FLAGS_selected_gpus"] = args.current_id + environs["PADDLE_TRAINERS_NUM"] = args.trainers + environs["PADDLE_STAGE_TRAINERS_NUM"] = [2, 2, 2] + environs["FLAGS_selected_gpus"] = 0 + environs["FLAGS_selected_xpus"] = 0 + environs["CUDA_VISIBLE_DEVICES"] = 0 + environs["XPU_VISIBLE_DEVICES"] = 0 + environs["STAGE_ID"] = args.stage_id + environs["STAGE_NUM"] = 3 + elif args.role.upper() == "TRAINER": + environs["PADDLE_PSERVERS_IP_PORT_LIST"] = args.endpoints + environs["PADDLE_TRAINER_ENDPOINTS"] = args.trainer_endpoints + environs[ + "PADDLE_NEXT_HETER_TRAINER_IP_PORT_LIST"] = heter_trainer_endpoints[ + 0] + environs["PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"] = "" + environs[ + "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST"] = all_heter_trainer_endpoints + environs["HETER_DEVICE_TYPE"] = "cpu" + environs["TRAINING_ROLE"] = args.role.upper() + environs["PADDLE_TRAINER_ID"] = args.current_id + environs["POD_IP"] = args.trainer_endpoints.split(",")[int( + args.current_id)].split(":")[0] + environs["PADDLE_PORT"] = args.trainer_endpoints.split(",")[int( + args.current_id)].split(":")[1] + environs["PADDLE_TRAINERS_NUM"] = args.trainers + environs["PADDLE_STAGE_TRAINERS_NUM"] = [2, 2, 2] + environs["FLAGS_selected_gpus"] = 0 + environs["FLAGS_selected_xpus"] = 0 + environs["CUDA_VISIBLE_DEVICES"] = 0 + environs["XPU_VISIBLE_DEVICES"] = 0 + environs["STAGE_ID"] = 1 + environs["STAGE_NUM"] = 3 for k, v in environs.items(): + print(k, v) os.environ[k] = str(v) self.role = role_maker.PaddleCloudRoleMaker() @@ -85,6 +132,11 @@ class FleetDistHeterRunnerBase(object): "launch_barrier": True, "heter_worker_device_guard": 'gpu' } + self.strategy.pipeline = True + self.strategy.pipeline_configs = { + "accumulate_steps": 1, + "micro_batch_size": 2048 + } return self.strategy def build_optimizer(self, avg_cost, strategy): @@ -96,12 +148,12 @@ class FleetDistHeterRunnerBase(object): fleet.init_server() fleet.run_server() + def run_dataset_heter_trainer(self, args): + out = self.do_dataset_heter_training(fleet) + def run_dataset_trainer(self, args): out = self.do_dataset_training(fleet) - def run_pyreader_trainer(self, args): - out = self.do_pyreader_training(fleet) - def net(self, args, batch_size=4, lr=0.01): raise NotImplementedError( "get_model should be implemented by child classes.") @@ -110,9 +162,9 @@ class FleetDistHeterRunnerBase(object): raise NotImplementedError( "do_dataset_training should be implemented by child classes.") - def do_pyreader_training(self, fleet): + def do_dataset_heter_training(self, fleet): raise NotImplementedError( - "do_pyreader_training should be implemented by child classes.") + "do_dataset_heter_training should be implemented by child classes.") class TestFleetHeterBase(unittest.TestCase): @@ -132,12 +184,12 @@ class TestFleetHeterBase(unittest.TestCase): self.startTime = time.time() self._mode = "async" - self._reader = "pyreader" + self._reader = "dataset" self._trainers = 2 self._pservers = 2 self._port_set = set() - self._heter_device = "gpu" + self._heter_device = "gpu;cpu" global DIST_UT_PORT if DIST_UT_PORT == 0 and os.getenv("PADDLE_DIST_UT_PORT"): @@ -151,7 +203,9 @@ class TestFleetHeterBase(unittest.TestCase): DIST_UT_PORT + 2, DIST_UT_PORT + 3) self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( DIST_UT_PORT + 4, DIST_UT_PORT + 5) - DIST_UT_PORT += 6 + self._heter_endpoints_2 = "127.0.0.1:%s,127.0.0.1:%s" % ( + DIST_UT_PORT + 6, DIST_UT_PORT + 7) + DIST_UT_PORT += 8 else: self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._find_free_port(), self._find_free_port()) @@ -159,6 +213,8 @@ class TestFleetHeterBase(unittest.TestCase): self._find_free_port(), self._find_free_port()) self._heter_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._find_free_port(), self._find_free_port()) + self._heter_endpoints_2 = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) self._python_interp = sys.executable self._geo_sgd_need_push_nums = 5 @@ -219,12 +275,17 @@ class TestFleetHeterBase(unittest.TestCase): return tr0_proc, tr1_proc, tr0_pipe, tr1_pipe def _start_heter_trainer(self, cmd, required_envs): - heter0_cmd, heter1_cmd = cmd.format(0), cmd.format(1) + heter0_cmd, heter1_cmd, heter2_cmd, heter3_cmd = cmd.format( + 0, 2), cmd.format(1, 2), cmd.format(2, 3), cmd.format(3, 3) heter0_pipe = open(tempfile.gettempdir() + "/heter0_err.log", "wb+") heter1_pipe = open(tempfile.gettempdir() + "/heter1_err.log", "wb+") + heter2_pipe = open(tempfile.gettempdir() + "/heter2_err.log", "wb+") + heter3_pipe = open(tempfile.gettempdir() + "/heter3_err.log", "wb+") heter0_out = open(tempfile.gettempdir() + "/heter0_out.log", "wb+") heter1_out = open(tempfile.gettempdir() + "/heter1_out.log", "wb+") + heter2_out = open(tempfile.gettempdir() + "/heter2_out.log", "wb+") + heter3_out = open(tempfile.gettempdir() + "/heter3_out.log", "wb+") heter0_proc = subprocess.Popen( heter0_cmd.strip().split(" "), @@ -236,8 +297,18 @@ class TestFleetHeterBase(unittest.TestCase): stdout=heter1_out, stderr=heter1_pipe, env=required_envs) + heter2_proc = subprocess.Popen( + heter2_cmd.strip().split(" "), + stdout=heter2_out, + stderr=heter2_pipe, + env=required_envs) + heter3_proc = subprocess.Popen( + heter3_cmd.strip().split(" "), + stdout=heter3_out, + stderr=heter3_pipe, + env=required_envs) - return heter0_proc, heter1_proc, heter0_pipe, heter1_pipe + return heter0_proc, heter1_proc, heter2_proc, heter3_proc, heter0_pipe, heter1_pipe, heter2_pipe, heter3_pipe def _run_cluster(self, model, envs): env = { @@ -251,26 +322,31 @@ class TestFleetHeterBase(unittest.TestCase): envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') python_path += " -m coverage run --branch -p" env.update(envs) + self._all_heter_endpoints = ";".join( + (self._heter_endpoints, self._heter_endpoints_2)) tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( python_path, model, self._ps_endpoints, self._tr_endpoints, self._trainers, self._mode, self._geo_sgd_need_push_nums, - self._reader, gloo_path, self._heter_endpoints, self._heter_device) + self._reader, gloo_path, self._all_heter_endpoints, + self._heter_device) ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( python_path, model, self._ps_endpoints, self._tr_endpoints, self._trainers, self._mode, self._geo_sgd_need_push_nums, - self._reader, gloo_path, self._heter_endpoints, self._heter_device) + self._reader, gloo_path, self._all_heter_endpoints, + self._heter_device) - heter_cmd = "{0} {1} --role heter_trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( + heter_cmd = "{0} {1} --role heter_trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --stage_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8} --heter_trainer_endpoints {9} --heter_trainer_device {10}".format( python_path, model, self._ps_endpoints, self._tr_endpoints, self._trainers, self._mode, self._geo_sgd_need_push_nums, - self._reader, gloo_path, self._heter_endpoints, self._heter_device) + self._reader, gloo_path, self._all_heter_endpoints, + self._heter_device) # Run dist train to compare with local results ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env) tr0, tr1, tr0_pipe, tr1_pipe = self._start_trainer(tr_cmd, env) - heter0, heter1, heter0_pipe, heter1_pipe = self._start_heter_trainer( + heter0, heter1, heter2, heter3, heter0_pipe, heter1_pipe, heter2_pipe, heter3_pipe = self._start_heter_trainer( heter_cmd, env) # Wait until trainer process terminate @@ -300,11 +376,15 @@ class TestFleetHeterBase(unittest.TestCase): ps1_pipe.close() heter0_pipe.close() heter1_pipe.close() + heter2_pipe.close() + heter3_pipe.close() ps0.terminate() ps1.terminate() heter0.terminate() heter1.terminate() + heter2.terminate() + heter3.terminate() self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check") self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check") shutil.rmtree(gloo_path) @@ -349,6 +429,7 @@ def runtime_main(test_class): parser.add_argument('--gloo_path', type=str, required=False, default="") parser.add_argument('--current_id', type=int, required=False, default=0) parser.add_argument('--trainers', type=int, required=False, default=1) + parser.add_argument('--stage_id', type=int, required=False, default=1) parser.add_argument('--mode', type=str, required=False, default='async') parser.add_argument( '--geo_sgd_need_push_nums', type=int, required=False, default=2) @@ -362,11 +443,11 @@ def runtime_main(test_class): avg_cost = model.net(args) model.build_optimizer(avg_cost, strategy) - if args.role == "pserver" or args.role == "heter_trainer": + if args.role == "pserver": model.run_pserver(args) + elif args.role == "heter_trainer": + model.run_dataset_heter_trainer(args) + fleet.stop_worker() else: - if args.reader == "dataset": - model.run_dataset_trainer(args) - else: - model.run_pyreader_trainer(args) + model.run_dataset_trainer(args) fleet.stop_worker() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py index 5f7d7b21d7ff8da8699c2f55adcde954c1c0156d..2ed331c62842407157a228f6f9e336b86463d91f 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_ctr.py @@ -23,10 +23,10 @@ import paddle paddle.enable_static() -class TestDistHeterPyreaderAsync2x2(TestFleetHeterBase): +class TestDistHeterDatasetAsync2x2(TestFleetHeterBase): def _setup_config(self): self._mode = "async" - self._reader = "pyreader" + self._reader = "dataset" def check_with_place(self, model_file, @@ -39,20 +39,22 @@ class TestDistHeterPyreaderAsync2x2(TestFleetHeterBase): "LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""), "FLAGS_rpc_deadline": "5000", # 5sec to fail fast "http_proxy": "", - "CPU_NUM": "3" + "CPU_NUM": "2" } required_envs.update(need_envs) if check_error_log: - required_envs["GLOG_v"] = "3" + required_envs["GLOG_v"] = "4" required_envs["GLOG_logtostderr"] = "1" tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs) def test_dist_train(self): self.check_with_place( - "dist_fleet_heter_ctr.py", delta=1e-5, check_error_log=True) + "dist_fleet_heter_pipeline_ctr.py", + delta=1e-5, + check_error_log=True) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py index eed8d5f1a496ead4712cf792dec879612d167825..61f15e7dffff29e8ba9e7e7946bd4824ccf7f7ce 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_heter_program.py @@ -32,9 +32,15 @@ class TestDistFleetHeterProgram(unittest.TestCase): "PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36012,127.0.0.1:36013" environs["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36014,127.0.0.1:36015" environs[ - "PADDLE_HETER_TRAINER_IP_PORT_LIST"] = "127.0.0.1:36016,127.0.0.1:36017" + "PADDLE_ALL_HETER_TRAINER_IP_PORT_LIST"] = "127.0.0.1:36016,127.0.0.1:36017" + environs[ + "PADDLE_PREVIOUS_HETER_TRAINER_IP_PORT_LIST"] = "127.0.0.1:36014,127.0.0.1:36015" environs["PADDLE_HETER_TRAINER_DEVICE"] = "gpu" environs["TRAINING_ROLE"] = "HETER_TRAINER" + environs["STAGE_ID"] = 2 + environs["STAGE_NUM"] = 2 + environs["HETER_DEVICE_TYPE"] = "gpu" + environs["PADDLE_STAGE_TRAINERS_NUM"] = [2, 2] environs["PADDLE_TRAINERS_NUM"] = 2 environs["PADDLE_TRAINER_ID"] = 0 environs["POD_IP"] = "127.0.0.1" diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py index 64b8744472d395821643b6bb3c51f559ce8779e5..88e5ea20446bbe8e9fceb4d5dc38f2f99cac0fa6 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_2.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_2.py @@ -23,6 +23,7 @@ import paddle.fluid as fluid class TestFleetBase(unittest.TestCase): def setUp(self): os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_PORT"] = "36000" os.environ["PADDLE_TRAINERS_NUM"] = "2" os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ "127.0.0.1:36001,127.0.0.2:36001" diff --git a/python/paddle/fluid/tests/unittests/test_fleet_launch_ps.sh b/python/paddle/fluid/tests/unittests/test_fleet_launch_ps.sh index 0f28be614c085e46f99d94124cb9755b73948117..bfbaf258c86b40dbc30e4737302a0a4edb72aadb 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_launch_ps.sh +++ b/python/paddle/fluid/tests/unittests/test_fleet_launch_ps.sh @@ -16,7 +16,7 @@ set -e -server_port_00=${PADDLE_DIST_UT_PORT} +server_port_00=$(( PADDLE_DIST_UT_PORT )) server_port_10=$(( PADDLE_DIST_UT_PORT + 1 )) worker_port_00=$(( PADDLE_DIST_UT_PORT + 2 )) worker_port_10=$(( PADDLE_DIST_UT_PORT + 3 )) @@ -30,12 +30,11 @@ heter_worker_port_0=$(( PADDLE_DIST_UT_PORT + 8 )) heter_worker_port_1=$(( PADDLE_DIST_UT_PORT + 9 )) function test_launch_ps(){ - python -m paddle.distributed.fleet.launch \ --servers="127.0.0.1:${server_port_00},127.0.0.1:${server_port_10}" \ --workers="127.0.0.1:${worker_port_00},127.0.0.1:${worker_port_10}" \ - fleet_ps_training.py 2> ut.elog - if grep -q "server are killed" ut.elog; then + fleet_ps_training.py 2> ut1.elog + if grep -q "server are killed" ut1.elog; then echo "test pserver launch succeed" else echo "test pserver launch failed" @@ -48,11 +47,13 @@ function test_launch_ps_heter(){ --servers="127.0.0.1:${server_port_01},127.0.0.1:${server_port_11}" \ --workers="127.0.0.1:${worker_port_01},127.0.0.1:${worker_port_11}" \ --heter_workers="127.0.0.1:${heter_worker_port_0},127.0.0.1:${heter_worker_port_1}" \ - fleet_ps_training.py 2> ut.elog - if grep -q "server are killed" ut.elog; then - echo "test heter pserver launch succeed" + --heter_devices="gpu" \ + --heter_worker_num="2" \ + fleet_heter_ps_training.py 2> ut2.elog + if grep -q "server are killed" ut2.elog; then + echo "test heter trainer launch succeed" else - echo "test pserver launch failed" + echo "test heter trainer launch failed" exit -1 fi } diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index 6152bce55ce9f23e3e8ba7fee1fd851c71bb592b..39320f5c0acf3b4bbad9c2e46c4789d8adc50504 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -17,7 +17,7 @@ import sys import os __all__ = [ 'TrainerDesc', 'MultiTrainer', 'DistMultiTrainer', 'PipelineTrainer', - 'HeterXpuTrainer' + 'HeterXpuTrainer', 'HeterPipelineTrainer' ] @@ -118,6 +118,13 @@ class TrainerDesc(object): def _set_program(self, program): self._program = program + def _set_trainer_id(self, trainer_id): + self.proto_desc.trainer_id = trainer_id + + def _set_trainers(self, trainers): + for trainer_num in trainers: + self.proto_desc.trainers.append(trainer_num) + def _set_use_cvm(self, use_cvm=False): self.proto_desc.use_cvm = use_cvm @@ -374,6 +381,30 @@ class PSGPUTrainer(TrainerDesc): self._device_worker._gen_worker_desc(self.proto_desc) +class HeterPipelineTrainer(TrainerDesc): + """ + Implement of HeterPipelineTrainer. + It's for HeterPS Pipeline training. + """ + + def __init__(self): + super(HeterPipelineTrainer, self).__init__() + pass + + def _set_program(self, program): + super(HeterPipelineTrainer, self)._set_program(program) + self._program = program + + def _gen_trainer_desc(self): + super(HeterPipelineTrainer, self)._gen_trainer_desc() + self.proto_desc.class_name = "HeterPipelineTrainer" + if self._program == None: + raise RuntimeError("None Program") + self._device_worker._set_infer(self._infer) + self._device_worker._set_program(self._program) + self._device_worker._gen_worker_desc(self.proto_desc) + + class PipelineTrainer(TrainerDesc): """ Implement of PipelineTrainer. diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index ed10bee2e063a7afb18a12adea604832163d26f7..1252676f844a70dfd242305ff54689706ccaf9c7 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -22,8 +22,8 @@ from paddle.fluid.log_helper import get_logger local_logger = get_logger( __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') -from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer, PSGPUTrainer -from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT +from .trainer_desc import MultiTrainer, DistMultiTrainer, PipelineTrainer, HeterXpuTrainer, PSGPUTrainer, HeterPipelineTrainer +from .device_worker import Hogwild, DownpourSGD, Section, DownpourSGDOPT, HeterSection from .framework import Variable from multiprocessing import Process, Manager @@ -56,6 +56,10 @@ class TrainerFactory(object): # for debug tools if opt_info is not None: + if opt_info.get("trainers") is not None: + trainer._set_trainers(opt_info["trainers"]) + if opt_info.get("trainer_id") is not None: + trainer._set_trainer_id(opt_info["trainer_id"]) if opt_info.get("dump_slot") is not None: trainer._set_dump_slot(opt_info["dump_slot"]) if opt_info.get("mpi_rank") is not None: