提交 10ed9e0a 编写于 作者: H heqiaozhi

download & run & instance

上级 57ac412b
......@@ -191,18 +191,19 @@ void AsyncExecutor::SaveModel(const std::string& path) {
}
}
void AsyncExecutor::PrepareDenseThread() {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;;
param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if (mode == "mpi") {
DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;;
param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO
param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
_pull_dense_thread->start();
}
}
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
......@@ -210,6 +211,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_var_names,
const std::string& mode,
const bool debug) {
std::vector<std::thread> threads;
......@@ -251,11 +253,15 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
PrepareDenseThread();
PrepareDenseThread(mode);
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num);
for (auto& worker : workers) {
worker.reset(new AsyncExecutorThreadWorker);
if (mode == "mpi") {
worker.reset(new AsyncExecutorThreadWorker);
} else {
worker.reset(new ExecutorThreadWorker);
}
}
// prepare thread resource here
......@@ -274,7 +280,9 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) {
th.join();
}
_pull_dense_thread->stop();
if (mode == "mpi") {
_pull_dense_thread->stop();
}
root_scope_->DropKids();
return;
......
......@@ -61,6 +61,7 @@ class AsyncExecutor {
const std::vector<std::string>& filelist,
const int thread_num,
const std::vector<std::string>& fetch_names,
const std::string& mode,
const bool debug = false);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void InitServer(const std::string& dist_desc, int index);
......@@ -79,7 +80,7 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index,
const bool debug);
void PrepareDenseThread();
void PrepareDenseThread(const std::string& mode);
public:
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
......
......@@ -87,9 +87,8 @@ class AsyncExecutor(object):
scope = global_scope()
self.executor = core.AsyncExecutor(scope, p)
self.instance = ps_instance.PaddlePSInstance(1, 2)
def run(self, program, data_feed, filelist, thread_num, fetch, debug=False):
def run(self, program, data_feed, filelist, thread_num, fetch, mode="", debug=False):
"""
Run program by this AsyncExecutor. Training dataset will be in filelist.
Users can also inspect certain variables by naming them in parameter
......@@ -151,10 +150,11 @@ class AsyncExecutor(object):
self.executor.run_from_files(program_desc,
data_feed.desc(), filelist, thread_num,
fetch_var_names, debug)
fetch_var_names, mode, debug)
def download_data(self, afs_path, local_path, fs_default_name, ugi, process_num=12):
hadoop_home = "$HADOOP_HOME"
#hadoop_home = "$HADOOP_HOME"
hadoop_home = "~/tools/hadoop-xingtian/hadoop/"
configs = {
"fs.default.name": fs_default_name,
......@@ -169,8 +169,11 @@ class AsyncExecutor(object):
self.instance.get_worker_index(),
self.instance.get_node_cnt() / 2,
multi_processes=process_num)
self.instance.barrier_all() #wait for download_data #TODO only barriere worker
def config_distributed_nodes(self, dist_opt):
def config_distributed_nodes(self):
self.instance = ps_instance.PaddlePSInstance(1, 2)
return self.instance
# get total rank
# get rank index
......@@ -196,11 +199,15 @@ class AsyncExecutor(object):
self.executor.gather_servers(ips, self.instance.get_node_cnt())
self.instance.barrier_all() #wait all worker start
self.instance.barrier_all() #wait init model
self.instance.barrier_all() #wait for download_data
self.instance.barrier_all() #wait for download_data #TODO remove this after only barrier worker
self.instance.barrier_all() #wait worker do all things
self.instance.barrier_all() #sync
def init_worker(self, dist_desc, afs_path, local_path, fs_default_name, ugi):
def init_worker(self, dist_desc, startup_program):
place = core.CPUPlace()
executor = Executor(place)
executor.run(startup_program)
self.instance.barrier_all() #wait all server start
ips = self.instance.gather_ips()
self.executor.init_worker(dist_desc, ips, self.instance.get_node_cnt(), self.instance._rankid)
......@@ -208,8 +215,6 @@ class AsyncExecutor(object):
if self.instance.is_first_worker():
self.executor.init_model()
self.instance.barrier_all() #wait init model
self.download_data(afs_path, local_path, fs_default_name, ugi, process_num=12)
self.instance.barrier_all() #wait for download_data
def init_model(self):
self.executor.init_model()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册