diff --git a/paddle/fluid/distributed/service/heter_server.h b/paddle/fluid/distributed/service/heter_server.h index 8acad1591d46c8196a081e022651a817f1ab92d9..76717a4674add8acef7a97f1dbef57b8adfd37a6 100644 --- a/paddle/fluid/distributed/service/heter_server.h +++ b/paddle/fluid/distributed/service/heter_server.h @@ -192,13 +192,24 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { virtual ~RequestSendAndRecvHandler() {} - // void SetMiniScopes(SharedMiniScope mini_scopes) { - // mini_scopes_ = mini_scopes; - // num_minibatch_ = mini_scopes_->size(); - //} + void SetMiniScopes(SharedMiniScope mini_scopes) { + mini_scopes_ = mini_scopes; + num_minibatch_ = mini_scopes_->size(); + } + void SetMicroScopes(SharedMicroScope micro_scopes) { micro_scopes_ = micro_scopes; - num_microbatch_ = micro_scopes_->size(); + for (auto& scope_pair : (*micro_scopes_)) { + // auto mini_idx = scope_pair.first; + auto& micro_scopes = scope_pair.second; + num_microbatch_ = micro_scopes->size(); + break; + } + } + + int GetThreadNum() { + std::unique_lock lk(scope_mutex_); + return (*task_queue_).size(); } void SetTaskQueue(SharedTaskQueue task_queue) { task_queue_ = task_queue; } @@ -235,25 +246,43 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { int minibatch_index = micro_id / 10; int microbatch_index = micro_id % 10; - // PADDLE_ENFORCE_EQ( - // (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1, - // platform::errors::InvalidArgument( - // "minibatch index should in current trainer")); - PADDLE_ENFORCE_EQ( - (*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1, - platform::errors::InvalidArgument( - "minibatch index should in current trainer")); + // check minibatch_index is in mini_scopes_ + std::unique_lock lk(scope_mutex_); + if ((*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end()) { + lk.unlock(); + // PADDLE_ENFORCE_EQ( + // (*mini_scopes_).find(minibatch_index) != (*mini_scopes_).end(), 1, + // platform::errors::InvalidArgument( + // "minibatch index should in current trainer")); + PADDLE_ENFORCE_EQ( + (*micro_scopes_).find(minibatch_index) != (*micro_scopes_).end(), 1, + platform::errors::InvalidArgument( + "minibatch index should in current trainer")); + + } else { + // create mini scope & micro scopes + auto* minibatch_scope = &(scope_->NewScope()); + (*mini_scopes_)[minibatch_index] = minibatch_scope; + (*micro_scopes_)[minibatch_index].reset( + new std::vector{}); + for (int i = 0; i < num_microbatch_; i++) { + auto* micro_scope = &(minibatch_scope->NewScope()); + (*((*micro_scopes_)[minibatch_index])).push_back(micro_scope); + } + (*task_queue_)[minibatch_index].reset( + new ::paddle::framework::BlockingQueue< + std::pair>()); + lk.unlock(); + } auto* micro_scope = (*((*micro_scopes_)[minibatch_index]))[microbatch_index]; distributed::DeserializeFromMultiVarMsgAndIOBuf( *request, &request_io_buffer, *dev_ctx_, micro_scope); - // blocking queue handles multi thread (*task_queue_)[minibatch_index]->Push( std::make_pair(message_name, microbatch_index)); - auto response_var_nums = request->recv_var_names_size(); std::vector response_var_names(response_var_nums), empty_var_names{}; @@ -269,11 +298,12 @@ class RequestSendAndRecvHandler final : public HeterRequestHandler { private: // share with HeterPipelineTrainer - // SharedMiniScope mini_scopes_{nullptr}; + SharedMiniScope mini_scopes_{nullptr}; SharedMicroScope micro_scopes_{nullptr}; int num_microbatch_; int num_minibatch_; + std::mutex scope_mutex_; bool is_first_stage_ = false; bool is_last_stage_ = false; @@ -321,14 +351,16 @@ class HeterServer { request_handler_ = request_handler; } - // void SetMiniBatchScopes(SharedMiniScope mini_scopes) { - // request_handler_->SetMiniScopes(mini_scopes); - //} + void SetMiniBatchScopes(SharedMiniScope mini_scopes) { + request_handler_->SetMiniScopes(mini_scopes); + } void SetMicroBatchScopes(SharedMicroScope micro_scopes) { request_handler_->SetMicroScopes(micro_scopes); } + int GetThreadNum() { return request_handler_->GetThreadNum(); } + void SetTaskQueue(SharedTaskQueue task_queue) { request_handler_->SetTaskQueue(task_queue); } diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index f9b9f43cdae5c6cdecee305159ae104c754c7597..600d75db53c7e791b3bf84ca787c90022598f729 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -631,10 +631,17 @@ class HeterSectionWorker : public DeviceWorker { std::shared_ptr> GetMicrobatchScopes() { return microbatch_scopes_; } + void SetMicrobatchScopes( + std::shared_ptr> microbatch_scopes) { + microbatch_scopes_ = microbatch_scopes; + } using SHARED_THREAD_QUEUE = std::shared_ptr< ::paddle::framework::BlockingQueue>>; SHARED_THREAD_QUEUE GetThreadQueue() { return thread_queue_; } + void SetThreadQueue(SHARED_THREAD_QUEUE thread_queue) { + thread_queue_ = thread_queue; + } void CopyParameters(int microbatch_id, const ProgramDesc& program, const platform::Place& place); void SetMinibatchScope(Scope* scope) { minibatch_scope_ = scope; } diff --git a/paddle/fluid/framework/heter_pipeline_trainer.cc b/paddle/fluid/framework/heter_pipeline_trainer.cc index 559b4178f0f4b33bda16c18da46b8000177a65fc..cb939f38ff3d9678e09e5cae433317031a47d78f 100644 --- a/paddle/fluid/framework/heter_pipeline_trainer.cc +++ b/paddle/fluid/framework/heter_pipeline_trainer.cc @@ -77,34 +77,51 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc, trainers_.push_back(trainer_num); } int cpu_trainer_num = trainers_[0]; - int cur_stage_trainer_num = trainers_[pipeline_stage_]; - int global_thread_num = cpu_trainer_num * thread_num_; - int previous_trainers = 0; - for (int i = 0; i < pipeline_stage_; i++) previous_trainers += trainers_[i]; - int stage_trainer_id = - trainer_id_ - previous_trainers; // trainer id in current stage - int cnt = -1; - for (int i = stage_trainer_id; i < global_thread_num; - i += cur_stage_trainer_num) { - cnt++; - workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( + // int cur_stage_trainer_num = trainers_[pipeline_stage_]; + // int global_thread_num = cpu_trainer_num * thread_num_; + // int previous_trainers = 0; + // for (int i = 0; i < pipeline_stage_; i++) previous_trainers += + // trainers_[i]; + // int stage_trainer_id = + // trainer_id_ - previous_trainers; // trainer id in current stage + + if (pipeline_stage_ == 0) { // for cpu trainer + int cnt = -1; + int real_thread_id = trainer_id_; + for (int i = 0; i < thread_num_; i++) { + cnt++; + workers_[real_thread_id] = DeviceWorkerFactory::CreateDeviceWorker( + trainer_desc.device_worker_name()); + auto this_worker = + std::dynamic_pointer_cast( + workers_[real_thread_id]); + this_worker->SetDebug(debug_); + this_worker->SetNeedDumpField(need_dump_field_); + this_worker->SetNeedDumpParam(need_dump_param_); + this_worker->SetDumpFieldVector(dump_fields_); + this_worker->SetDumpParamVector(dump_param_); + this_worker->InitRandomDumpConfig(trainer_desc); + this_worker->SetDeviceIndex(real_thread_id); + real_thread_id += cpu_trainer_num; + // if (pipeline_stage_ == 0) { + this_worker->SetDataFeed(readers[cnt]); + //} + this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetPipelineStageNum(num_pipeline_stages_); + this_worker->SetPipelineStage(pipeline_stage_); + } + } else { // for heter_trainer + // heter trainer with thread_id == -1 is not for + // real training + workers_[-1] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); auto this_worker = std::dynamic_pointer_cast( - workers_[i]); - this_worker->SetDebug(debug_); - this_worker->SetNeedDumpField(need_dump_field_); - this_worker->SetNeedDumpParam(need_dump_param_); - this_worker->SetDumpFieldVector(dump_fields_); - this_worker->SetDumpParamVector(dump_param_); - this_worker->InitRandomDumpConfig(trainer_desc); - this_worker->SetDeviceIndex(i); - if (pipeline_stage_ == 0) { - this_worker->SetDataFeed(readers[cnt]); - } + workers_[-1]); this_worker->SetMicrobatchNum(num_microbatches_); this_worker->SetPipelineStageNum(num_pipeline_stages_); this_worker->SetPipelineStage(pipeline_stage_); + this_worker->SetDeviceIndex(-1); } } @@ -177,7 +194,7 @@ void HeterPipelineTrainer::Run() { } auto heter_server = paddle::distributed::HeterServer::GetInstance(); heter_server->WaitServerReady(); - // heter_server->SetMiniBatchScopes(mini_scopes_); + heter_server->SetMiniBatchScopes(mini_scopes_); heter_server->SetMicroBatchScopes(micro_scopes_); heter_server->SetTaskQueue(task_queue_); // main training logic @@ -193,6 +210,7 @@ void HeterPipelineTrainer::Run() { } } } else { // for heter worker + // start thread_worker with thread_id = -1 for (auto& worker_pair : workers_) { auto device_worker = worker_pair.second; if (!debug_) { @@ -203,6 +221,60 @@ void HeterPipelineTrainer::Run() { device_worker.get())); } } + bool epoch_finish = false; + auto heter_server = paddle::distributed::HeterServer::GetInstance(); + while (!epoch_finish) { + if (heter_server->IsStop()) { + epoch_finish = true; + continue; + } + // create new thread_worker + // size_t thread_num = (*micro_scopes_).size(); + // size_t thread_num = (*task_queue_).size(); + size_t thread_num = heter_server->GetThreadNum(); + while (thread_num > threads_.size()) { + for (auto& worker_pair : (*micro_scopes_)) { + auto worker_index = worker_pair.first; + if (workers_.find(worker_index) != workers_.end()) continue; + workers_[worker_index] = DeviceWorkerFactory::CreateDeviceWorker( + trainer_desc_.device_worker_name()); + auto this_worker = + std::dynamic_pointer_cast( + workers_[worker_index]); + this_worker->SetDebug(debug_); + this_worker->SetNeedDumpField(need_dump_field_); + this_worker->SetNeedDumpParam(need_dump_param_); + this_worker->SetDumpFieldVector(dump_fields_); + this_worker->SetDumpParamVector(dump_param_); + this_worker->InitRandomDumpConfig(trainer_desc_); + this_worker->SetDeviceIndex(worker_index); + this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetPipelineStageNum(num_pipeline_stages_); + this_worker->SetPipelineStage(pipeline_stage_); + this_worker->SetPlace(place_); + this_worker->Initialize(trainer_desc_); + this_worker->SetRootScope(root_scope_); + + // generate mini_batch scope for every worker + // auto* minibatch_scope = &root_scope_->NewScope(); + auto* minibatch_scope = (*mini_scopes_)[worker_index]; + // (*mini_scopes_)[worker_index] = minibatch_scope; + this_worker->SetMinibatchScope(minibatch_scope); + // after set micro num & mini batch scope + this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]); + this_worker->CreateMicrobatchScopes(); + // this_worker->SetMicrobatchScopes((*micro_scopes_)[worker_index]); + this_worker->SetThreadQueue((*task_queue_)[worker_index]); + if (!debug_) { + threads_.push_back( + std::thread(&DeviceWorker::TrainFiles, this_worker.get())); + } else { + threads_.push_back(std::thread( + &DeviceWorker::TrainFilesWithProfiler, this_worker.get())); + } + } + } + } } for (auto& th : threads_) { th.join(); @@ -228,7 +300,11 @@ void HeterPipelineTrainer::Finalize() { } Scope* HeterPipelineTrainer::GetWorkerScope(int thread_id) { - return workers_[thread_id]->GetThreadScope(); + if (workers_.find(thread_id) != workers_.end()) { + return workers_[thread_id]->GetThreadScope(); + } else { + return nullptr; + } } } // end namespace framework diff --git a/paddle/fluid/framework/heter_section_worker.cc b/paddle/fluid/framework/heter_section_worker.cc index 0249782cd5b3cd60bfc50a170534bcb0d6bbb72e..ace6ac49255c85c7d370ef80db36f75b805eb379 100644 --- a/paddle/fluid/framework/heter_section_worker.cc +++ b/paddle/fluid/framework/heter_section_worker.cc @@ -218,15 +218,21 @@ void HeterSectionWorker::CreateMicrobatchScopes() { minibatch_scope_, platform::errors::InvalidArgument( "minibatch_scope_ can not be nullptr when create MicroBatch Scopes")); - microbatch_scopes_.reset(new std::vector{}); - (*microbatch_scopes_).resize(num_microbatches_); - VLOG(3) << "Create microbatch scopes..."; - std::shared_ptr program; - program.reset(new ProgramDesc( - trainer_desc_.heter_section_param().section_config().program_desc())); - for (int j = 0; j < num_microbatches_; ++j) { - (*microbatch_scopes_)[j] = &minibatch_scope_->NewScope(); - CopyParameters(j, *program, place_); + if (microbatch_scopes_.get() == nullptr) { + microbatch_scopes_.reset(new std::vector{}); + (*microbatch_scopes_).resize(num_microbatches_); + VLOG(3) << "Create microbatch scopes..."; + for (int j = 0; j < num_microbatches_; ++j) { + (*microbatch_scopes_)[j] = &minibatch_scope_->NewScope(); + } + } + if (thread_id_ >= 0) { + std::shared_ptr program; + program.reset(new ProgramDesc( + trainer_desc_.heter_section_param().section_config().program_desc())); + for (int j = 0; j < num_microbatches_; ++j) { + CopyParameters(j, *program, place_); + } } } @@ -258,6 +264,8 @@ void HeterSectionWorker::CopyParameters(int microbatch_id, VLOG(5) << "Create persistable var: " << var->Name() << ", which pointer is " << ptr; } else if (!var->Persistable()) { + if ((*microbatch_scopes_)[microbatch_id]->FindVar(var->Name()) != nullptr) + continue; auto* ptr = (*microbatch_scopes_)[microbatch_id]->Var(var->Name()); VLOG(5) << "Create variable " << var->Name() << " for microbatch " << microbatch_id << ", which pointer is " << ptr; @@ -359,22 +367,25 @@ void HeterSectionWorker::BatchPostProcess() { } void HeterSectionWorker::TrainFiles() { - total_ins_num_ = 0; - batch_num_ = 0; - platform::SetNumThreads(1); - timeline_.Start(); - VLOG(3) << "begin section_worker TrainFiles"; - epoch_finish_ = false; - if (pipeline_stage_ == 0) { - device_reader_->Start(); - } - while (!epoch_finish_) { - Run(); - dev_ctx_->Wait(); + if (thread_id_ >= 0) { + total_ins_num_ = 0; + batch_num_ = 0; + platform::SetNumThreads(1); + timeline_.Start(); + VLOG(3) << "begin section_worker TrainFiles"; + epoch_finish_ = false; + if (pipeline_stage_ == 0) { + device_reader_->Start(); + } + while (!epoch_finish_) { + Run(); + dev_ctx_->Wait(); + } + timeline_.Pause(); + VLOG(3) << "worker " << thread_id_ << " train cost " + << timeline_.ElapsedSec() + << " seconds, ins_num: " << total_ins_num_; } - timeline_.Pause(); - VLOG(3) << "worker " << thread_id_ << " train cost " << timeline_.ElapsedSec() - << " seconds, ins_num: " << total_ins_num_; } void HeterSectionWorker::PrintFetchVars() { @@ -406,22 +417,24 @@ void HeterSectionWorker::PrintFetchVars() { } void HeterSectionWorker::TrainFilesWithProfiler() { - VLOG(3) << "begin section_worker TrainFilesWithProfiler"; - batch_num_ = 0; - epoch_finish_ = false; - total_ins_num_ = 0; - op_name_.clear(); - op_total_time_.clear(); - if (pipeline_stage_ == 0) { - device_reader_->Start(); - } - while (!epoch_finish_) { - Run(); - dev_ctx_->Wait(); - if (epoch_finish_) { - // dump param for debug - if (need_dump_field_ || need_dump_param_) { - writer_.Flush(); + if (thread_id_ >= 0) { + VLOG(3) << "begin section_worker TrainFilesWithProfiler"; + batch_num_ = 0; + epoch_finish_ = false; + total_ins_num_ = 0; + op_name_.clear(); + op_total_time_.clear(); + if (pipeline_stage_ == 0) { + device_reader_->Start(); + } + while (!epoch_finish_) { + Run(); + dev_ctx_->Wait(); + if (epoch_finish_) { + // dump param for debug + if (need_dump_field_ || need_dump_param_) { + writer_.Flush(); + } } } } diff --git a/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc b/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc index 3284103af53662b48315fbaca25c93e750c1d5b1..c870e758e96afc1c70a26236b0d20ac05d77aaf1 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc +++ b/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc @@ -161,13 +161,20 @@ TEST(HETER_LISTEN_AND_SERV, CPU) { b_rpc_service->WaitServerReady(); using MicroScope = std::unordered_map>>; + using MiniScope = std::unordered_map; + std::shared_ptr mini_scopes(new MiniScope{}); std::shared_ptr micro_scopes(new MicroScope{}); std::shared_ptr> micro_scope( new std::vector{}); - (*micro_scope).push_back(new framework::Scope()); - (*micro_scope).push_back(new framework::Scope()); + auto* mini_scope = new framework::Scope(); + (*mini_scopes)[0] = mini_scope; + auto* micro_scope_0 = &(mini_scope->NewScope()); + auto* micro_scope_1 = &(mini_scope->NewScope()); + (*micro_scope).push_back(micro_scope_0); + (*micro_scope).push_back(micro_scope_1); (*micro_scopes)[0] = micro_scope; b_rpc_service->SetMicroBatchScopes(micro_scopes); + b_rpc_service->SetMiniBatchScopes(mini_scopes); using TaskQueue = std::unordered_mapWaitServerReady(); using MicroScope = std::unordered_map>>; + using MiniScope = std::unordered_map; + std::shared_ptr mini_scopes(new MiniScope{}); std::shared_ptr micro_scopes(new MicroScope{}); std::shared_ptr> micro_scope( new std::vector{}); - (*micro_scope).push_back(new framework::Scope()); - (*micro_scope).push_back(new framework::Scope()); + auto* mini_scope = new framework::Scope(); + (*mini_scopes)[0] = mini_scope; + auto* micro_scope_0 = &(mini_scope->NewScope()); + auto* micro_scope_1 = &(mini_scope->NewScope()); + (*micro_scope).push_back(micro_scope_0); + (*micro_scope).push_back(micro_scope_1); (*micro_scopes)[0] = micro_scope; b_rpc_service->SetMicroBatchScopes(micro_scopes); + b_rpc_service->SetMiniBatchScopes(mini_scopes); using TaskQueue = std::unordered_mapWaitServerReady(); using MicroScope = std::unordered_map>>; + using MiniScope = std::unordered_map; + std::shared_ptr mini_scopes(new MiniScope{}); std::shared_ptr micro_scopes(new MicroScope{}); + auto* mini_scope = new framework::Scope(); + (*mini_scopes)[0] = mini_scope; std::shared_ptr> micro_scope( new std::vector{}); - (*micro_scope).push_back(new framework::Scope()); + auto* micro_scope_0 = &(mini_scope->NewScope()); + (*micro_scope).push_back(micro_scope_0); (*micro_scopes)[0] = micro_scope; b_rpc_service->SetMicroBatchScopes(micro_scopes); + b_rpc_service->SetMiniBatchScopes(mini_scopes); using TaskQueue = std::unordered_mapWaitServerReady(); using MicroScope = std::unordered_map>>; + using MiniScope = std::unordered_map; + std::shared_ptr mini_scopes(new MiniScope{}); std::shared_ptr micro_scopes(new MicroScope{}); + auto* mini_scope = new framework::Scope(); + (*mini_scopes)[0] = mini_scope; std::shared_ptr> micro_scope( new std::vector{}); - (*micro_scope).push_back(new framework::Scope()); + auto* micro_scope_0 = &(mini_scope->NewScope()); + (*micro_scope).push_back(micro_scope_0); (*micro_scopes)[0] = micro_scope; b_rpc_service2->SetMicroBatchScopes(micro_scopes); + b_rpc_service2->SetMiniBatchScopes(mini_scopes); using TaskQueue = std::unordered_map