提交 3c01cdef 编写于 作者: H heqiaozhi

refine executor_thread_worker.cc & executor_thread_worker.h code style

上级 c71279bc
...@@ -303,7 +303,7 @@ void ExecutorThreadWorker::SetRootScope(Scope* g_scope) { ...@@ -303,7 +303,7 @@ void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope; root_scope_ = g_scope;
} }
//AsyncExecutor // AsyncExecutor
void AsyncExecutorThreadWorker::TrainFiles() { void AsyncExecutorThreadWorker::TrainFiles() {
SetDevice(); SetDevice();
...@@ -330,7 +330,6 @@ void AsyncExecutorThreadWorker::TrainFiles() { ...@@ -330,7 +330,6 @@ void AsyncExecutorThreadWorker::TrainFiles() {
print_fetch_var(thread_scope_, fetch_var_names_[i]); print_fetch_var(thread_scope_, fetch_var_names_[i]);
} // end for (int i = 0...) } // end for (int i = 0...)
} // end while () } // end while ()
LOG(ERROR) << "TRAIN DONE";
} }
void AsyncExecutorThreadWorker::SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) { void AsyncExecutorThreadWorker::SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {
...@@ -360,44 +359,12 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() { ...@@ -360,44 +359,12 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() {
UpdateParams(); UpdateParams();
} }
void AsyncExecutorThreadWorker::BindingSlotVariableMemory() {
/*
std::vector<int> ins_slot_offset(batch_size + 1, 0);
for (auto i = 1u; i <= batch_size; ++i) {
ins_slot_offset[i] += ins_slot_offset[i - 1] + slot_dim;
}
std::vector<int> tensor_lod(batch_size + 1, 0);
for (auto i = 1u; i <= batch_size; ++i) {
tensor_lod[i] += tensor_lod[i - 1] + 1;
}
auto& used_slots = reader->get_use_slot_alias();
slot_input_vec.resize(used_slots.size() - 1);
for (auto slot_idx = 1u; slot_idx < used_slots.size(); ++slot_idx) {
auto var = slot_input_variable_name[slot_idx];
auto v = thread_scope->FindVar(var);
CHECK(v != nullptr) << "var[" << var << "] not found";
LoDTensor* tensor = v->GetMutable<LoDTensor>();
float* tensor_ptr = tensor->mutable_data<float>({batch_size, slot_dim}, platform::CPUPlace());
memset(tensor_ptr, 0, sizeof(float) * ins_slot_offset.back());
LoD data_lod{tensor_lod};
tensor->set_lod(data_lod);
slot_input_vec[slot_idx - 1].reset(tensor);
}
*/
}
void AsyncExecutorThreadWorker::SetParamConfig(AsyncWorkerParamConfig* param_config) { void AsyncExecutorThreadWorker::SetParamConfig(AsyncWorkerParamConfig* param_config) {
_param_config = param_config; _param_config = param_config;
} }
void AsyncExecutorThreadWorker::PrepareParams() { void AsyncExecutorThreadWorker::PrepareParams() {
//int table_id = 0; //TODO
for (auto table_id: _param_config->sparse_table_id) { for (auto table_id: _param_config->sparse_table_id) {
PullSparse(table_id); PullSparse(table_id);
for (auto& t : _pull_sparse_status) { for (auto& t : _pull_sparse_status) {
...@@ -423,9 +390,7 @@ void AsyncExecutorThreadWorker::UpdateParams() { ...@@ -423,9 +390,7 @@ void AsyncExecutorThreadWorker::UpdateParams() {
for (auto i : _param_config->dense_table_id) { for (auto i : _param_config->dense_table_id) {
PushDense(i); PushDense(i);
} }
// _param_config->tmp_push_dense_wait_times
int32_t tmp_push_dense_wait_times = -1; int32_t tmp_push_dense_wait_times = -1;
// _param_config->tmp_push_sparse_wait_times
int32_t tmp_push_sparse_wait_times = -1; int32_t tmp_push_sparse_wait_times = -1;
static uint32_t push_dense_wait_times = static uint32_t push_dense_wait_times =
static_cast<uint32_t>(tmp_push_dense_wait_times); static_cast<uint32_t>(tmp_push_dense_wait_times);
...@@ -509,17 +474,15 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) { ...@@ -509,17 +474,15 @@ void AsyncExecutorThreadWorker::PullSparse(int table_id) {
pull_feature_value.data(), table_id, features.data(), features.size()); pull_feature_value.data(), table_id, features.data(), features.size());
_pull_sparse_status.push_back(std::move(status)); _pull_sparse_status.push_back(std::move(status));
//to save time
auto& push_g = _feature_push_value[table_id]; auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, push_g, fea_dim); check_pull_push_memory(features, push_g, fea_dim);
//binding_slot_embed_with_concat(); TODO collect_feasign_info(table_id);
collect_feasign_info(table_id); //TODO
} }
void AsyncExecutorThreadWorker::FillSparse(int table_id) { void AsyncExecutorThreadWorker::FillSparse(int table_id) {
auto slot_dim = _param_config->slot_dim; // TODO auto slot_dim = _param_config->slot_dim;
auto fea_dim = _param_config->fea_dim; //TODO auto fea_dim = _param_config->fea_dim;
auto& features = _features[table_id]; auto& features = _features[table_id];
auto& fea_value = _feature_value[table_id]; auto& fea_value = _feature_value[table_id];
...@@ -544,53 +507,35 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) { ...@@ -544,53 +507,35 @@ void AsyncExecutorThreadWorker::FillSparse(int table_id) {
LoD data_lod{tensor_lod}; LoD data_lod{tensor_lod};
tensor_emb->set_lod(data_lod); tensor_emb->set_lod(data_lod);
//float* ptr = tensor_emb->data<float>();
for (auto index = 0u; index < len; ++index){ for (auto index = 0u; index < len; ++index){
//if (_current_train_job.use_cvm_feature()) {
// if (ids[index] == 0u) {
// memcpy(ptr + slot_dim * index, init_value.data(), sizeof(float) * slot_dim);
// continue;
// }
// memcpy(ptr + slot_dim * index, fea_value[fea_idx].data(), sizeof(float) * slot_dim);
// (ptr + slot_dim * index)[0] = log((ptr + slot_dim * index)[0] + 1);
// (ptr + slot_dim * index)[1] = log((ptr + slot_dim * index)[1] + 1) - (ptr + slot_dim * index)[0];
// fea_idx++;
//} else {
if (ids[index] == 0u) { if (ids[index] == 0u) {
memcpy(ptr + slot_dim * index, init_value.data() + 2, sizeof(float) * slot_dim); memcpy(ptr + slot_dim * index, init_value.data() + 2, sizeof(float) * slot_dim);
continue; continue;
} }
memcpy(ptr + slot_dim * index, fea_value[fea_idx].data() + 2, sizeof(float) * slot_dim); memcpy(ptr + slot_dim * index, fea_value[fea_idx].data() + 2, sizeof(float) * slot_dim);
fea_idx++; fea_idx++;
//}
} }
} }
} }
void AsyncExecutorThreadWorker::PushSparse(int table_id) { void AsyncExecutorThreadWorker::PushSparse(int table_id) {
auto slot_dim = _param_config->slot_dim; //TODO auto slot_dim = _param_config->slot_dim;
auto fea_dim = _param_config->fea_dim;//_current_train_job.fea_dim();TODO auto fea_dim = _param_config->fea_dim;
auto& features = _features[table_id]; auto& features = _features[table_id];
CHECK(features.size() < 1000000) << "features size:" << features.size(); CHECK(features.size() < 1000000) << "features size is too big, may be wrong:" << features.size();
//std::vector<std::string> gradient_var;
//auto& gradient_var = GlobalConfig::instance().input_gradient_variable_name; //TODO
auto& push_g = _feature_push_value[table_id]; auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, push_g, fea_dim); check_pull_push_memory(features, push_g, fea_dim);
CHECK(push_g.size() == features.size() + 1) << "push_g size:" << push_g.size() << " features size:" << features.size(); CHECK(push_g.size() == features.size() + 1) << "push_g size:" << push_g.size() << " features size:" << features.size();
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
auto& fea_info = _fea_info[table_id]; auto& fea_info = _fea_info[table_id];
int offset = 0; int offset = 2;
//if (!_current_train_job.use_cvm_feature()) { //TODO
offset = 2;
//}
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias(); const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
// slot_idx = 0 is label TODO // slot_idx = 0 is label
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) { for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
if (_param_config->slot_alias_to_table.find(feed_vec[slot_idx]) == _param_config->slot_alias_to_table.end()) { if (_param_config->slot_alias_to_table.find(feed_vec[slot_idx]) == _param_config->slot_alias_to_table.end()) {
LOG(ERROR) << "ERROR slot_idx:" << slot_idx << " name:" << feed_vec[slot_idx]; LOG(ERROR) << "ERROR slot_idx:" << slot_idx << " name:" << feed_vec[slot_idx];
} else if (_param_config->slot_alias_to_table[feed_vec[slot_idx]] != table_id) { } else if (_param_config->slot_alias_to_table[feed_vec[slot_idx]] != table_id) {
LOG(ERROR) << "ERROR continue";
continue; continue;
} }
Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[table_id][slot_idx - 1]); Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[table_id][slot_idx - 1]);
...@@ -609,7 +554,6 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) { ...@@ -609,7 +554,6 @@ void AsyncExecutorThreadWorker::PushSparse(int table_id) {
LOG(ERROR) << "var[" << feed_vec[slot_idx] << "] not found"; LOG(ERROR) << "var[" << feed_vec[slot_idx] << "] not found";
exit(-1); exit(-1);
} }
//int len = tensor->lod()[0].back();
int len = tensor->numel(); int len = tensor->numel();
CHECK(slot_dim * len == g_tensor->numel()) << "len:" << len << " g_numel:" << g_tensor->numel(); CHECK(slot_dim * len == g_tensor->numel()) << "len:" << len << " g_numel:" << g_tensor->numel();
CHECK(len == tensor->numel()) << "len:" << len << "t_numel:" << tensor->numel(); CHECK(len == tensor->numel()) << "len:" << len << "t_numel:" << tensor->numel();
......
...@@ -155,7 +155,6 @@ class ExecutorThreadWorker { ...@@ -155,7 +155,6 @@ class ExecutorThreadWorker {
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names); void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
virtual void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); virtual void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
virtual void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) {}; virtual void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) {};
virtual void BindingSlotVariableMemory() {};
virtual void SetParamConfig(AsyncWorkerParamConfig* param_config) {}; virtual void SetParamConfig(AsyncWorkerParamConfig* param_config) {};
private: private:
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
...@@ -191,7 +190,6 @@ public: ...@@ -191,7 +190,6 @@ public:
virtual ~AsyncExecutorThreadWorker() {} virtual ~AsyncExecutorThreadWorker() {}
void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt); void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt);
void BindingSlotVariableMemory();
void SetParamConfig(AsyncWorkerParamConfig* param_config); void SetParamConfig(AsyncWorkerParamConfig* param_config);
void TrainFiles(); void TrainFiles();
void TrainOneNetwork(); void TrainOneNetwork();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册