diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 825224437e0cdda03c56faf1b50833abd8b8c2ab..48aeb151d57aa27ac88419b7a83b4aafe1163c22 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -54,6 +54,7 @@ class Executor { explicit Executor(const platform::Place& place); + explicit Executor(Scope* scope, const platform::Place& place); /* * Close this Executor. * Calling this method will send complete messages to all pserver instances. @@ -110,8 +111,20 @@ class Executor { void EnableMKLDNN(const ProgramDesc& program); + void RunFromTrainerDesc(const ProgramDesc& main_program, + const std::string& trainer_desc_str, + const bool debug); + + void RunFromDataset(const ProgramDesc& main_program, const Dataset* dataset, + const std::string& trainer_desc_str, const bool debug); + + public: + std::shared_ptr fleet_ptr_; + Scope* root_scope_; + private: const platform::Place place_; + int actual_thread_num_; }; } // namespace framework diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 6c9fa960841fd95f85eee9a227f43efd93af553f..d1ade19f56017a77840fb699f8963979818058f5 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -21,25 +21,34 @@ limitations under the License. */ namespace paddle { namespace framework { -void MultiTrainer::Initialize(const TrainerDesc& trainer_desc) { +void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, + Dataset* dataset) { thread_num_ = trainer_desc.thread_num(); // get filelist from trainer_desc here workers_.resize(thread_num_); - readers_.resize(thread_num_); + + if (NULL == dataset) { + readers_.resize(thread_num_); + for (int i = 0; i < thread_num_; ++i) { + readers_[i] = + DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name()); + readers_[i]->Init(trainer_desc.data_desc()); + } + std::vector filelist_vec; + for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) { + filelist_vec.push_back(trainer_desc.filelist(i)); + } + readers_[0]->SetFileList(filelist_vec); + } else { + // readers_ = dataset.get_readers(); ? + } + for (int i = 0; i < thread_num_; ++i) { workers_[i] = DeviceWorkerFactory::CreateDeviceWorker( trainer_desc.device_worker_name()); - readers_[i] = - DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name()); workers_[i]->SetDeviceIndex(i); - readers_[i]->Init(trainer_desc.data_desc()); workers_[i]->SetDataFeed(readers_[i]); } - std::vector filelist_vec; - for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) { - filelist_vec.push_back(trainer_desc.filelist(i)); - } - readers_[0]->SetFileList(filelist_vec); } // call only after all resources are set in current trainer