diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 45a914b70eaea45a394e654794ea1a12c48dbe69..f0ca375f950dc98533a4cd79e372eadef9770dc2 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -191,18 +191,19 @@ void AsyncExecutor::SaveModel(const std::string& path) { } } -void AsyncExecutor::PrepareDenseThread() { - DensePullThreadParam param; - param.ps_client = _pslib_ptr->_worker_ptr;; - param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO - param.training_thread_num = actual_thread_num; - param.root_scope = root_scope_; - //param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO - param.dense_params = &_param_config.dense_variable_name; - - _pull_dense_thread = std::shared_ptr(new DensePullThread(param)); - _pull_dense_thread->start(); - +void AsyncExecutor::PrepareDenseThread(const std::string& mode) { + if (mode == "mpi") { + DensePullThreadParam param; + param.ps_client = _pslib_ptr->_worker_ptr;; + param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO + param.training_thread_num = actual_thread_num; + param.root_scope = root_scope_; + //param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO + param.dense_params = &_param_config.dense_variable_name; + + _pull_dense_thread = std::shared_ptr(new DensePullThread(param)); + _pull_dense_thread->start(); + } } void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, @@ -210,6 +211,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, const std::vector& filelist, const int thread_num, const std::vector& fetch_var_names, + const std::string& mode, const bool debug) { std::vector threads; @@ -251,11 +253,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, // todo: should be factory method for creating datafeed std::vector> readers; PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); - PrepareDenseThread(); + PrepareDenseThread(mode); std::vector> workers; workers.resize(actual_thread_num); for (auto& worker : workers) { - worker.reset(new AsyncExecutorThreadWorker); + if (mode == "mpi") { + worker.reset(new AsyncExecutorThreadWorker); + } else { + worker.reset(new ExecutorThreadWorker); + } } // prepare thread resource here @@ -274,7 +280,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, for (auto& th : threads) { th.join(); } - _pull_dense_thread->stop(); + if (mode == "mpi") { + _pull_dense_thread->stop(); + } root_scope_->DropKids(); return; diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h index 4b461262173fb8828a4bbf7ebad34715a2f4fa96..93010f8a9b06f8edb0d028a058390bb22e21b0f7 100644 --- a/paddle/fluid/framework/async_executor.h +++ b/paddle/fluid/framework/async_executor.h @@ -61,6 +61,7 @@ class AsyncExecutor { const std::vector& filelist, const int thread_num, const std::vector& fetch_names, + const std::string& mode, const bool debug = false); //void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index); void InitServer(const std::string& dist_desc, int index); @@ -79,7 +80,7 @@ class AsyncExecutor { const std::vector& fetch_var_names, Scope* root_scope, const int thread_index, const bool debug); - void PrepareDenseThread(); + void PrepareDenseThread(const std::string& mode); public: std::shared_ptr _pslib_ptr; std::shared_ptr _pull_dense_thread; diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py index cce7ec5cca1151adab2326094450f0bcb06ab504..e760d58fd22f2be0448ffcad2cf76145f3755e61 100644 --- a/python/paddle/fluid/async_executor.py +++ b/python/paddle/fluid/async_executor.py @@ -87,9 +87,8 @@ class AsyncExecutor(object): scope = global_scope() self.executor = core.AsyncExecutor(scope, p) - self.instance = ps_instance.PaddlePSInstance(1, 2) - def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): + def run(self, program, data_feed, filelist, thread_num, fetch, mode="", debug=False): """ Run program by this AsyncExecutor. Training dataset will be in filelist. Users can also inspect certain variables by naming them in parameter @@ -151,10 +150,11 @@ class AsyncExecutor(object): self.executor.run_from_files(program_desc, data_feed.desc(), filelist, thread_num, - fetch_var_names, debug) + fetch_var_names, mode, debug) def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12): - hadoop_home = "$HADOOP_HOME" + #hadoop_home = "$HADOOP_HOME" + hadoop_home = "~/tools/hadoop-xingtian/hadoop/" configs = { "fs.default.name": fs_default_name, @@ -169,8 +169,11 @@ class AsyncExecutor(object): self.instance.get_worker_index(), self.instance.get_node_cnt() / 2, multi_processes=process_num) + self.instance.barrier_all() #wait for download_data #TODO only barriere worker - def config_distributed_nodes(self, dist_opt): + def config_distributed_nodes(self): + self.instance = ps_instance.PaddlePSInstance(1, 2) + return self.instance # get total rank # get rank index @@ -196,11 +199,15 @@ class AsyncExecutor(object): self.executor.gather_servers(ips, self.instance.get_node_cnt()) self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait init model - self.instance.barrier_all() #wait for download_data + self.instance.barrier_all() #wait for download_data #TODO remove this after only barrier worker self.instance.barrier_all() #wait worker do all things self.instance.barrier_all() #sync - def init_worker(self, dist_desc, afs_path, local_path, fs_default_name, ugi): + def init_worker(self, dist_desc, startup_program): + place = core.CPUPlace() + executor = Executor(place) + executor.run(startup_program) + self.instance.barrier_all() #wait all server start ips = self.instance.gather_ips() self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid) @@ -208,8 +215,6 @@ class AsyncExecutor(object): if self.instance.is_first_worker(): self.executor.init_model() self.instance.barrier_all() #wait init model - self.download_data(afs_path, local_path, fs_default_name, ugi, process_num=12) - self.instance.barrier_all() #wait for download_data def init_model(self): self.executor.init_model()