提交 08c25995 编写于 作者: D dongdaxiang

add run from dataset in executor.

上级 c28bbdf8
...@@ -54,6 +54,7 @@ class Executor { ...@@ -54,6 +54,7 @@ class Executor {
explicit Executor(const platform::Place& place); explicit Executor(const platform::Place& place);
explicit Executor(Scope* scope, const platform::Place& place);
/* /*
* Close this Executor. * Close this Executor.
* Calling this method will send complete messages to all pserver instances. * Calling this method will send complete messages to all pserver instances.
...@@ -110,8 +111,20 @@ class Executor { ...@@ -110,8 +111,20 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program); void EnableMKLDNN(const ProgramDesc& program);
void RunFromTrainerDesc(const ProgramDesc& main_program,
const std::string& trainer_desc_str,
const bool debug);
void RunFromDataset(const ProgramDesc& main_program, const Dataset* dataset,
const std::string& trainer_desc_str, const bool debug);
public:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
Scope* root_scope_;
private: private:
const platform::Place place_; const platform::Place place_;
int actual_thread_num_;
}; };
} // namespace framework } // namespace framework
......
...@@ -21,25 +21,34 @@ limitations under the License. */ ...@@ -21,25 +21,34 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc) { void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here // get filelist from trainer_desc here
workers_.resize(thread_num_); workers_.resize(thread_num_);
if (NULL == dataset) {
readers_.resize(thread_num_); readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) { for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
readers_[i] = readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name()); DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
workers_[i]->SetDeviceIndex(i);
readers_[i]->Init(trainer_desc.data_desc()); readers_[i]->Init(trainer_desc.data_desc());
workers_[i]->SetDataFeed(readers_[i]);
} }
std::vector<std::string> filelist_vec; std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) { for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i)); filelist_vec.push_back(trainer_desc.filelist(i));
} }
readers_[0]->SetFileList(filelist_vec); readers_[0]->SetFileList(filelist_vec);
} else {
// readers_ = dataset.get_readers(); ?
}
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
workers_[i]->SetDeviceIndex(i);
workers_[i]->SetDataFeed(readers_[i]);
}
} }
// call only after all resources are set in current trainer // call only after all resources are set in current trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册