diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 2b4751691bbdd33d204c2b41a4c37a24b6aef37c..d6d53a8858030734812587f6bbd03a108c5cf8ce 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -51,10 +51,6 @@ bool CheckValidOutput(LoDTensor* tensor, size_t batch_size); class FleetWrapper; -#define SEC_LOG \ - VLOG(3) << "[s" << section_id_ << "p" << pipeline_id_ << "t" << thread_id_ \ - << "]: " - class PullDenseWorker { public: virtual ~PullDenseWorker() {} @@ -311,40 +307,9 @@ class DownpourWorkerOpt : public DownpourWorker { }; #if defined(PADDLE_WITH_NCCL) -using ScopeQueue = operators::reader::BlockingQueue; - -class SyncFunctor { - public: - SyncFunctor(int rank_id, int rank_num, int sync_steps); - virtual ~SyncFunctor() {} - - void SetSyncParam(const std::vector& sync_param) { - sync_param_ = &sync_param; - } - void SetNcclCtxMap(platform::NCCLContextMap* nccl_ctx_map) { - nccl_ctx_map_ = nccl_ctx_map; - } - - int operator()(Scope* scope); - static std::vector pipeline_scopes_; - static uint64_t sync_flag_; - - protected: - const int rank_id_; - const int rank_num_; - const std::vector* sync_param_ = nullptr; - platform::NCCLContextMap* nccl_ctx_map_ = nullptr; - - uint64_t sync_signal_; - const int sync_steps_; - int counter_; - - void Synchronize(); -}; - class SectionWorker : public DeviceWorker { public: - SectionWorker() {} + SectionWorker() { local_batch_id_ = 0; } ~SectionWorker() override {} void Initialize(const TrainerDesc& desc) override; @@ -360,50 +325,39 @@ class SectionWorker : public DeviceWorker { const platform::Place& place() const { return place_; } void SetSectionIndex(int section_id) { section_id_ = section_id; } - void SetDeviceIndex(int tid) override { pipeline_id_ = tid; } + void SetDeviceIndex(int tid) override {} void SetThreadIndex(int thread_id) { thread_id_ = thread_id; } - void SetVarNames(const std::vector& in_var_names, - const std::vector& out_var_names) { - in_var_names_ = &in_var_names; - out_var_names_ = &out_var_names; - } - void SetScopeQueue(ScopeQueue* in_scope_queue, ScopeQueue* out_scope_queue) { - in_scope_queue_ = in_scope_queue; - out_scope_queue_ = out_scope_queue; + void SetMicrobatchNum(int num) { num_microbatches_ = num; } + void SetMicrobatchScopes(const std::vector& scope) { + microbatch_scopes_ = scope; } - void SetCountMutex(std::mutex* mutex) { worker_count_mutex_ = mutex; } - void SetWorkerCount(int* worker_count) { worker_count_ = worker_count; } - void SetSectionNum(int section_num) { section_num_ = section_num; } - void SetPipelineNum(int pipeline_num) { pipeline_num_ = pipeline_num; } - void SetNextSectionPlace(const paddle::platform::Place& place) { - next_section_place_ = place; + void SetMinibatchScope(const Scope* scope) { minibatch_scope_ = scope; } + void SetSkipVars(const std::vector& skip_vars) { + skip_vars_ = skip_vars; } - SyncFunctor* sync_func_ = nullptr; - void SetSyncFunctor(SyncFunctor* sync_func) { sync_func_ = sync_func; } static std::atomic cpu_id_; protected: void AutoSetCPUAffinity(bool reuse); int section_id_; - int pipeline_id_; - int section_num_; - int pipeline_num_; int thread_id_; - // This worker will consume scope from in_scope_queue_ - // and produce scope to out_scope_queue_ - ScopeQueue* in_scope_queue_ = nullptr; - ScopeQueue* out_scope_queue_ = nullptr; - const std::vector* in_var_names_ = nullptr; - const std::vector* out_var_names_ = nullptr; - std::mutex* worker_count_mutex_ = nullptr; - int* worker_count_ = nullptr; - paddle::platform::Place next_section_place_; + int num_microbatches_; + std::vector microbatch_scopes_; + std::vector skip_vars_; + const Scope* minibatch_scope_; std::vector> ops_; + static std::mutex thread_mutex; + static std::condition_variable thread_condition; + static bool threads_completed; + std::shared_ptr program_; + static uint64_t batch_id_; + uint64_t local_batch_id_; platform::DeviceContext* dev_ctx_ = nullptr; }; #endif + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 47e962a4825369020535905dab2859fd9be0398b..379892ecfd1161fd5e5003552bc48b1153b2c412 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -23,8 +23,13 @@ namespace framework { void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { - pipeline_num_ = trainer_desc.thread_num(); - VLOG(3) << "pipeline num: " << pipeline_num_; + const auto& section_params = trainer_desc.section_param(); + num_microbatches_ = section_params.num_microbatches(); + VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; + section_num_ = section_params.section_config_size(); + VLOG(3) << "Number of program sections: " << section_num_; + trainer_desc_ = trainer_desc; + start_cpu_core_id_ = section_params.start_cpu_core_id(); SetDataset(dataset); ParseDumpConfig(trainer_desc); @@ -32,96 +37,62 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, const std::vector readers = dataset->GetReaders(); VLOG(3) << "readers num: " << readers.size(); - - pipeline_config_ = trainer_desc.section_param(); - scope_queue_size_ = pipeline_config_.queue_size(); - sync_steps_ = pipeline_config_.sync_steps(); - section_num_ = pipeline_config_.section_config_size(); - - VLOG(3) << "scope_queue_size: " << scope_queue_size_; - VLOG(3) << "section num: " << section_num_; - VLOG(3) << "sync_steps: " << sync_steps_; + int num_readers = readers.size(); + PADDLE_ENFORCE_EQ(num_readers, 1, + platform::errors::InvalidArgument( + "Number of dataset readers for pipeline " + "must be 1 now, but the value you give is %d.", + num_readers)); + auto* reader = readers[0]; + feed_var_names_ = reader->GetUseSlotAlias(); workers_.resize(section_num_); - in_var_names_.resize(section_num_); - out_var_names_.resize(section_num_); - worker_count_.resize(section_num_); - worker_count_mutex_.resize(section_num_); - param_need_sync_.reset(new std::vector); - - int reader_index = 0; for (int i = 0; i < section_num_; ++i) { - const auto& section_config = pipeline_config_.section_config(i); - int concurrency = section_config.concurrency(); - VLOG(3) << "the thread num of each pipeline in section " << i - << " is: " << concurrency; - in_var_names_[i].reset(new std::vector( - section_config.section_in_var_names().begin(), - section_config.section_in_var_names().end())); - out_var_names_[i].reset(new std::vector( - section_config.section_out_var_names().begin(), - section_config.section_out_var_names().end())); - worker_count_[i].resize(pipeline_num_); - worker_count_mutex_[i].resize(pipeline_num_); - for (int j = 0; j < pipeline_num_; ++j) { - worker_count_[i][j] = new int(concurrency); - worker_count_mutex_[i][j].reset(new std::mutex); - } - + const auto& section_config = section_params.section_config(i); platform::Place place; - workers_[i].resize(pipeline_num_); - for (int j = 0; j < pipeline_num_; ++j) { - workers_[i][j].resize(concurrency); - - switch (section_config.place()) { - case SectionConfig::CPUPlace: - place = platform::CPUPlace(); - break; - case SectionConfig::CUDAPlace: - // Note that one section has at most one GPU place in one pipeline - place = platform::CUDAPlace(j); - break; - case SectionConfig::CUDAPinnedPlace: - place = platform::CUDAPinnedPlace(); - break; - default: - PADDLE_ENFORCE(false, "Unkown place type in SectionConfig: %d", - section_config.place()); - } + int place_id = section_config.place_id(); + switch (section_config.place()) { + case SectionConfig::CPUPlace: + place = platform::CPUPlace(); + break; + case SectionConfig::CUDAPlace: + // Note that one section has at most one GPU place in one pipeline + PADDLE_ENFORCE_GE( + place_id, 0, + platform::errors::InvalidArgument( + "The place_id value for CUDAPlace shoud be greater " + "than or equal to 0, but the value you give is %d.", + place_id)); + place = platform::CUDAPlace(place_id); + break; + case SectionConfig::CUDAPinnedPlace: + place = platform::CUDAPinnedPlace(); + break; + default: + PADDLE_ENFORCE_NOT_NULL(nullptr, + platform::errors::InvalidArgument( + "Unkown place type in SectionConfig: %d", + section_config.place())); + } + places_.emplace_back(place); + VLOG(3) << "Device worker place: " << place << ", device id: " << place_id + << ", section: " << i; - for (int k = 0; k < concurrency; ++k) { - workers_[i][j][k] = DeviceWorkerFactory::CreateDeviceWorker( - trainer_desc.device_worker_name()); - auto this_worker = - std::dynamic_pointer_cast( - workers_[i][j][k]); - this_worker->SetSectionIndex(i); - this_worker->SetDeviceIndex(j); - this_worker->SetThreadIndex(k); - this_worker->SetSectionNum(section_num_); - this_worker->SetPipelineNum(pipeline_num_); - if (i == 0) { - this_worker->SetDataFeed(readers[reader_index++]); - this_worker->SetReaderPlace(place); - } - if (i == section_num_ - 1) { - this_worker->SetNeedDumpField(need_dump_field_); - this_worker->SetNeedDumpParam(need_dump_param_); - this_worker->SetDumpFieldVector(dump_fields_); - this_worker->SetDumpParamVector(dump_param_); - } - this_worker->SetPlace(place); - this_worker->Initialize(trainer_desc); - this_worker->InitRandomDumpConfig(trainer_desc); - } + workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( + trainer_desc.device_worker_name()); + auto this_worker = + std::dynamic_pointer_cast( + workers_[i]); + if (i == 0) { + // we only set reader for the first section + this_worker->SetDataFeed(reader); + this_worker->SetReaderPlace(place); } - } - param_need_sync_.reset( - new std::vector(pipeline_config_.param_need_sync().begin(), - pipeline_config_.param_need_sync().end())); - VLOG(3) << "param_need_sync_ have: "; - for (const std::string& name : *param_need_sync_) { - VLOG(3) << name; + this_worker->SetThreadIndex(i); + this_worker->SetSectionIndex(i); + this_worker->SetPlace(place); + this_worker->Initialize(trainer_desc); + this_worker->SetMicrobatchNum(num_microbatches_); } // set debug here SetDebug(trainer_desc.debug()); @@ -140,13 +111,7 @@ std::string PipelineTrainer::GetDumpPath(int tid) { void PipelineTrainer::InitDumpEnv() { queue_ = paddle::framework::MakeChannel(); - // Only set dump channel on the last section - for (int j = 0; j < pipeline_num_; ++j) { - for (size_t k = 0; k < workers_[section_num_ - 1][j].size(); ++k) { - workers_[section_num_ - 1][j][k]->SetChannelWriter(queue_.get()); - } - } - // TODO(hutuxian): should make it as a config + // TODO(sandyhouse): should make it as a config dump_thread_num_ = 1; for (int i = 0; i < dump_thread_num_; i++) { dump_thread_.push_back( @@ -154,150 +119,105 @@ void PipelineTrainer::InitDumpEnv() { } } -void PipelineTrainer::InitFirstScopeQueue(ScopeQueue* scope_queue, - int pipeline_id, - const ProgramDesc& main_program, - const Scope& root_scope) { - for (int i = 0; i < scope_queue_size_; ++i) { - Scope* scope = &pipeline_scopes_[pipeline_id]->NewScope(); - for (auto& var : main_program.Block(0).AllVars()) { - if (!var->Persistable()) { - auto* ptr = scope->Var(var->Name()); - InitializeVariable(ptr, var->GetType()); +void PipelineTrainer::CopyParameters(int section_id, int microbatch_id, + const ProgramDesc& program, + const platform::Place& place) { + auto& global_block = program.Block(0); + for (auto& var : global_block.AllVars()) { + int is_feed_var = + std::count(feed_var_names_.begin(), feed_var_names_.end(), var->Name()); + if ((var->Persistable() || is_feed_var) && microbatch_id == 0) { + if (is_feed_var) { + auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name()); + VLOG(3) << "data name: " << var->Name() << ", ptr: " << new_ptr; + InitializeVariable(new_ptr, var->GetType()); } else { - if (section_num_ == 1) { // Means only one section and it must be - // CUDAPlace, so copy all persistable vars to - // pipeline scope - const LoDTensor& root_tensor = - root_scope.FindVar(var->Name())->Get(); - LoDTensor* gpu_tensor = pipeline_scopes_[pipeline_id] - ->Var(var->Name()) - ->GetMutable(); - platform::Place place = platform::CUDAPlace(pipeline_id); - TensorCopy(*static_cast(&root_tensor), place, - static_cast(gpu_tensor)); - } + auto* ptr = root_scope_->FindVar(var->Name()); + auto* new_ptr = minibatch_scopes_[section_id]->Var(var->Name()); + VLOG(3) << "Create persistable var " << var->Name() << " for minibatch " + << section_id << ", which pointer is " << new_ptr; + InitializeVariable(new_ptr, var->GetType()); + const LoDTensor& root_tensor = ptr->Get(); + LoDTensor* minibatch_tensor = new_ptr->GetMutable(); + TensorCopy(*static_cast(&root_tensor), place, + static_cast(minibatch_tensor)); } + } else if (!var->Persistable() && !is_feed_var) { + auto* ptr = + microbatch_scopes_[section_id][microbatch_id]->Var(var->Name()); + VLOG(3) << "Create variable " << var->Name() << " for section " + << section_id << " microbatch " << microbatch_id + << ", which pointer is " << ptr; + InitializeVariable(ptr, var->GetType()); } - scope_queue->Send(scope); } } -void PipelineTrainer::CopyParameters(const Scope& root_scope, int pipeline_id) { - for (const std::string& name : *param_need_sync_) { - const LoDTensor& root_tensor = root_scope.FindVar(name)->Get(); - - // TODO(hutxian): check a new var of the same name is created in - // pipeline_scope - LoDTensor* gpu_tensor = - pipeline_scopes_[pipeline_id]->Var(name)->GetMutable(); - platform::Place place = platform::CUDAPlace(pipeline_id); - TensorCopy(*static_cast(&root_tensor), place, - static_cast(gpu_tensor)); +void PipelineTrainer::GetSkipVars(int section_id, const ProgramDesc& program) { + auto& global_block = program.Block(0); + for (auto& op : global_block.AllOps()) { + if (op->Type() != "enqueue") { + continue; + } + auto input_arg_names = op->InputArgumentNames(); + PADDLE_ENFORCE_EQ(input_arg_names.size(), 1, + platform::errors::InvalidArgument( + "Number of input arguments for enqueue op must be 1, " + "but the value is %d.", + input_arg_names.size())); + std::string input_arg_name = input_arg_names[0]; + if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) { + skip_vars_[section_id].emplace_back(input_arg_name); + VLOG(3) << "add skip var name: " << input_arg_name; + } } } void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place) { - PADDLE_ENFORCE(root_scope_, "Null root_scope pointer"); - SectionWorker::cpu_id_.store(pipeline_config_.start_cpu_core_id()); - scope_queues_.resize(section_num_); - pipeline_scopes_.resize(pipeline_num_); - for (auto& var : main_program.Block(0).AllVars()) { - if (var->Persistable()) { - persistable_vars_.push_back(var->Name()); - } - } + PADDLE_ENFORCE_NOT_NULL(root_scope_, + platform::errors::InvalidArgument( + "root_scope pointer can not be nullptr")); + auto start_cpu_id = trainer_desc_.section_param().start_cpu_core_id(); + SectionWorker::cpu_id_.store(start_cpu_id); + minibatch_scopes_.resize(section_num_); + microbatch_scopes_.resize(section_num_); + skip_vars_.resize(section_num_); VLOG(3) << "Init ScopeQueues and create all scopes"; for (int i = 0; i < section_num_; ++i) { - for (int j = 0; j < pipeline_num_; ++j) { - scope_queues_[i].emplace_back(new ScopeQueue(scope_queue_size_)); - if (i == 0) { - pipeline_scopes_[j] = &root_scope_->NewScope(); - CopyParameters(*root_scope_, j); - InitFirstScopeQueue(scope_queues_[0].back().get(), j, main_program, - *root_scope_); - } + minibatch_scopes_[i] = &root_scope_->NewScope(); + std::shared_ptr program; + program.reset(new ProgramDesc( + trainer_desc_.section_param().section_config(i).program_desc())); + microbatch_scopes_[i].resize(num_microbatches_); + for (int j = 0; j < num_microbatches_; ++j) { + microbatch_scopes_[i][j] = &minibatch_scopes_[i]->NewScope(); + CopyParameters(i, j, *program, places_[i]); } + GetSkipVars(i, *program); } for (int i = 0; i < section_num_; ++i) { - for (int j = 0; j < pipeline_num_; ++j) { - for (size_t k = 0; k < workers_[i][j].size(); ++k) { - auto this_worker = - std::dynamic_pointer_cast( - workers_[i][j][k]); - this_worker->SetRootScope(root_scope_); - this_worker->SetCountMutex(worker_count_mutex_[i][j].get()); - this_worker->SetWorkerCount(worker_count_[i][j]); - this_worker->SetScopeQueue(scope_queues_[i][j].get(), - (i == section_num_ - 1) - ? scope_queues_[0][j].get() - : scope_queues_[i + 1][j].get()); - this_worker->SetVarNames(*in_var_names_[i], *out_var_names_[i]); - if (i != section_num_ - 1) { - // For data copy in adjacent different place - this_worker->SetNextSectionPlace( - std::dynamic_pointer_cast( - workers_[i + 1][j][0]) - ->place()); - } - } - } - } - - if (pipeline_num_ > 1 && sync_steps_ != -1) { - construct_sync_functor(); - } -} - -void PipelineTrainer::construct_sync_functor() { - std::vector cuda_places; - for (int i = 0; i < pipeline_num_; ++i) { - cuda_places.emplace_back(platform::CUDAPlace(i)); - } - nccl_ctx_map_.reset(new platform::NCCLContextMap(cuda_places)); - sync_functors_.resize(pipeline_num_); - SyncFunctor::sync_flag_ = 0; - SyncFunctor::pipeline_scopes_.resize(0); - - for (int j = 0; j < pipeline_num_; ++j) { - SyncFunctor* sync_function = new SyncFunctor(j, pipeline_num_, sync_steps_); - sync_function->SetSyncParam(*param_need_sync_); - sync_function->SetNcclCtxMap(nccl_ctx_map_.get()); - SyncFunctor::pipeline_scopes_.push_back(this->pipeline_scopes_[j]); - sync_functors_[j].reset(sync_function); - } - for (int i = section_num_ - 1; i >= 0; --i) { - if (SectionConfig::CUDAPlace == - pipeline_config_.section_config(i).place()) { - for (int j = 0; j < pipeline_num_; ++j) { - for (size_t k = 0; k < workers_[i][j].size(); ++k) { - auto this_worker = - std::dynamic_pointer_cast( - workers_[i][j][k]); - this_worker->SetSyncFunctor(sync_functors_[j].get()); - } - } - break; - } + auto this_worker = + std::dynamic_pointer_cast( + workers_[i]); + this_worker->SetRootScope(root_scope_); + this_worker->SetMinibatchScope(minibatch_scopes_[i]); + this_worker->SetMicrobatchScopes(microbatch_scopes_[i]); + this_worker->SetSkipVars(skip_vars_[i]); } } void PipelineTrainer::Run() { VLOG(3) << "Going to run"; for (int i = 0; i < section_num_; ++i) { - for (int j = 0; j < pipeline_num_; ++j) { - for (size_t k = 0; k < workers_[i][j].size(); ++k) { - if (!debug_) { - section_threads_.push_back( - std::thread(&DeviceWorker::TrainFiles, workers_[i][j][k].get())); - } else { - section_threads_.push_back(std::thread( - &DeviceWorker::TrainFilesWithProfiler, workers_[i][j][k].get())); - } - } + if (!debug_) { + section_threads_.push_back( + std::thread(&DeviceWorker::TrainFiles, workers_[i].get())); + } else { + section_threads_.push_back(std::thread( + &DeviceWorker::TrainFilesWithProfiler, workers_[i].get())); } } } @@ -309,18 +229,31 @@ void PipelineTrainer::Finalize() { if (need_dump_field_) { FinalizeDumpEnv(); } - for (const auto& var : persistable_vars_) { - auto* root_tensor = root_scope_->Var(var)->GetMutable(); - // TODO(hutuxian): Add a final all-reduce? - const auto& thread_tensor = - pipeline_scopes_[0]->FindVar(var)->Get(); - TensorCopySync(thread_tensor, platform::CPUPlace(), root_tensor); + VLOG(3) << "copying back parameters. "; + for (int i = 0; i < section_num_; ++i) { + std::shared_ptr program; + program.reset(new ProgramDesc( + trainer_desc_.section_param().section_config(i).program_desc())); + for (int j = 0; j < num_microbatches_; ++j) { + auto& global_block = program->Block(0); + for (auto& var : global_block.AllVars()) { + if (var->Persistable()) { + auto* ptr = root_scope_->FindVar(var->Name()); + LoDTensor* root_tensor = ptr->GetMutable(); + auto* minibatch_ptr = minibatch_scopes_[i]->Var(var->Name()); + const LoDTensor& minibatch_tensor = minibatch_ptr->Get(); + TensorCopy(*static_cast(&minibatch_tensor), places_[0], + static_cast(root_tensor)); + VLOG(4) << "Copy persitable var " << var->Name() << " to root scope"; + } + } + } } root_scope_->DropKids(); } Scope* PipelineTrainer::GetWorkerScope(int thread_id) { - return pipeline_scopes_[thread_id]; + return microbatch_scopes_[thread_id][0]; } } // end namespace framework diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index df8bd61554e590fb0d83960a0fca63f78229c9a4..03b7afbb8771fadbe07a352497fa69a299928cf7 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -10,6 +10,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #if defined(PADDLE_WITH_NCCL) +#include +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/garbage_collector.h" +#include "paddle/fluid/framework/program_desc.h" + #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" @@ -25,82 +30,17 @@ limitations under the License. */ namespace paddle { namespace framework { -uint64_t SyncFunctor::sync_flag_ = 0; -std::vector SyncFunctor::pipeline_scopes_; - -SyncFunctor::SyncFunctor(int rank_id, int rank_num, int sync_steps) - : rank_id_(rank_id), rank_num_(rank_num), sync_steps_(sync_steps) { - PADDLE_ENFORCE(rank_num > 1, "rank_num should larger than 1"); - counter_ = 0; - sync_signal_ = 0; - uint8_t* ptr = reinterpret_cast(&sync_signal_); - for (int i = 0; i < rank_num_; ++i) { - ptr[i] = 0xFF; - } -} - -int SyncFunctor::operator()(Scope* scope) { - ++counter_; - if (counter_ < sync_steps_) { - return 0; - } - if (counter_ == sync_steps_) { - reinterpret_cast(&sync_flag_)[rank_id_] = 0xFF; - } - - if (sync_flag_ == sync_signal_) { - static std::mutex mutex; - if (mutex.try_lock()) { - if (sync_flag_ == sync_signal_) { - Synchronize(); - sync_flag_ = 0; - } - mutex.unlock(); - } - } - - if (sync_flag_ == 0) { - counter_ = 0; - } - return 0; -} - -void SyncFunctor::Synchronize() { - for (const std::string& name : *sync_param_) { - platform::NCCLGroupGuard guard; - for (int i = 0; i < rank_num_; ++i) { - const platform::NCCLContext& nccl_ctx = nccl_ctx_map_->at(i); - LoDTensor* tensor = - pipeline_scopes_[i]->Var(name)->GetMutable(); - // TODO(hutuxian): do not depend on data type explicitly - float* data = - tensor->mutable_data(nccl_ctx_map_->DevCtx(i)->GetPlace()); - const int numel = tensor->numel(); - - paddle::framework::AttributeMap attrs; - attrs.insert({"scale", static_cast(1. / rank_num_)}); - auto scale_op = framework::OpRegistry::CreateOp("scale", {{"X", {name}}}, - {{"Out", {name}}}, attrs); - scale_op->Run(*(pipeline_scopes_[i]), - nccl_ctx_map_->DevCtx(i)->GetPlace()); - PADDLE_ENFORCE(platform::dynload::ncclAllReduce( - data, data, numel, ncclFloat, ncclSum, nccl_ctx.comm(), - dynamic_cast( - platform::DeviceContextPool::Instance().Get( - platform::CUDAPlace(i))) - ->stream())); - } - } - nccl_ctx_map_->WaitAll(); -} - std::atomic SectionWorker::cpu_id_(0); +std::mutex SectionWorker::thread_mutex; +std::condition_variable SectionWorker::thread_condition; +bool SectionWorker::threads_completed = false; +uint64_t SectionWorker::batch_id_(0); + void SectionWorker::Initialize(const TrainerDesc& desc) { dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); - std::shared_ptr program; - program.reset(new ProgramDesc( + program_.reset(new ProgramDesc( desc.section_param().section_config(section_id_).program_desc())); - for (auto& op_desc : program->Block(0).AllOps()) { + for (auto& op_desc : program_->Block(0).AllOps()) { ops_.push_back(OpRegistry::CreateOp(*op_desc)); } } @@ -136,314 +76,494 @@ void SectionWorker::AutoSetCPUAffinity(bool reuse) { (0 == CPU_ISSET(proc, &mask))) { LOG(WARNING) << "Fail to set thread affinity to CPU " << proc; } - SEC_LOG << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc; + VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc; } void SectionWorker::TrainFiles() { - SEC_LOG << "begin section_worker TrainFiles"; + VLOG(3) << "begin section_worker TrainFiles"; AutoSetCPUAffinity(true); - int64_t step_cnt = 0; - int64_t accum_num = 0; - int batch_size = 0; - Scope* scope = nullptr; - if (device_reader_ != nullptr) { - device_reader_->Start(); - } - while (in_scope_queue_->Receive(&scope)) { - if (device_reader_ != nullptr) { - device_reader_->AssignFeedVar(*scope); - batch_size = device_reader_->Next(); - if (batch_size <= 0) { - break; - } - SEC_LOG << "read batch size: " << batch_size; + int64_t max_memory_size = 0; + std::unique_ptr gc; + auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place_)) { + if (IsFastEagerDeletionModeEnabled()) { + gc.reset(new UnsafeFastGPUGarbageCollector( + BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size)); } else { - // TODO(hutuxian): Keep batch_size in scope? Or is there a better way to - // fetch batch_size? Some variables may not have batch_size. - PADDLE_ENFORCE( - in_var_names_->size(), - "Section without a reader or in variable is not supported by now"); - const LoDTensor& tensor = - scope->FindVar(in_var_names_->at(0))->Get(); - batch_size = - tensor.lod().size() ? tensor.lod()[0].size() - 1 : tensor.dims()[0]; - SEC_LOG << "input batch size: " << batch_size; + gc.reset(new DefaultStreamGarbageCollector( + BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size)); } + } else if (platform::is_cpu_place(place_)) { +#endif + gc.reset(new CPUGarbageCollector( + BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size)); +#ifdef PADDLE_WITH_CUDA + } +#endif - Scope* exe_scope = scope; - if (section_id_ > 0 && platform::is_gpu_place(place_)) { - SEC_LOG << "CPU2GPU memory copy"; - - if (scope->kids().empty()) { - exe_scope = &scope->NewScope(); - } else { - exe_scope = scope->kids().front(); - PADDLE_ENFORCE(scope->kids().size() == 1, "scope->kids().size(): %zu", - scope->kids().size()); + if (thread_id_ == 0) { + while (true) { + // Start a minibatch. + for (int i = 0; i < num_microbatches_; ++i) { + try { + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + // We run op with op_role = kLRSched only for the first microbatch + // to avoid increasing the @LR_DECAY_STEP@ multiple times. + bool run_first_mbatch = + op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)) || + op_role == static_cast(OpRole::kLRSched); + bool run_others = op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)); + if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for scope " << i; + op->Run(*microbatch_scopes_[i], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), + unused_vars_, gc.get()); + } + } + } + } catch (platform::EOFException&) { + std::unique_lock lk(thread_mutex); + threads_completed = true; + VLOG(3) << "thread " << thread_id_ << " completed."; + VLOG(3) << "called notify all"; + thread_condition.notify_all(); + VLOG(0) << "EOF encountered"; + return; + } + if (i == 0) { + VLOG(3) << "called notify all"; + std::unique_lock lk(thread_mutex); + batch_id_ += 1; + thread_condition.notify_all(); + } } - - for (const std::string& name : *in_var_names_) { - const LoDTensor& src_tensor = scope->FindVar(name)->Get(); - if (platform::is_gpu_place(src_tensor.place())) { - continue; + // backward pass + for (int i = 0; i < num_microbatches_; ++i) { + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kBackward) || + op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for scope " << i; + op->Run(*microbatch_scopes_[i], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), + unused_vars_, gc.get()); + } + } } - LoDTensor* gpu_tensor = exe_scope->Var(name)->GetMutable(); - gpu_tensor->set_lod(src_tensor.lod()); - TensorCopy(*static_cast(&src_tensor), place_, *dev_ctx_, - static_cast(gpu_tensor)); } - } - - SEC_LOG << "begin running ops"; - - for (auto& op : ops_) { - op->Run(*exe_scope, place_); - } - exe_scope->DropKids(); - // Wait for GPU calc finising, as the cudaMemcpy and GPU calc may be in - // different streams - // No effect when it is a CPUDeviceContext - dev_ctx_->Wait(); - -#ifdef PADDLE_WITH_BOX_PS - auto box_ptr = BoxWrapper::GetInstance(); - auto& metric_list = box_ptr->GetMetricList(); - for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) { - auto* metric_msg = iter->second; - if (box_ptr->Phase() != metric_msg->MetricPhase()) { - continue; + // update pass + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kOptimize)) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for minibatch scope"; + op->Run(*microbatch_scopes_[0], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], + op.get(), unused_vars_, gc.get()); + } + } } - metric_msg->add_data(exe_scope); + dev_ctx_->Wait(); } -#endif - if (section_id_ != section_num_ - 1 && platform::is_gpu_place(place_)) { - // FIXME: Temporarily we assume two adjacent sections are in different - // places, - // and we do data transformation only in sections in GPU place, so the - // data is - // transform from GPU to CPU - // A better way to handle such a data transformation is to record each - // place of - // joint-out variables, and do transform as required - - SEC_LOG << "GPU2CPU memory copy"; - - for (const std::string& name : *out_var_names_) { - const LoDTensor& src_tensor = - exe_scope->FindVar(name)->Get(); - LoDTensor* dst_tensor = scope->Var(name)->GetMutable(); - dst_tensor->set_lod(src_tensor.lod()); - TensorCopy(*static_cast(&src_tensor), - next_section_place_, *dev_ctx_, - static_cast(dst_tensor)); + } else { + while (true) { + { + PADDLE_ENFORCE_LE( + local_batch_id_, batch_id_, + platform::errors::InvalidArgument( + "local_batch_id_ (%d) must be less than or equal to " + "batch_id_ (%d)", + local_batch_id_, batch_id_)); + std::unique_lock lk(thread_mutex); + if (local_batch_id_ == batch_id_ && !threads_completed) { + thread_condition.wait(lk); + } + VLOG(3) << "thread " << thread_id_ << " local_batch_id_ " + << local_batch_id_ << " batch_id_ " << batch_id_; + if (threads_completed) { + VLOG(3) << "thread " << thread_id_ << " completed."; + lk.unlock(); + threads_completed = false; + return; + } + lk.unlock(); + local_batch_id_ += 1; } + // forward pass: + for (int i = 0; i < num_microbatches_; ++i) { + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + // We run op with op_role = kLRSched only for the first microbatch + // to avoid increasing the @LR_DECAY_STEP@ multiple times. + bool run_first_mbatch = + op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)) || + op_role == static_cast(OpRole::kLRSched); + bool run_others = op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)); + if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for scope " << i; + op->Run(*microbatch_scopes_[i], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), + unused_vars_, gc.get()); + } + } + } + } + // backward pass + for (int i = 0; i < num_microbatches_; ++i) { + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kBackward) || + op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for scope " << i; + op->Run(*microbatch_scopes_[i], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), + unused_vars_, gc.get()); + } + } + } + } + // update pass + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kOptimize)) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for minibatch scope"; + op->Run(*microbatch_scopes_[0], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], + op.get(), unused_vars_, gc.get()); + } + } + } + dev_ctx_->Wait(); } - - out_scope_queue_->Send(scope); - - if (sync_func_) { - (*sync_func_)(scope); - } - - ++step_cnt; - accum_num += batch_size; - } - - worker_count_mutex_->lock(); - --(*worker_count_); - worker_count_mutex_->unlock(); - - if (*worker_count_ <= 0) { - while (section_id_ < section_num_ - 1 && out_scope_queue_->Size()) { - sleep(1); - } - out_scope_queue_->Close(); } } void SectionWorker::TrainFilesWithProfiler() { - SEC_LOG << "begin section_worker TrainFiles with profiler"; + VLOG(3) << "begin section_worker TrainFiles with profiler"; AutoSetCPUAffinity(true); - int64_t step_cnt = 0; - int64_t accum_num = 0; - int batch_size = 0; - Scope* scope = nullptr; - - platform::Timer reader_timer; - platform::Timer cal_timer; - platform::Timer trans_timer; - platform::Timer sync_timer; - platform::Timer main_timer; - platform::Timer outer_timer; + platform::Timer batch_timer; + platform::Timer timeline; std::vector op_total_time; std::vector op_name; + std::vector op_max_time; + std::vector op_min_time; + std::vector op_count; for (auto& op : ops_) { op_name.push_back(op->Type()); } op_total_time.resize(ops_.size()); - for (size_t i = 0; i < op_total_time.size(); ++i) { - op_total_time[i] = 0.0; - } - platform::Timer timeline; - if (device_reader_ != nullptr) { - device_reader_->Start(); + op_max_time.resize(ops_.size()); + op_min_time.resize(ops_.size()); + for (size_t i = 0; i < op_min_time.size(); ++i) { + op_min_time[i] = DBL_MAX; } - - bool started = false; - while (in_scope_queue_->Receive(&scope)) { - if (UNLIKELY(!started)) { - outer_timer.Start(); - started = true; - } - main_timer.Resume(); - - if (device_reader_ != nullptr) { - reader_timer.Resume(); - device_reader_->AssignFeedVar(*scope); - batch_size = device_reader_->Next(); - reader_timer.Pause(); - if (batch_size <= 0) { - break; - } - SEC_LOG << "read batch size: " << batch_size; + op_count.resize(ops_.size()); + + int64_t max_memory_size = 0; + std::unique_ptr gc; + // const std::vector keep_vars; + auto unused_vars_ = GetUnusedVars(program_->Block(0), ops_, skip_vars_); +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(place_)) { + if (IsFastEagerDeletionModeEnabled()) { + gc.reset(new UnsafeFastGPUGarbageCollector( + BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size)); } else { - PADDLE_ENFORCE( - in_var_names_->size(), - "Section without a reader or in variable is not supported by now"); - const LoDTensor& tensor = - scope->FindVar(in_var_names_->at(0))->Get(); - batch_size = - tensor.lod().size() ? tensor.lod()[0].size() - 1 : tensor.dims()[0]; - SEC_LOG << "input batch size: " << batch_size; + gc.reset(new DefaultStreamGarbageCollector( + BOOST_GET_CONST(platform::CUDAPlace, place_), max_memory_size)); } + } else if (platform::is_cpu_place(place_)) { +#endif + gc.reset(new CPUGarbageCollector( + BOOST_GET_CONST(platform::CPUPlace, place_), max_memory_size)); +#ifdef PADDLE_WITH_CUDA + } +#endif - Scope* exe_scope = scope; - if (section_id_ > 0 && platform::is_gpu_place(place_)) { - SEC_LOG << "CPU2GPU memory copy"; - trans_timer.Resume(); - if (scope->kids().empty()) { - exe_scope = &scope->NewScope(); - } else { - exe_scope = scope->kids().front(); - PADDLE_ENFORCE(scope->kids().size() == 1, "scope->kids().size(): %zu", - scope->kids().size()); + if (thread_id_ == 0) { + while (true) { + // Start a minibatch. + // int batch_size = 0; + batch_timer.Start(); + for (int i = 0; i < num_microbatches_; ++i) { + try { + int op_idx = 0; + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + // We run op with op_role = kLRSched only for the first microbatch + // to avoid increasing the @LR_DECAY_STEP@ multiple times. + bool run_first_mbatch = + op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)) || + op_role == static_cast(OpRole::kLRSched); + bool run_others = op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)); + if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for scope " << i; + timeline.Start(); + op->Run(*microbatch_scopes_[i], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), + unused_vars_, gc.get()); + } + timeline.Pause(); + auto time = timeline.ElapsedUS(); + op_total_time[op_idx] += time; + if (time > op_max_time[op_idx]) { + op_max_time[op_idx] = time; + } + if (time < op_min_time[op_idx]) { + op_min_time[op_idx] = time; + } + op_count[op_idx] += 1; + op_total_time[op_idx] += time; + } + op_idx++; + } + } catch (platform::EOFException&) { + std::unique_lock lk(thread_mutex); + threads_completed = true; + VLOG(3) << "thread " << thread_id_ << " completed."; + VLOG(3) << "called notify all"; + thread_condition.notify_all(); + VLOG(0) << "EOF encountered"; + VLOG(0) << "============timeline============"; + for (size_t i = 0; i < ops_.size(); ++i) { + VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i] + << ", min_time: " << op_min_time[i] + << ", mean_time: " << op_total_time[i] / op_count[i]; + } + VLOG(0) << "================================"; + return; + } + if (i == 0) { + VLOG(3) << "called notify all"; + std::unique_lock lk(thread_mutex); + batch_id_ += 1; + thread_condition.notify_all(); + } } - - for (const std::string& name : *in_var_names_) { - const LoDTensor& src_tensor = scope->FindVar(name)->Get(); - if (platform::is_gpu_place(src_tensor.place())) { - continue; + // backward pass + for (int i = 0; i < num_microbatches_; ++i) { + int op_idx = 0; + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kBackward) || + op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for scope " << i; + timeline.Start(); + op->Run(*microbatch_scopes_[i], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), + unused_vars_, gc.get()); + } + timeline.Pause(); + auto time = timeline.ElapsedUS(); + op_total_time[op_idx] += time; + if (time > op_max_time[op_idx]) { + op_max_time[op_idx] = time; + } + if (time < op_min_time[op_idx]) { + op_min_time[op_idx] = time; + } + op_count[op_idx] += 1; + op_total_time[op_idx] += time; + } + op_idx++; } - LoDTensor* gpu_tensor = exe_scope->Var(name)->GetMutable(); - gpu_tensor->set_lod(src_tensor.lod()); - TensorCopy(*static_cast(&src_tensor), place_, *dev_ctx_, - static_cast(gpu_tensor)); } - trans_timer.Pause(); - } - - SEC_LOG << "begin running ops"; - cal_timer.Resume(); - int op_id = 0; - dev_ctx_->Wait(); - for (auto& op : ops_) { - timeline.Start(); - op->Run(*exe_scope, place_); + // update pass + int op_idx = 0; + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kOptimize)) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for minibatch scope"; + timeline.Start(); + op->Run(*microbatch_scopes_[0], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], + op.get(), unused_vars_, gc.get()); + } + timeline.Pause(); + auto time = timeline.ElapsedUS(); + op_total_time[op_idx] += time; + if (time > op_max_time[op_idx]) { + op_max_time[op_idx] = time; + } + if (time < op_min_time[op_idx]) { + op_min_time[op_idx] = time; + } + op_count[op_idx] += 1; + op_total_time[op_idx] += time; + } + op_idx++; + } dev_ctx_->Wait(); - timeline.Pause(); - op_total_time[op_id++] += timeline.ElapsedUS(); + batch_timer.Pause(); + VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); } - exe_scope->DropKids(); - // Wait for GPU calc finising, as the cudaMemcpy and GPU calc may be in - // different streams - // No effect when it is a CPUDeviceContext - dev_ctx_->Wait(); - cal_timer.Pause(); -#ifdef PADDLE_WITH_BOX_PS - auto box_ptr = BoxWrapper::GetInstance(); - auto& metric_list = box_ptr->GetMetricList(); - for (auto iter = metric_list.begin(); iter != metric_list.end(); iter++) { - auto* metric_msg = iter->second; - if (box_ptr->Phase() != metric_msg->MetricPhase()) { - continue; + } else { + while (true) { + { + PADDLE_ENFORCE_LE( + local_batch_id_, batch_id_, + platform::errors::InvalidArgument( + "local_batch_id_ (%d) must be less than or equal to " + "batch_id_ (%d)", + local_batch_id_, batch_id_)); + std::unique_lock lk(thread_mutex); + if (local_batch_id_ == batch_id_ && !threads_completed) { + thread_condition.wait(lk); + } + VLOG(3) << "thread " << thread_id_ << " local_batch_id_ " + << local_batch_id_ << " batch_id_ " << batch_id_; + if (threads_completed) { + VLOG(3) << "thread " << thread_id_ << " completed."; + lk.unlock(); + VLOG(0) << "============timeline============"; + for (size_t i = 0; i < ops_.size(); ++i) { + VLOG(0) << "op: " << op_name[i] << ", max_time: " << op_max_time[i] + << ", min_time: " << op_min_time[i] + << ", mean_time: " << op_total_time[i] / op_count[i]; + } + VLOG(0) << "================================"; + threads_completed = false; + return; + } + lk.unlock(); + local_batch_id_ += 1; } - metric_msg->add_data(exe_scope); - } -#endif - if (need_dump_field_) { - DumpField(*scope, dump_mode_, dump_interval_); - } - if (need_dump_param_ && pipeline_id_ == 0) { - DumpParam(*scope, step_cnt); - } - - if (section_id_ != section_num_ - 1 && platform::is_gpu_place(place_)) { - // FIXME: Temporarily we assume two adjacent sections are in different - // places, - // and we do data transformation only in sections in GPU place, so the - // data is - // transform from GPU to CPU - // A better way to handle such a data transformation is to record each - // place of - // joint-out variables, and do transform as required - - SEC_LOG << "GPU2CPU memory copy"; - trans_timer.Resume(); - for (const std::string& name : *out_var_names_) { - const LoDTensor& src_tensor = - exe_scope->FindVar(name)->Get(); - LoDTensor* dst_tensor = scope->Var(name)->GetMutable(); - dst_tensor->set_lod(src_tensor.lod()); - TensorCopy(*static_cast(&src_tensor), - next_section_place_, *dev_ctx_, - static_cast(dst_tensor)); + // forward pass: + for (int i = 0; i < num_microbatches_; ++i) { + int op_idx = 0; + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + // We run op with op_role = kLRSched only for the first microbatch + // to avoid increasing the @LR_DECAY_STEP@ multiple times. + bool run_first_mbatch = + op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)) || + op_role == static_cast(OpRole::kLRSched); + bool run_others = op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)); + if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for scope " << i; + timeline.Start(); + op->Run(*microbatch_scopes_[i], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), + unused_vars_, gc.get()); + } + timeline.Pause(); + auto time = timeline.ElapsedUS(); + op_total_time[op_idx] += time; + if (time > op_max_time[op_idx]) { + op_max_time[op_idx] = time; + } + if (time < op_min_time[op_idx]) { + op_min_time[op_idx] = time; + } + op_count[op_idx] += 1; + op_total_time[op_idx] += time; + } + op_idx++; + } } - trans_timer.Pause(); - } - - out_scope_queue_->Send(scope); - - if (sync_func_) { - sync_timer.Resume(); - (*sync_func_)(scope); - sync_timer.Pause(); - } - - ++step_cnt; - accum_num += batch_size; - main_timer.Pause(); - } - if (need_dump_field_ || need_dump_param_) { - writer_.Flush(); - } - outer_timer.Pause(); - - worker_count_mutex_->lock(); - --(*worker_count_); - worker_count_mutex_->unlock(); - - if (*worker_count_ <= 0) { - while (section_id_ < section_num_ - 1 && out_scope_queue_->Size()) { - sleep(1); + // backward pass + for (int i = 0; i < num_microbatches_; ++i) { + int op_idx = 0; + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kBackward) || + op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for scope " << i; + timeline.Start(); + op->Run(*microbatch_scopes_[i], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), + unused_vars_, gc.get()); + } + timeline.Pause(); + auto time = timeline.ElapsedUS(); + op_total_time[op_idx] += time; + if (time > op_max_time[op_idx]) { + op_max_time[op_idx] = time; + } + if (time < op_min_time[op_idx]) { + op_min_time[op_idx] = time; + } + op_count[op_idx] += 1; + op_total_time[op_idx] += time; + } + op_idx++; + } + } + // update pass + int op_idx = 0; + for (auto& op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kOptimize)) { + VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + << " for minibatch scope"; + timeline.Start(); + op->Run(*microbatch_scopes_[0], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], + op.get(), unused_vars_, gc.get()); + } + timeline.Pause(); + auto time = timeline.ElapsedUS(); + op_total_time[op_idx] += time; + if (time > op_max_time[op_idx]) { + op_max_time[op_idx] = time; + } + if (time < op_min_time[op_idx]) { + op_min_time[op_idx] = time; + } + op_count[op_idx] += 1; + op_total_time[op_idx] += time; + } + op_idx++; + } + dev_ctx_->Wait(); } - out_scope_queue_->Close(); - } - LOG(ERROR) << "log_for_profile" - << " card:" << pipeline_id_ << " thread:" << thread_id_ - << " section:" << section_id_ << " step_count:" << step_cnt - << " batch_count:" << accum_num - << " read_time:" << reader_timer.ElapsedUS() - << " trans_time:" << trans_timer.ElapsedUS() - << " cal_time:" << cal_timer.ElapsedUS() - << " sync_time:" << sync_timer.ElapsedUS() - << " main_time:" << main_timer.ElapsedUS() - << " outer_time:" << outer_timer.ElapsedUS(); - for (size_t i = 0; i < ops_.size(); ++i) { - LOG(ERROR) << "op: " << op_name[i] - << ", mean time: " << op_total_time[i] / accum_num; } } } // namespace framework diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index c18ea33d041b9518fb60d2453830de8e4b4ff033..bb56b3ea3d251d53d6e8e494ec1c658574c2e96c 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -137,49 +137,31 @@ class PipelineTrainer : public TrainerBase { virtual Scope* GetWorkerScope(int thread_id); void InitDumpEnv() override; virtual std::string GetDumpPath(int tid); + void GetSkipVars(int section_id, const ProgramDesc& main_program); protected: int section_num_; - int pipeline_num_; - int scope_queue_size_; - int sync_steps_; + int num_microbatches_; + int start_cpu_core_id_; + std::vector feed_var_names_; + std::vector places_; + std::vector> skip_vars_; + TrainerDesc trainer_desc_; - SectionWorkerParameter pipeline_config_; - - // The in/output var names for each section - std::vector>> in_var_names_; - std::vector>> out_var_names_; - - // Counter for the running thread - std::vector> worker_count_; - std::vector>> worker_count_mutex_; - - // worker: [section_id][pipeline_id][thread_id] - std::vector>>> - workers_; std::vector section_threads_; - - // We use scope to maintain context info, and scopes - // will be deliverd between different sections. - std::vector>> scope_queues_; - std::vector pipeline_scopes_; - - // The parameters that should be syncronized between different cards using - // nccl all-reduce - std::shared_ptr> param_need_sync_; - std::vector persistable_vars_; - std::vector> sync_functors_; - std::shared_ptr nccl_ctx_map_; - - std::vector readers_; - - void InitFirstScopeQueue(ScopeQueue* scope_queue, int pipeline_id, - const ProgramDesc& main_program, - const Scope& root_scope); - void CopyParameters(const Scope& root_scope, int pipeline_id); - void construct_sync_functor(); + // worker: [section_id] + std::vector> workers_; + // minibatch_scopes_: [section_id] + std::vector minibatch_scopes_; + // microbatch_scopes_: [section_id][microbatch_id] + std::vector> microbatch_scopes_; + + void CopyParameters(int section_id, int microbatch_id, + const ProgramDesc& program, const platform::Place& place); + bool isPersistableVarGrad(std::string name); + bool isPersistable(VarDesc* var); }; #endif + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 9cbb063a3fab6810709c1504deed2ccf40743123..670ae074c7c7f0e3bcd91e157ba7b01b48d3b7ee 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -83,6 +83,7 @@ message SectionWorkerParameter { optional int64 sync_steps = 3 [ default = 1 ]; optional int32 start_cpu_core_id = 4 [ default = 1 ]; repeated string param_need_sync = 5; + optional int32 num_microbatches = 6; } message SectionConfig { @@ -99,6 +100,7 @@ message SectionConfig { optional int32 concurrency = 3 [ default = 1 ]; repeated string section_in_var_names = 4; repeated string section_out_var_names = 5; + optional int32 place_id = 6 [ default = -1 ]; } message FetchConfig { diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index ebc304eab7d6a8b59a81fa2cc4244fb81bf3b1a4..72e0351ec36c028593eb8f099a4e39aa314aac37 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -403,11 +403,8 @@ class Section(DeviceWorker): trainer_desc.device_worker_name = "SectionWorker" pipeline_opt = self._program._pipeline_opt section_param = trainer_desc.section_param - section_param.queue_size = pipeline_opt["queue_size"] - section_param.sync_steps = pipeline_opt["sync_steps"] + section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] - for e in pipeline_opt["param_need_sync"]: - section_param.param_need_sync.append(e) for i, program in enumerate(pipeline_opt["section_program_list"]): cfg = section_param.section_config.add() cfg.program_desc.ParseFromString(program["program"]._get_desc() @@ -415,6 +412,7 @@ class Section(DeviceWorker): # TODO: why does not work # cfg.program_desc.CopyFrom(program.program._get_desc()) place = pipeline_opt["place_list"][i] + place_id = pipeline_opt["place_id_list"][i] if isinstance(place, core.CPUPlace): cfg.place = cfg.CPUPlace elif isinstance(place, core.CUDAPlace): @@ -425,12 +423,7 @@ class Section(DeviceWorker): raise NotImplementedError( "SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now." ) - - cfg.concurrency = pipeline_opt["concurrency_list"][i] - for var in program["input_set"]: - cfg.section_in_var_names.append(var) - for var in program["output_set"]: - cfg.section_out_var_names.append(var) + cfg.place_id = place_id class DeviceWorkerFactory(object): diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 6b145105132b91241fdf0f791355fd04ce891e2c..270208120ccb84ead53bcf72903adcaf52901316 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4474,7 +4474,7 @@ class PipelineOptimizer(object): "place_list": place_list, "place_id_list": place_id_list, "sync_steps": -1, - "queue_size": self._num_microbatches, + "num_microbatches": self._num_microbatches, "start_cpu_core_id": self._start_cpu_core_id, } return optimize_ops, params_grads, program_list diff --git a/python/paddle/fluid/tests/unittests/test_pipeline.py b/python/paddle/fluid/tests/unittests/test_pipeline.py index 40a1e0b0248caafff1517da67474db2b1b2c6d9a..1f884195a47f19ca0c69912dfa68cf608317ddc8 100644 --- a/python/paddle/fluid/tests/unittests/test_pipeline.py +++ b/python/paddle/fluid/tests/unittests/test_pipeline.py @@ -100,7 +100,7 @@ def build_network(input, layers=50, class_dim=1000): pool_type='max') if layers >= 50: for block in range(len(depth)): - with fluid.device_guard("cpu"): + with fluid.device_guard("gpu:0"): for i in range(depth[block]): conv = bottleneck_block( input=conv, @@ -118,7 +118,7 @@ def build_network(input, layers=50, class_dim=1000): initializer=fluid.initializer.Uniform(-stdv, stdv))) else: for block in range(len(depth)): - with fluid.device_guard("cpu"): + with fluid.device_guard("gpu:0"): for i in range(depth[block]): conv = basic_block( input=conv, @@ -140,38 +140,68 @@ def build_network(input, layers=50, class_dim=1000): class TestPipeline(unittest.TestCase): """ TestCases for Pipeline Training. """ + def _run(self, debug): + main_prog = fluid.Program() + startup_prog = fluid.Program() + with fluid.program_guard(main_prog, startup_prog): + with fluid.device_guard("cpu"): + image = fluid.layers.data( + name="image", shape=[3, 224, 224], dtype="float32") + label = fluid.layers.data( + name="label", shape=[1], dtype="int64") + data_loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=64, + use_double_buffer=True, + iterable=False) + fc = build_network(image, layers=50) + with fluid.device_guard("gpu:0"): + out, prob = fluid.layers.softmax_with_cross_entropy( + logits=fc, label=label, return_softmax=True) + loss = fluid.layers.mean(out) + acc_top1 = fluid.layers.accuracy(input=prob, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=prob, label=label, k=5) + + base_lr = 0.1 + passes = [30, 60, 80, 90] + total_images = 1281167 + steps_per_pass = total_images // 128 + bd = [steps_per_pass * p for p in passes] + lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] + lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr) + optimizer = fluid.optimizer.MomentumOptimizer( + lr_val, + momentum=0.9, + regularization=fluid.regularizer.L2Decay(1e-4)) + optimizer = fluid.optimizer.PipelineOptimizer( + optimizer, num_microbatches=2) + optimizer.minimize(loss) + + def train_reader(): + for _ in range(4): + img = np.random.random(size=[3, 224, 224]).astype('float32') + label = np.random.random(size=[1]).astype('int64') + yield img, label + + data_loader.set_sample_generator(train_reader, batch_size=1) + place = fluid.CPUPlace() + + # The following dataset is only used for the + # interface 'train_from_dataset'. + # And it has no actual meaning. + dataset = fluid.DatasetFactory().create_dataset('FileInstantDataset') + dataset.set_batch_size(1) + dataset.set_thread(1) + dataset.set_filelist(['/tmp/tmp_2.txt']) + dataset.set_use_var([image, label]) + exe = fluid.Executor(place) + exe.run(startup_prog) + data_loader.start() + exe.train_from_dataset(main_prog, dataset, debug=debug) + def test_pipeline(self): - with fluid.device_guard("cpu"): - image = fluid.layers.data( - name="image", shape=[3, 224, 224], dtype="float32") - label = fluid.layers.data(name="label", shape=[1], dtype="int64") - data_loader = fluid.io.DataLoader.from_generator( - feed_list=[image, label], - capacity=64, - use_double_buffer=True, - iterable=False) - fc = build_network(image, layers=50) - with fluid.device_guard("gpu:0"): - out, prob = fluid.layers.softmax_with_cross_entropy( - logits=fc, label=label, return_softmax=True) - loss = fluid.layers.mean(out) - acc_top1 = fluid.layers.accuracy(input=prob, label=label, k=1) - acc_top5 = fluid.layers.accuracy(input=prob, label=label, k=5) - - base_lr = 0.1 - passes = [30, 60, 80, 90] - total_images = 1281167 - steps_per_pass = total_images // 128 - bd = [steps_per_pass * p for p in passes] - lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] - lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr) - optimizer = fluid.optimizer.Momentum( - lr_val, - momentum=0.9, - regularization=fluid.regularizer.L2Decay(1e-4)) - optimizer = fluid.optimizer.PipelineOptimizer( - optimizer, num_microbatches=2) - optimizer.minimize(loss) + self._run(False) + self._run(True) def test_pipeline_noneoptimizer(self): with fluid.device_guard("gpu:0"):