From 08c25995a2eacfa4dc8fcecff5080ace6e9e43f6 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Wed, 6 Mar 2019 09:55:38 +0800 Subject: [PATCH] add run from dataset in executor. --- paddle/fluid/framework/executor.h | 13 +++++++++++ paddle/fluid/framework/multi_trainer.cc | 29 ++++++++++++++++--------- 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 825224437..48aeb151d 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -54,6 +54,7 @@ class Executor { explicit Executor(const platform::Place& place); + explicit Executor(Scope* scope, const platform::Place& place); /* * Close this Executor. * Calling this method will send complete messages to all pserver instances. @@ -110,8 +111,20 @@ class Executor { void EnableMKLDNN(const ProgramDesc& program); + void RunFromTrainerDesc(const ProgramDesc& main_program, + const std::string& trainer_desc_str, + const bool debug); + + void RunFromDataset(const ProgramDesc& main_program, const Dataset* dataset, + const std::string& trainer_desc_str, const bool debug); + + public: + std::shared_ptr fleet_ptr_; + Scope* root_scope_; + private: const platform::Place place_; + int actual_thread_num_; }; } // namespace framework diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 6c9fa9608..d1ade19f5 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -21,25 +21,34 @@ limitations under the License. */ namespace paddle { namespace framework { -void MultiTrainer::Initialize(const TrainerDesc& trainer_desc) { +void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, + Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); // get filelist from trainer_desc here workers_.resize(thread_num_); - readers_.resize(thread_num_); + + if (NULL == dataset) { + readers_.resize(thread_num_); + for (int i = 0; i < thread_num_; ++i) { + readers_[i] = + DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name()); + readers_[i]->Init(trainer_desc.data_desc()); + } + std::vector filelist_vec; + for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) { + filelist_vec.push_back(trainer_desc.filelist(i)); + } + readers_[0]->SetFileList(filelist_vec); + } else { + // readers_ = dataset.get_readers(); ? + } + for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); - readers_[i] = - DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name()); workers_[i]->SetDeviceIndex(i); - readers_[i]->Init(trainer_desc.data_desc()); workers_[i]->SetDataFeed(readers_[i]); } - std::vector filelist_vec; - for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) { - filelist_vec.push_back(trainer_desc.filelist(i)); - } - readers_[0]->SetFileList(filelist_vec); } // call only after all resources are set in current trainer -- GitLab