提交 c71279bc 编写于 作者: D dongdaxiang

refine code style for async_executor.h and async_executor.cc

上级 33ee5cad
...@@ -66,15 +66,20 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT ...@@ -66,15 +66,20 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
} }
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) { void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib()); _pslib_ptr =
_pslib_ptr->init_server(dist_desc, index);//TODO done std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
_pslib_ptr->init_server(dist_desc, index);
InitParamConfig(); InitParamConfig();
} }
void AsyncExecutor::InitWorker(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index) { void AsyncExecutor::InitWorker(const std::string& dist_desc,
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib()); const std::vector<uint64_t>& host_sign_list,
_pslib_ptr->init_worker(dist_desc, host_sign_list.data(), node_num, index);//TODO done int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
_pslib_ptr->init_worker(
dist_desc, host_sign_list.data(), node_num, index);
InitParamConfig(); InitParamConfig();
} }
...@@ -87,43 +92,65 @@ void AsyncExecutor::StopServer() { ...@@ -87,43 +92,65 @@ void AsyncExecutor::StopServer() {
_pslib_ptr->stop_server(); _pslib_ptr->stop_server();
} }
void AsyncExecutor::GatherServers(std::vector<uint64_t>& host_sign_list, int node_num) { void AsyncExecutor::GatherServers(
std::vector<uint64_t>& host_sign_list, int node_num) {
_pslib_ptr->gather_servers(host_sign_list.data(), node_num); _pslib_ptr->gather_servers(host_sign_list.data(), node_num);
} }
void AsyncExecutor::InitParamConfig() { void AsyncExecutor::InitParamConfig() {
for (int i = 0; i < _pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param_size(); ++i) { for (int i = 0; i <
if (_pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param(i).table_class().find("SparseTable") != -1) { _pslib_ptr->get_param()->server_param().\
_param_config.fea_dim = _pslib_ptr->get_param()->server_param().downpour_server_param().downpour_table_param(i).accessor().fea_dim(); //TODO downpour_server_param().\
downpour_table_param_size();
++i) {
if (_pslib_ptr->get_param()->server_param().\
downpour_server_param().downpour_table_param(i).\
table_class().find("SparseTable") != -1) {
_param_config.fea_dim = _pslib_ptr->get_param()->server_param().\
downpour_server_param().\
downpour_table_param(i).\
accessor().fea_dim();
break; break;
} }
} }
_param_config.slot_dim = _param_config.fea_dim - 2; //TODO _param_config.slot_dim = _param_config.fea_dim - 2;
_param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); _param_config.tmp_push_dense_wait_times = static_cast<int32_t>(
_param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch()); _pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = static_cast<int32_t>(
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); ++t) { _pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
_param_config.skip_op.push_back(_pslib_ptr->get_param()->trainer_param().skip_op(t));
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().skip_op_size();
++t) {
_param_config.skip_op.push_back(
_pslib_ptr->get_param()->trainer_param().skip_op(t));
} }
//sparse
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) { for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size();
++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t); auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name; std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) { for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i)); tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_key(i)] = table.table_id(); _param_config.slot_alias_to_table[table.slot_key(i)] =
table.table_id();
} }
std::vector<std::string> tmp_sparse_gradient_variable_name; std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto i = 0u; i < table.slot_gradient_size(); ++i) { for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
tmp_sparse_gradient_variable_name.push_back( tmp_sparse_gradient_variable_name.push_back(
table.slot_gradient(i)); table.slot_gradient(i));
} }
_param_config.slot_input_vec[table.table_id()] = std::move(tmp_sparse_variable_name); _param_config.slot_input_vec[table.table_id()] =
_param_config.gradient_var[table.table_id()] = std::move(tmp_sparse_gradient_variable_name); std::move(tmp_sparse_variable_name);
_param_config.gradient_var[table.table_id()] =
std::move(tmp_sparse_gradient_variable_name);
_param_config.sparse_table_id.push_back(table.table_id()); _param_config.sparse_table_id.push_back(table.table_id());
} }
//dense
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) { for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().dense_table_size();
++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t); auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
std::vector<std::string> tmp_dense_variable_name; std::vector<std::string> tmp_dense_variable_name;
for (int i = 0u; i < table.dense_variable_name_size(); ++i) { for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
...@@ -134,20 +161,18 @@ void AsyncExecutor::InitParamConfig() { ...@@ -134,20 +161,18 @@ void AsyncExecutor::InitParamConfig() {
tmp_dense_gradient_variable_name.push_back( tmp_dense_gradient_variable_name.push_back(
table.dense_gradient_variable_name(i)); table.dense_gradient_variable_name(i));
} }
_param_config.dense_variable_name[table.table_id()] = std::move(tmp_dense_variable_name); _param_config.dense_variable_name[table.table_id()] =
_param_config.dense_gradient_variable_name[table.table_id()] = std::move(tmp_dense_gradient_variable_name); std::move(tmp_dense_variable_name);
_param_config.dense_gradient_variable_name[table.table_id()] =
std::move(tmp_dense_gradient_variable_name);
_param_config.dense_table_id.push_back(table.table_id()); _param_config.dense_table_id.push_back(table.table_id());
_param_config.dense_table_size.push_back(table.fea_dim()); //TODO _param_config.dense_table_size.push_back(table.fea_dim());
} }
} }
void AsyncExecutor::InitModel() { void AsyncExecutor::InitModel() {
//TODO only rank = 0 do this for (auto table_id : _param_config.dense_table_id) {
//std::vector<int> all_dense_table_id; //TODO
//all_dense_table_id.push_back(0); //done
for (auto table_id: _param_config.dense_table_id) {
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
//std::vector<std::string> variables; //TODO
for (auto& t : _param_config.dense_variable_name[table_id]) { for (auto& t : _param_config.dense_variable_name[table_id]) {
Variable* var = root_scope_->FindVar(t); Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found"; CHECK(var != nullptr) << "var[" << t << "] not found";
...@@ -169,7 +194,9 @@ void AsyncExecutor::InitModel() { ...@@ -169,7 +194,9 @@ void AsyncExecutor::InitModel() {
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} }
auto push_status = _pslib_ptr->_worker_ptr->push_dense_param(regions.data(), regions.size(), table_id); auto push_status =
_pslib_ptr->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait(); push_status.wait();
auto status = push_status.get(); auto status = push_status.get();
if (status != 0) { if (status != 0) {
...@@ -185,7 +212,7 @@ void AsyncExecutor::SaveModel(const std::string& path) { ...@@ -185,7 +212,7 @@ void AsyncExecutor::SaveModel(const std::string& path) {
ret = _pslib_ptr->_worker_ptr->save(path, 0); ret = _pslib_ptr->_worker_ptr->save(path, 0);
ret.wait(); ret.wait();
int32_t feasign_cnt = ret.get(); int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { // TODO should be feasign_cnt < 0, because server bug if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0
LOG(FATAL) << "save model failed"; LOG(FATAL) << "save model failed";
exit(-1); exit(-1);
} }
...@@ -195,13 +222,13 @@ void AsyncExecutor::PrepareDenseThread(const std::string& mode) { ...@@ -195,13 +222,13 @@ void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if (mode == "mpi") { if (mode == "mpi") {
DensePullThreadParam param; DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;; param.ps_client = _pslib_ptr->_worker_ptr;;
param.threshold = 1;//GlobalConfig::instance().pull_dense_per_batch; //TODO param.threshold = 1;
param.training_thread_num = actual_thread_num; param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_; param.root_scope = root_scope_;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param.dense_params = &_param_config.dense_variable_name; param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param)); _pull_dense_thread = std::shared_ptr<DensePullThread>(
new DensePullThread(param));
_pull_dense_thread->start(); _pull_dense_thread->start();
} }
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <time.h>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
...@@ -22,8 +23,7 @@ limitations under the License. */ ...@@ -22,8 +23,7 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include <typeinfo> #include <typeinfo>
#include <vector> #include <vector>
#include <random> //local_random_engine #include <random> // local_random_engine
#include <time.h> //local_random_engine
#include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor_thread_worker.h" #include "paddle/fluid/framework/executor_thread_worker.h"
...@@ -43,8 +43,9 @@ inline std::default_random_engine& local_random_engine() { ...@@ -43,8 +43,9 @@ inline std::default_random_engine& local_random_engine() {
struct engine_wrapper_t { struct engine_wrapper_t {
std::default_random_engine engine; std::default_random_engine engine;
engine_wrapper_t() { engine_wrapper_t() {
static std::atomic<unsigned long> x(0); static std::atomic<uint64> x(0);
std::seed_seq sseq = {x++, x++, x++, (unsigned long)(current_realtime() * 1000)}; std::seed_seq sseq = {x++, x++, x++,
static_cast<uint64>(current_realtime() * 1000)};
engine.seed(sseq); engine.seed(sseq);
} }
}; };
...@@ -63,16 +64,18 @@ class AsyncExecutor { ...@@ -63,16 +64,18 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_names, const std::vector<std::string>& fetch_names,
const std::string& mode, const std::string& mode,
const bool debug = false); const bool debug = false);
//void ConfigPslib(const char* dist_desc, uint64_t* host_sign_list, int node_num, int index);
void InitServer(const std::string& dist_desc, int index); void InitServer(const std::string& dist_desc, int index);
void InitWorker(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index); void InitWorker(
//void ConfigWorker() {} const std::string& dist_desc,
const std::vector<uint64_t>& host_sign_list,
int node_num, int index);
uint64_t StartServer(); uint64_t StartServer();
void StopServer(); void StopServer();
void GatherServers(std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
void InitModel(); void InitModel();
void SaveModel(const std::string& path); void SaveModel(const std::string& path);
void InitParamConfig(); void InitParamConfig();
private: private:
void CreateThreads(ExecutorThreadWorker* worker, void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program, const ProgramDesc& main_program,
...@@ -81,6 +84,7 @@ class AsyncExecutor { ...@@ -81,6 +84,7 @@ class AsyncExecutor {
Scope* root_scope, const int thread_index, Scope* root_scope, const int thread_index,
const bool debug); const bool debug);
void PrepareDenseThread(const std::string& mode); void PrepareDenseThread(const std::string& mode);
public: public:
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr; std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread; std::shared_ptr<DensePullThread> _pull_dense_thread;
...@@ -88,6 +92,7 @@ class AsyncExecutor { ...@@ -88,6 +92,7 @@ class AsyncExecutor {
platform::Place place_; platform::Place place_;
AsyncWorkerParamConfig _param_config; AsyncWorkerParamConfig _param_config;
private: private:
int actual_thread_num; int actual_thread_num;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册