diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 6efe5cafe722c34184d48c60b1e05c37529eed2a..c62d62a5dc473fa7f9648f1b418fefa825711837 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -66,15 +66,20 @@ void PrepareReaders(std::vector>& readers, // NOLINT } void AsyncExecutor::InitServer(const std::string& dist_desc, int index) { - _pslib_ptr = std::shared_ptr(new paddle::distributed::PSlib()); - _pslib_ptr->init_server(dist_desc, index);//TODO done - + _pslib_ptr = + std::shared_ptr( + new paddle::distributed::PSlib()); + _pslib_ptr->init_server(dist_desc, index); InitParamConfig(); } -void AsyncExecutor::InitWorker(const std::string& dist_desc, std::vector& host_sign_list, int node_num, int index) { - _pslib_ptr = std::shared_ptr(new paddle::distributed::PSlib()); - _pslib_ptr->init_worker(dist_desc, host_sign_list.data(), node_num, index);//TODO done +void AsyncExecutor::InitWorker(const std::string& dist_desc, + const std::vector& host_sign_list, + int node_num, int index) { + _pslib_ptr = std::shared_ptr( + new paddle::distributed::PSlib()); + _pslib_ptr->init_worker( + dist_desc, host_sign_list.data(), node_num, index); InitParamConfig(); } @@ -87,43 +92,65 @@ void AsyncExecutor::StopServer() { _pslib_ptr->stop_server(); } -void AsyncExecutor::GatherServers(std::vector& host_sign_list, int node_num) { +void AsyncExecutor::GatherServers( + std::vector& host_sign_list, int node_num) { _pslib_ptr->gather_servers(host_sign_list.data(), node_num); } void AsyncExecutor::InitParamConfig() { - for (int i = 0; i < _pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param_size(); ++i) { - if (_pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param(i).table_class().find("SparseTable") != -1) { - _param_config.fea_dim = _pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param(i).accessor().fea_dim(); //TODO + for (int i = 0; i < + _pslib_ptr->get_param()->server_param().\ + downpour_server_param().\ + downpour_table_param_size(); + ++i) { + if (_pslib_ptr->get_param()->server_param().\ + downpour_server_param().downpour_table_param(i).\ + table_class().find("SparseTable") != -1) { + _param_config.fea_dim = _pslib_ptr->get_param()->server_param().\ + downpour_server_param().\ + downpour_table_param(i).\ + accessor().fea_dim(); break; } } - _param_config.slot_dim = _param_config.fea_dim - 2; //TODO - _param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); - _param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch()); - - for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); ++t) { - _param_config.skip_op.push_back(_pslib_ptr->get_param()->trainer_param().skip_op(t)); + _param_config.slot_dim = _param_config.fea_dim - 2; + _param_config.tmp_push_dense_wait_times = static_cast( + _pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); + _param_config.tmp_push_sparse_wait_times = static_cast( + _pslib_ptr->get_param()->trainer_param().push_sparse_per_batch()); + + for (auto t = 0u; + t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); + ++t) { + _param_config.skip_op.push_back( + _pslib_ptr->get_param()->trainer_param().skip_op(t)); } - //sparse - for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) { + + for (auto t = 0u; + t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); + ++t) { auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t); std::vector tmp_sparse_variable_name; for (int i = 0u; i < table.slot_value_size(); ++i) { tmp_sparse_variable_name.push_back(table.slot_value(i)); - _param_config.slot_alias_to_table[table.slot_key(i)] = table.table_id(); + _param_config.slot_alias_to_table[table.slot_key(i)] = + table.table_id(); } std::vector tmp_sparse_gradient_variable_name; for (auto i = 0u; i < table.slot_gradient_size(); ++i) { tmp_sparse_gradient_variable_name.push_back( table.slot_gradient(i)); } - _param_config.slot_input_vec[table.table_id()] = std::move(tmp_sparse_variable_name); - _param_config.gradient_var[table.table_id()] = std::move(tmp_sparse_gradient_variable_name); + _param_config.slot_input_vec[table.table_id()] = + std::move(tmp_sparse_variable_name); + _param_config.gradient_var[table.table_id()] = + std::move(tmp_sparse_gradient_variable_name); _param_config.sparse_table_id.push_back(table.table_id()); } - //dense - for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) { + + for (auto t = 0u; + t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); + ++t) { auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t); std::vector tmp_dense_variable_name; for (int i = 0u; i < table.dense_variable_name_size(); ++i) { @@ -134,20 +161,18 @@ void AsyncExecutor::InitParamConfig() { tmp_dense_gradient_variable_name.push_back( table.dense_gradient_variable_name(i)); } - _param_config.dense_variable_name[table.table_id()] = std::move(tmp_dense_variable_name); - _param_config.dense_gradient_variable_name[table.table_id()] = std::move(tmp_dense_gradient_variable_name); + _param_config.dense_variable_name[table.table_id()] = + std::move(tmp_dense_variable_name); + _param_config.dense_gradient_variable_name[table.table_id()] = + std::move(tmp_dense_gradient_variable_name); _param_config.dense_table_id.push_back(table.table_id()); - _param_config.dense_table_size.push_back(table.fea_dim()); //TODO + _param_config.dense_table_size.push_back(table.fea_dim()); } } void AsyncExecutor::InitModel() { - //TODO only rank = 0 do this - //std::vector all_dense_table_id; //TODO - //all_dense_table_id.push_back(0); //done - for (auto table_id: _param_config.dense_table_id) { + for (auto table_id : _param_config.dense_table_id) { std::vector regions; - //std::vector variables; //TODO for (auto& t : _param_config.dense_variable_name[table_id]) { Variable* var = root_scope_->FindVar(t); CHECK(var != nullptr) << "var[" << t << "] not found"; @@ -169,13 +194,15 @@ void AsyncExecutor::InitModel() { regions.emplace_back(std::move(reg)); } - auto push_status = _pslib_ptr->_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); + auto push_status = + _pslib_ptr->_worker_ptr->push_dense_param( + regions.data(), regions.size(), table_id); push_status.wait(); auto status = push_status.get(); if (status != 0) { LOG(FATAL) << "push dense param failed, status[" << status << "]"; exit(-1); - } + } } } @@ -185,7 +212,7 @@ void AsyncExecutor::SaveModel(const std::string& path) { ret = _pslib_ptr->_worker_ptr->save(path, 0); ret.wait(); int32_t feasign_cnt = ret.get(); - if (feasign_cnt == -1) { // TODO should be feasign_cnt < 0, because server bug + if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0 LOG(FATAL) << "save model failed"; exit(-1); } @@ -195,13 +222,13 @@ void AsyncExecutor::PrepareDenseThread(const std::string& mode) { if (mode == "mpi") { DensePullThreadParam param; param.ps_client = _pslib_ptr->_worker_ptr;; - param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO + param.threshold = 1; param.training_thread_num = actual_thread_num; param.root_scope = root_scope_; - //param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO param.dense_params = &_param_config.dense_variable_name; - _pull_dense_thread = std::shared_ptr(new DensePullThread(param)); + _pull_dense_thread = std::shared_ptr( + new DensePullThread(param)); _pull_dense_thread->start(); } } diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h index 93010f8a9b06f8edb0d028a058390bb22e21b0f7..184566dd39e358d5ed1083ee4f5f0e3cf99370bf 100644 --- a/paddle/fluid/framework/async_executor.h +++ b/paddle/fluid/framework/async_executor.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include // NOLINT @@ -22,8 +23,7 @@ limitations under the License. */ #include // NOLINT #include #include -#include //local_random_engine -#include //local_random_engine +#include // local_random_engine #include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor_thread_worker.h" @@ -43,9 +43,10 @@ inline std::default_random_engine& local_random_engine() { struct engine_wrapper_t { std::default_random_engine engine; engine_wrapper_t() { - static std::atomic x(0); - std::seed_seq sseq = {x++, x++, x++, (unsigned long)(current_realtime() * 1000)}; - engine.seed(sseq); + static std::atomic x(0); + std::seed_seq sseq = {x++, x++, x++, + static_cast(current_realtime() * 1000)}; + engine.seed(sseq); } }; thread_local engine_wrapper_t r; @@ -61,18 +62,20 @@ class AsyncExecutor { const std::vector& filelist, const int thread_num, const std::vector& fetch_names, - const std::string& mode, + const std::string& mode, const bool debug = false); - //void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index); void InitServer(const std::string& dist_desc, int index); - void InitWorker(const std::string& dist_desc, std::vector& host_sign_list, int node_num, int index); - //void ConfigWorker() {} + void InitWorker( + const std::string& dist_desc, + const std::vector& host_sign_list, + int node_num, int index); uint64_t StartServer(); void StopServer(); - void GatherServers(std::vector& host_sign_list, int node_num); + void GatherServers(const std::vector& host_sign_list, int node_num); void InitModel(); void SaveModel(const std::string& path); void InitParamConfig(); + private: void CreateThreads(ExecutorThreadWorker* worker, const ProgramDesc& main_program, @@ -81,6 +84,7 @@ class AsyncExecutor { Scope* root_scope, const int thread_index, const bool debug); void PrepareDenseThread(const std::string& mode); + public: std::shared_ptr _pslib_ptr; std::shared_ptr _pull_dense_thread; @@ -88,6 +92,7 @@ class AsyncExecutor { platform::Place place_; AsyncWorkerParamConfig _param_config; + private: int actual_thread_num;