From d3ca359e445884ffca4b147607607517aad4791b Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Wed, 5 Dec 2018 19:30:37 +0800 Subject: [PATCH] config init & adapt to interface --- paddle/fluid/framework/async_executor.cc | 55 +++++++++++++++++-- paddle/fluid/framework/async_executor.h | 3 +- .../fluid/framework/executor_thread_worker.cc | 44 ++++++++------- .../fluid/framework/executor_thread_worker.h | 15 +++-- 4 files changed, 85 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index 94ed8c2fca4..292b05c5884 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -67,21 +67,63 @@ void PrepareReaders(std::vector>& readers, // NOLINT void AsyncExecutor::ConfigPslib(const std::string& dist_desc, std::vector& host_sign_list, int node_num, int index) { _pslib_ptr = std::shared_ptr(new paddle::distributed::PSlib()); - _pslib_ptr->init_and_config(dist_desc, host_sign_list, node_num, index);//TODO + _pslib_ptr->init_and_config(dist_desc, host_sign_list, node_num, index);//TODO done } void AsyncExecutor::StartServer() { + InitParamConfig(); _pslib_ptr->run_server(); } +void AsyncExecutor::InitParamConfig() { + _param_config.fea_dim = _pslib_ptr->get_param()->trainer_param().sparse_table(0).feature_dim(); //TODO + _param_config.slot_dim = _param_config.fea_dim - 2; //TODO + _param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().pull_dense_per_batch()); + _param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); + //sparse + for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) { + auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t); + std::vector tmp_sparse_variable_name; + for (int i = 0u; i < table.slot_value_size(); ++i) { + tmp_sparse_variable_name.push_back(table.slot_value(i)); + _param_config.slot_alias_to_table[table.slot_value(i)] = table.table_id(); + } + std::vector tmp_sparse_gradient_variable_name; + for (auto i = 0u; i < table.slot_gradient_size(); ++i) { + tmp_sparse_gradient_variable_name.push_back( + table.slot_gradient(i)); + } + _param_config.slot_input_vec[table.table_id()] = std::move(tmp_sparse_variable_name); + _param_config.gradient_var[table.table_id()] = std::move(tmp_sparse_gradient_variable_name); + _param_config.sparse_table_id.push_back(table.table_id()); + } + //dense + for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) { + auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t); + std::vector tmp_dense_variable_name; + for (int i = 0u; i < table.dense_variable_name_size(); ++i) { + tmp_dense_variable_name.push_back(table.dense_variable_name(i)); + } + std::vector tmp_dense_gradient_variable_name; + for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) { + tmp_dense_gradient_variable_name.push_back( + table.dense_gradient_variable_name(i)); + } + _param_config.dense_variable_name[table.table_id()] = std::move(tmp_dense_variable_name); + _param_config.dense_gradient_variable_name[table.table_id()] = std::move(tmp_dense_gradient_variable_name); + _param_config.dense_table_id.push_back(table.table_id()); + _param_config.dense_table_size.push_back(table.fea_dim()); //TODO + } +} + void AsyncExecutor::InitModel() { //TODO only rank = 0 do this - std::vector all_dense_table_id; //TODO - all_dense_table_id.push_back(0); - for (auto table_id: all_dense_table_id) { + //std::vector all_dense_table_id; //TODO + //all_dense_table_id.push_back(0); //done + for (auto table_id: _param_config.dense_table_id) { std::vector regions; - std::vector variables; //TODO - for (auto& t : variables) { + //std::vector variables; //TODO + for (auto& t : _param_config.dense_variable_name[table_id]) { Variable* var = root_scope_->FindVar(t); CHECK(var != nullptr) << "var[" << t << "] not found"; LoDTensor* tensor = var->GetMutable(); @@ -131,6 +173,7 @@ void AsyncExecutor::PrepareDenseThread() { param.training_thread_num = actual_thread_num; param.root_scope = root_scope_; //param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO + param.dense_params = &_param_config.dense_variable_name; _pull_dense_thread = std::shared_ptr(new DensePullThread(param)); diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h index 67f4e5deeee..21e4a66fcef 100644 --- a/paddle/fluid/framework/async_executor.h +++ b/paddle/fluid/framework/async_executor.h @@ -68,7 +68,7 @@ class AsyncExecutor { void StartServer(); void InitModel(); void SaveModel(const std::string& path); - + void InitParamConfig(); private: void CreateThreads(ExecutorThreadWorker* worker, const ProgramDesc& main_program, @@ -86,6 +86,7 @@ class AsyncExecutor { AsyncWorkerParamConfig _param_config; private: int actual_thread_num; + }; diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc index 19d8818be74..f7c05e400d7 100644 --- a/paddle/fluid/framework/executor_thread_worker.cc +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -382,33 +382,38 @@ void AsyncExecutorThreadWorker::BindingSlotVariableMemory() { } */ } -void AsyncExecutorThreadWorker::SetParamConfig(AsyncWorkerParamConfig* pc) { - _param_config = pc; + +void AsyncExecutorThreadWorker::SetParamConfig(AsyncWorkerParamConfig* param_config) { + _param_config = param_config; } void AsyncExecutorThreadWorker::PrepareParams() { - int table_id = 0; //TODO - PullSparse(table_id); - for (auto& t : _pull_sparse_status) { - t.wait(); - auto status = t.get(); - if (status != 0) { - LOG(ERROR) << "pull sparse failed, status[" << status << "]"; - exit(-1); + //int table_id = 0; //TODO + for (auto table_id: _param_config->sparse_table_id) { + PullSparse(table_id); + for (auto& t : _pull_sparse_status) { + t.wait(); + auto status = t.get(); + if (status != 0) { + LOG(ERROR) << "pull sparse failed, status[" << status << "]"; + exit(-1); + } } } _pull_sparse_status.resize(0); - FillSparse(table_id); + for (auto table_id: _param_config->sparse_table_id) { + FillSparse(table_id); + } } void AsyncExecutorThreadWorker::UpdateParams() { - //for (auto i = 0u; i < GlobalConfig::instance().dense_table_id.size(); ++i) {//TODO - for (int i = 0; i < 1; ++i) { + for (auto i: _param_config->sparse_table_id) {//TODO + //for (int i = 0; i < 1; ++i) { PushSparse(i); } //for (auto i = 0u; i < GlobalConfig::instance().dense_table_id.size(); ++i) {//TODO - for (int i = 1; i < 2; ++i) { + for (auto i: _param_config->dense_table_id) { PushDense(i); } int32_t tmp_push_dense_wait_times = _param_config->tmp_push_dense_wait_times; //TODO @@ -437,14 +442,13 @@ void AsyncExecutorThreadWorker::UpdateParams() { } //for (auto dense_table_id : GlobalConfig::instance().dense_table_id) {//TODO - int dense_table_id = 1; + for (auto dense_table_id: _param_config->dense_table_id) { _pull_dense_thread->increase_thread_version(thread_id_, dense_table_id); + } //} } void AsyncExecutorThreadWorker::PushDense(int table_id) { - //auto table_id = GlobalConfig::instance().dense_table_id[table_id_index]; TODO - std::vector regions; //auto& variables = GlobalConfig::instance().dense_gradient_variable_name[table_id]; std::vector variables; @@ -529,7 +533,7 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) { int64_t* ids = tensor->data(); int len = tensor->numel(); - Variable* var_emb = thread_scope_->FindVar(_param_config->slot_input_vec[slot_idx - 1]); + Variable* var_emb = thread_scope_->FindVar(_param_config->slot_input_vec[table_id][slot_idx - 1]); LoDTensor* tensor_emb = var_emb->GetMutable(); float* ptr = tensor_emb->data(); @@ -575,10 +579,10 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) { // slot_idx = 0 is label TODO for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) { - if (_slot_alias_to_table[feed_vec[slot_idx]] != table_id) { + if (_param_config->slot_alias_to_table[feed_vec[slot_idx]] != table_id) { continue; } - Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[slot_idx - 1]); + Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[table_id][slot_idx - 1]); LoDTensor* g_tensor = g_var->GetMutable(); //int count = g_tensor->numel(); float* g = g_tensor->data(); diff --git a/paddle/fluid/framework/executor_thread_worker.h b/paddle/fluid/framework/executor_thread_worker.h index 63f383cd479..4e3255a590c 100644 --- a/paddle/fluid/framework/executor_thread_worker.h +++ b/paddle/fluid/framework/executor_thread_worker.h @@ -40,8 +40,14 @@ struct AsyncWorkerParamConfig { int32_t tmp_push_dense_wait_times; int32_t tmp_push_sparse_wait_times; - std::vector slot_input_vec; //6048slot 6050slot //name - std::vector gradient_var; //6048slot_embed + std::map> dense_variable_name; + std::map> dense_gradient_variable_name; + std::vector dense_table_id; + std::vector dense_table_size; // fea_dim for each dense table + std::vector sparse_table_id; + std::map> slot_input_vec; //6048slot 6050slot //name + std::map> gradient_var; //6048slot_embed + std::unordered_map slot_alias_to_table; //TODO done }; struct DensePullThreadParam { @@ -148,7 +154,7 @@ class ExecutorThreadWorker { virtual void SetPSlibPtr(std::shared_ptr pslib_ptr); virtual void SetPullDenseThread(std::shared_ptr dpt) {}; virtual void BindingSlotVariableMemory() {}; - virtual void SetParamConfig(AsyncWorkerParamConfig* pc) {}; + virtual void SetParamConfig(AsyncWorkerParamConfig* param_config) {}; private: void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadOperators(const framework::ProgramDesc& program); @@ -184,7 +190,7 @@ public: void SetPSlibPtr(std::shared_ptr pslib_ptr); void SetPullDenseThread(std::shared_ptr dpt); void BindingSlotVariableMemory(); - void SetParamConfig(AsyncWorkerParamConfig* pc); + void SetParamConfig(AsyncWorkerParamConfig* param_config); void TrainFiles(); void TrainOneNetwork(); void PrepareParams(); @@ -209,7 +215,6 @@ private: std::map>> _feature_value; std::map>> _feature_push_value; - std::unordered_map _slot_alias_to_table; //TODO std::shared_ptr _pslib_ptr; -- GitLab