提交 d3ca359e 编写于 作者: H heqiaozhi

config init & adapt to interface

上级 45177aa2
...@@ -67,21 +67,63 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT ...@@ -67,21 +67,63 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
void AsyncExecutor::ConfigPslib(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index) { void AsyncExecutor::ConfigPslib(const std::string& dist_desc, std::vector<uint64_t>& host_sign_list, int node_num, int index) {
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib()); _pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(new paddle::distributed::PSlib());
_pslib_ptr->init_and_config(dist_desc, host_sign_list, node_num, index);//TODO _pslib_ptr->init_and_config(dist_desc, host_sign_list, node_num, index);//TODO done
} }
void AsyncExecutor::StartServer() { void AsyncExecutor::StartServer() {
InitParamConfig();
_pslib_ptr->run_server(); _pslib_ptr->run_server();
} }
void AsyncExecutor::InitParamConfig() {
_param_config.fea_dim = _pslib_ptr->get_param()->trainer_param().sparse_table(0).feature_dim(); //TODO
_param_config.slot_dim = _param_config.fea_dim - 2; //TODO
_param_config.tmp_push_dense_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().pull_dense_per_batch());
_param_config.tmp_push_sparse_wait_times = (int32_t)(_pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
//sparse
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_value(i)] = table.table_id();
}
std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
tmp_sparse_gradient_variable_name.push_back(
table.slot_gradient(i));
}
_param_config.slot_input_vec[table.table_id()] = std::move(tmp_sparse_variable_name);
_param_config.gradient_var[table.table_id()] = std::move(tmp_sparse_gradient_variable_name);
_param_config.sparse_table_id.push_back(table.table_id());
}
//dense
for (auto t = 0u; t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); ++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
std::vector<std::string> tmp_dense_variable_name;
for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
tmp_dense_variable_name.push_back(table.dense_variable_name(i));
}
std::vector<std::string> tmp_dense_gradient_variable_name;
for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) {
tmp_dense_gradient_variable_name.push_back(
table.dense_gradient_variable_name(i));
}
_param_config.dense_variable_name[table.table_id()] = std::move(tmp_dense_variable_name);
_param_config.dense_gradient_variable_name[table.table_id()] = std::move(tmp_dense_gradient_variable_name);
_param_config.dense_table_id.push_back(table.table_id());
_param_config.dense_table_size.push_back(table.fea_dim()); //TODO
}
}
void AsyncExecutor::InitModel() { void AsyncExecutor::InitModel() {
//TODO only rank = 0 do this //TODO only rank = 0 do this
std::vector<int> all_dense_table_id; //TODO //std::vector<int> all_dense_table_id; //TODO
all_dense_table_id.push_back(0); //all_dense_table_id.push_back(0); //done
for (auto table_id: all_dense_table_id) { for (auto table_id: _param_config.dense_table_id) {
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
std::vector<std::string> variables; //TODO //std::vector<std::string> variables; //TODO
for (auto& t : variables) { for (auto& t : _param_config.dense_variable_name[table_id]) {
Variable* var = root_scope_->FindVar(t); Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found"; CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
...@@ -131,6 +173,7 @@ void AsyncExecutor::PrepareDenseThread() { ...@@ -131,6 +173,7 @@ void AsyncExecutor::PrepareDenseThread() {
param.training_thread_num = actual_thread_num; param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_; param.root_scope = root_scope_;
//param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO //param.dense_params = &GlobalConfig::instance().dense_variable_name; //TODO
param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param)); _pull_dense_thread = std::shared_ptr<DensePullThread>(new DensePullThread(param));
......
...@@ -68,7 +68,7 @@ class AsyncExecutor { ...@@ -68,7 +68,7 @@ class AsyncExecutor {
void StartServer(); void StartServer();
void InitModel(); void InitModel();
void SaveModel(const std::string& path); void SaveModel(const std::string& path);
void InitParamConfig();
private: private:
void CreateThreads(ExecutorThreadWorker* worker, void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program, const ProgramDesc& main_program,
...@@ -86,6 +86,7 @@ class AsyncExecutor { ...@@ -86,6 +86,7 @@ class AsyncExecutor {
AsyncWorkerParamConfig _param_config; AsyncWorkerParamConfig _param_config;
private: private:
int actual_thread_num; int actual_thread_num;
}; };
......
...@@ -382,33 +382,38 @@ void AsyncExecutorThreadWorker::BindingSlotVariableMemory() { ...@@ -382,33 +382,38 @@ void AsyncExecutorThreadWorker::BindingSlotVariableMemory() {
} }
*/ */
} }
void AsyncExecutorThreadWorker::SetParamConfig(AsyncWorkerParamConfig* pc) {
_param_config = pc; void AsyncExecutorThreadWorker::SetParamConfig(AsyncWorkerParamConfig* param_config) {
_param_config = param_config;
} }
void AsyncExecutorThreadWorker::PrepareParams() { void AsyncExecutorThreadWorker::PrepareParams() {
int table_id = 0; //TODO //int table_id = 0; //TODO
PullSparse(table_id); for (auto table_id: _param_config->sparse_table_id) {
for (auto& t : _pull_sparse_status) { PullSparse(table_id);
t.wait(); for (auto& t : _pull_sparse_status) {
auto status = t.get(); t.wait();
if (status != 0) { auto status = t.get();
LOG(ERROR) << "pull sparse failed, status[" << status << "]"; if (status != 0) {
exit(-1); LOG(ERROR) << "pull sparse failed, status[" << status << "]";
exit(-1);
}
} }
} }
_pull_sparse_status.resize(0); _pull_sparse_status.resize(0);
FillSparse(table_id); for (auto table_id: _param_config->sparse_table_id) {
FillSparse(table_id);
}
} }
void AsyncExecutorThreadWorker::UpdateParams() { void AsyncExecutorThreadWorker::UpdateParams() {
//for (auto i = 0u; i < GlobalConfig::instance().dense_table_id.size(); ++i) {//TODO for (auto i: _param_config->sparse_table_id) {//TODO
for (int i = 0; i < 1; ++i) { //for (int i = 0; i < 1; ++i) {
PushSparse(i); PushSparse(i);
} }
//for (auto i = 0u; i < GlobalConfig::instance().dense_table_id.size(); ++i) {//TODO //for (auto i = 0u; i < GlobalConfig::instance().dense_table_id.size(); ++i) {//TODO
for (int i = 1; i < 2; ++i) { for (auto i: _param_config->dense_table_id) {
PushDense(i); PushDense(i);
} }
int32_t tmp_push_dense_wait_times = _param_config->tmp_push_dense_wait_times; //TODO int32_t tmp_push_dense_wait_times = _param_config->tmp_push_dense_wait_times; //TODO
...@@ -437,14 +442,13 @@ void AsyncExecutorThreadWorker::UpdateParams() { ...@@ -437,14 +442,13 @@ void AsyncExecutorThreadWorker::UpdateParams() {
} }
//for (auto dense_table_id : GlobalConfig::instance().dense_table_id) {//TODO //for (auto dense_table_id : GlobalConfig::instance().dense_table_id) {//TODO
int dense_table_id = 1; for (auto dense_table_id: _param_config->dense_table_id) {
_pull_dense_thread->increase_thread_version(thread_id_, dense_table_id); _pull_dense_thread->increase_thread_version(thread_id_, dense_table_id);
}
//} //}
} }
void AsyncExecutorThreadWorker::PushDense(int table_id) { void AsyncExecutorThreadWorker::PushDense(int table_id) {
//auto table_id = GlobalConfig::instance().dense_table_id[table_id_index]; TODO
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
//auto& variables = GlobalConfig::instance().dense_gradient_variable_name[table_id]; //auto& variables = GlobalConfig::instance().dense_gradient_variable_name[table_id];
std::vector<std::string> variables; std::vector<std::string> variables;
...@@ -529,7 +533,7 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) { ...@@ -529,7 +533,7 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) {
int64_t* ids = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int len = tensor->numel(); int len = tensor->numel();
Variable* var_emb = thread_scope_->FindVar(_param_config->slot_input_vec[slot_idx - 1]); Variable* var_emb = thread_scope_->FindVar(_param_config->slot_input_vec[table_id][slot_idx - 1]);
LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>(); LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
float* ptr = tensor_emb->data<float>(); float* ptr = tensor_emb->data<float>();
...@@ -575,10 +579,10 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) { ...@@ -575,10 +579,10 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) {
// slot_idx = 0 is label TODO // slot_idx = 0 is label TODO
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) { for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
if (_slot_alias_to_table[feed_vec[slot_idx]] != table_id) { if (_param_config->slot_alias_to_table[feed_vec[slot_idx]] != table_id) {
continue; continue;
} }
Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[slot_idx - 1]); Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[table_id][slot_idx - 1]);
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>(); LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
//int count = g_tensor->numel(); //int count = g_tensor->numel();
float* g = g_tensor->data<float>(); float* g = g_tensor->data<float>();
......
...@@ -40,8 +40,14 @@ struct AsyncWorkerParamConfig { ...@@ -40,8 +40,14 @@ struct AsyncWorkerParamConfig {
int32_t tmp_push_dense_wait_times; int32_t tmp_push_dense_wait_times;
int32_t tmp_push_sparse_wait_times; int32_t tmp_push_sparse_wait_times;
std::vector<std::string> slot_input_vec; //6048slot 6050slot //name std::map<uint64_t, std::vector<std::string>> dense_variable_name;
std::vector<std::string> gradient_var; //6048slot_embed std::map<uint64_t, std::vector<std::string>> dense_gradient_variable_name;
std::vector<int> dense_table_id;
std::vector<uint32_t> dense_table_size; // fea_dim for each dense table
std::vector<int> sparse_table_id;
std::map<uint64_t, std::vector<std::string>> slot_input_vec; //6048slot 6050slot //name
std::map<uint64_t, std::vector<std::string>> gradient_var; //6048slot_embed
std::unordered_map<std::string, uint64_t> slot_alias_to_table; //TODO done
}; };
struct DensePullThreadParam { struct DensePullThreadParam {
...@@ -148,7 +154,7 @@ class ExecutorThreadWorker { ...@@ -148,7 +154,7 @@ class ExecutorThreadWorker {
virtual void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); virtual void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
virtual void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) {}; virtual void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) {};
virtual void BindingSlotVariableMemory() {}; virtual void BindingSlotVariableMemory() {};
virtual void SetParamConfig(AsyncWorkerParamConfig* pc) {}; virtual void SetParamConfig(AsyncWorkerParamConfig* param_config) {};
private: private:
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
void CreateThreadOperators(const framework::ProgramDesc& program); void CreateThreadOperators(const framework::ProgramDesc& program);
...@@ -184,7 +190,7 @@ public: ...@@ -184,7 +190,7 @@ public:
void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt); void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt);
void BindingSlotVariableMemory(); void BindingSlotVariableMemory();
void SetParamConfig(AsyncWorkerParamConfig* pc); void SetParamConfig(AsyncWorkerParamConfig* param_config);
void TrainFiles(); void TrainFiles();
void TrainOneNetwork(); void TrainOneNetwork();
void PrepareParams(); void PrepareParams();
...@@ -209,7 +215,6 @@ private: ...@@ -209,7 +215,6 @@ private:
std::map<uint64_t, std::vector<std::vector<float>>> _feature_value; std::map<uint64_t, std::vector<std::vector<float>>> _feature_value;
std::map<uint64_t, std::vector<std::vector<float>>> _feature_push_value; std::map<uint64_t, std::vector<std::vector<float>>> _feature_push_value;
std::unordered_map<std::string, uint64_t> _slot_alias_to_table; //TODO
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr; std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册