提交 10ed9e0a 编写于 作者: H heqiaozhi

download & run & instance

上级 57ac412b
...@@ -191,18 +191,19 @@ void AsyncExecutor::SaveModel(const std::string& path) { ...@@ -191,18 +191,19 @@ void AsyncExecutor::SaveModel(const std::string& path) {
} }
} }
void AsyncExecutor::PrepareDenseThread() { void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
DensePullThreadParam param; if (mode == "mpi") {
param.ps_client = _pslib_ptr->_worker_ptr;; DensePullThreadParam param;
param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO param.ps_client = _pslib_ptr->_worker_ptr;;
param.training_thread_num = actual_thread_num; param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO
param.root_scope = root_scope_; param.training_thread_num = actual_thread_num;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO param.root_scope = root_scope_;
param.dense_params = &_param_config.dense_variable_name; //param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start(); _pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
} }
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...@@ -210,6 +211,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -210,6 +211,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
const std::string& mode,
const bool debug) { const bool debug) {
std::vector<std::thread> threads; std::vector<std::thread> threads;
...@@ -251,11 +253,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -251,11 +253,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed // todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers; std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
PrepareDenseThread(); PrepareDenseThread(mode);
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers; std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num); workers.resize(actual_thread_num);
for (auto& worker : workers) { for (auto& worker : workers) {
worker.reset(new AsyncExecutorThreadWorker); if (mode == "mpi") {
worker.reset(new AsyncExecutorThreadWorker);
} else {
worker.reset(new ExecutorThreadWorker);
}
} }
// prepare thread resource here // prepare thread resource here
...@@ -274,7 +280,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, ...@@ -274,7 +280,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) { for (auto& th : threads) {
th.join(); th.join();
} }
_pull_dense_thread->stop(); if (mode == "mpi") {
_pull_dense_thread->stop();
}
root_scope_->DropKids(); root_scope_->DropKids();
return; return;
......
...@@ -61,6 +61,7 @@ class AsyncExecutor { ...@@ -61,6 +61,7 @@ class AsyncExecutor {
const std::vector<std::string>& filelist, const std::vector<std::string>& filelist,
const int thread_num, const int thread_num,
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
const std::string& mode,
const bool debug = false); const bool debug = false);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index); //void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void InitServer(const std::string& dist_desc, int index); void InitServer(const std::string& dist_desc, int index);
...@@ -79,7 +80,7 @@ class AsyncExecutor { ...@@ -79,7 +80,7 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_var_names, const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index, Scope* root_scope, const int thread_index,
const bool debug); const bool debug);
void PrepareDenseThread(); void PrepareDenseThread(const std::string& mode);
public: public:
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr; std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread; std::shared_ptr<DensePullThread> _pull_dense_thread;
......
...@@ -87,9 +87,8 @@ class AsyncExecutor(object): ...@@ -87,9 +87,8 @@ class AsyncExecutor(object):
scope = global_scope() scope = global_scope()
self.executor = core.AsyncExecutor(scope, p) self.executor = core.AsyncExecutor(scope, p)
self.instance = ps_instance.PaddlePSInstance(1, 2)
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): def run(self, program, data_feed, filelist, thread_num, fetch, mode="", debug=False):
""" """
Run program by this AsyncExecutor. Training dataset will be in filelist. Run program by this AsyncExecutor. Training dataset will be in filelist.
Users can also inspect certain variables by naming them in parameter Users can also inspect certain variables by naming them in parameter
...@@ -151,10 +150,11 @@ class AsyncExecutor(object): ...@@ -151,10 +150,11 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc, self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num, data_feed.desc(), filelist, thread_num,
fetch_var_names, debug) fetch_var_names, mode, debug)
def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12): def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12):
hadoop_home = "$HADOOP_HOME" #hadoop_home = "$HADOOP_HOME"
hadoop_home = "~/tools/hadoop-xingtian/hadoop/"
configs = { configs = {
"fs.default.name": fs_default_name, "fs.default.name": fs_default_name,
...@@ -169,8 +169,11 @@ class AsyncExecutor(object): ...@@ -169,8 +169,11 @@ class AsyncExecutor(object):
self.instance.get_worker_index(), self.instance.get_worker_index(),
self.instance.get_node_cnt() / 2, self.instance.get_node_cnt() / 2,
multi_processes=process_num) multi_processes=process_num)
self.instance.barrier_all() #wait for download_data #TODO only barriere worker
def config_distributed_nodes(self, dist_opt): def config_distributed_nodes(self):
self.instance = ps_instance.PaddlePSInstance(1, 2)
return self.instance
# get total rank # get total rank
# get rank index # get rank index
...@@ -196,11 +199,15 @@ class AsyncExecutor(object): ...@@ -196,11 +199,15 @@ class AsyncExecutor(object):
self.executor.gather_servers(ips, self.instance.get_node_cnt()) self.executor.gather_servers(ips, self.instance.get_node_cnt())
self.instance.barrier_all() #wait all worker start self.instance.barrier_all() #wait all worker start
self.instance.barrier_all() #wait init model self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait for download_data self.instance.barrier_all() #wait for download_data #TODO remove this after only barrier worker
self.instance.barrier_all() #wait worker do all things self.instance.barrier_all() #wait worker do all things
self.instance.barrier_all() #sync self.instance.barrier_all() #sync
def init_worker(self, dist_desc, afs_path, local_path, fs_default_name, ugi): def init_worker(self, dist_desc, startup_program):
place = core.CPUPlace()
executor = Executor(place)
executor.run(startup_program)
self.instance.barrier_all() #wait all server start self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips() ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid) self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
...@@ -208,8 +215,6 @@ class AsyncExecutor(object): ...@@ -208,8 +215,6 @@ class AsyncExecutor(object):
if self.instance.is_first_worker(): if self.instance.is_first_worker():
self.executor.init_model() self.executor.init_model()
self.instance.barrier_all() #wait init model self.instance.barrier_all() #wait init model
self.download_data(afs_path, local_path, fs_default_name, ugi, process_num=12)
self.instance.barrier_all() #wait for download_data
def init_model(self): def init_model(self):
self.executor.init_model() self.executor.init_model()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册