提交 08c25995 编写于 作者: D dongdaxiang

add run from dataset in executor.

上级 c28bbdf8
......@@ -54,6 +54,7 @@ class Executor {
explicit Executor(const platform::Place& place);
explicit Executor(Scope* scope, const platform::Place& place);
/*
* Close this Executor.
* Calling this method will send complete messages to all pserver instances.
......@@ -110,8 +111,20 @@ class Executor {
void EnableMKLDNN(const ProgramDesc& program);
void RunFromTrainerDesc(const ProgramDesc& main_program,
const std::string& trainer_desc_str,
const bool debug);
void RunFromDataset(const ProgramDesc& main_program, const Dataset* dataset,
const std::string& trainer_desc_str, const bool debug);
public:
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
Scope* root_scope_;
private:
const platform::Place place_;
int actual_thread_num_;
};
} // namespace framework
......
......@@ -21,25 +21,34 @@ limitations under the License. */
namespace paddle {
namespace framework {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset* dataset) {
thread_num_ = trainer_desc.thread_num();
// get filelist from trainer_desc here
workers_.resize(thread_num_);
readers_.resize(thread_num_);
if (NULL == dataset) {
readers_.resize(thread_num_);
for (int i = 0; i < thread_num_; ++i) {
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
readers_[i]->Init(trainer_desc.data_desc());
}
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
readers_[0]->SetFileList(filelist_vec);
} else {
// readers_ = dataset.get_readers(); ?
}
for (int i = 0; i < thread_num_; ++i) {
workers_[i] = DeviceWorkerFactory::CreateDeviceWorker(
trainer_desc.device_worker_name());
readers_[i] =
DataFeedFactory::CreateDataFeed(trainer_desc.data_desc().name());
workers_[i]->SetDeviceIndex(i);
readers_[i]->Init(trainer_desc.data_desc());
workers_[i]->SetDataFeed(readers_[i]);
}
std::vector<std::string> filelist_vec;
for (unsigned i = 0; i < trainer_desc.filelist_size(); ++i) {
filelist_vec.push_back(trainer_desc.filelist(i));
}
readers_[0]->SetFileList(filelist_vec);
}
// call only after all resources are set in current trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册