提交 c59cdf3a 编写于 作者: D dongdaxiang

refine executor_thread_worker.h and executor_thread_worker.cc code style

上级 c4cb4142
...@@ -58,7 +58,8 @@ bool DensePullThread::check_update_param(uint64_t table_id) { ...@@ -58,7 +58,8 @@ bool DensePullThread::check_update_param(uint64_t table_id) {
{ {
std::lock_guard<std::mutex> lock(_mutex_for_version); std::lock_guard<std::mutex> lock(_mutex_for_version);
auto& version = _training_versions[table_id]; auto& version = _training_versions[table_id];
_current_version[table_id] = *(std::min_element(version.begin(), version.end())); _current_version[table_id] =
*(std::min_element(version.begin(), version.end()));
} }
if (_current_version[table_id] - _last_versions[table_id] < _threshold) { if (_current_version[table_id] - _last_versions[table_id] < _threshold) {
return false; return false;
...@@ -93,7 +94,8 @@ void DensePullThread::wait_all() { ...@@ -93,7 +94,8 @@ void DensePullThread::wait_all() {
t.wait(); t.wait();
auto status = t.get(); auto status = t.get();
if (status != 0) { if (status != 0) {
LOG(WARNING) << "pull dense failed times:" << ++_pull_dense_fail_times; LOG(WARNING) << "pull dense failed times:" <<
++_pull_dense_fail_times;
} }
} }
...@@ -105,7 +107,8 @@ void DensePullThread::wait_all() { ...@@ -105,7 +107,8 @@ void DensePullThread::wait_all() {
_pull_dense_status.resize(0); _pull_dense_status.resize(0);
} }
void DensePullThread::increase_thread_version(int thread_id, uint64_t table_id) { void DensePullThread::increase_thread_version(
int thread_id, uint64_t table_id) {
std::lock_guard<std::mutex> lock(_mutex_for_version); std::lock_guard<std::mutex> lock(_mutex_for_version);
_training_versions[table_id][thread_id]++; _training_versions[table_id][thread_id]++;
} }
...@@ -169,10 +172,6 @@ void ExecutorThreadWorker::SetFetchVarNames( ...@@ -169,10 +172,6 @@ void ExecutorThreadWorker::SetFetchVarNames(
fetch_var_names.end()); fetch_var_names.end());
} }
void ExecutorThreadWorker::SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {
}
void ExecutorThreadWorker::SetDevice() { void ExecutorThreadWorker::SetDevice() {
#if defined _WIN32 || defined __APPLE__ #if defined _WIN32 || defined __APPLE__
...@@ -332,10 +331,12 @@ void AsyncExecutorThreadWorker::TrainFiles() { ...@@ -332,10 +331,12 @@ void AsyncExecutorThreadWorker::TrainFiles() {
} // end while () } // end while ()
} }
void AsyncExecutorThreadWorker::SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) { void AsyncExecutorThreadWorker::SetPSlibPtr(
std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {
_pslib_ptr = pslib_ptr; _pslib_ptr = pslib_ptr;
} }
void AsyncExecutorThreadWorker::SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) { void AsyncExecutorThreadWorker::SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) {
_pull_dense_thread = dpt; _pull_dense_thread = dpt;
} }
void AsyncExecutorThreadWorker::TrainOneNetwork() { void AsyncExecutorThreadWorker::TrainOneNetwork() {
...@@ -347,7 +348,8 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() { ...@@ -347,7 +348,8 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() {
} }
bool need_skip = false; bool need_skip = false;
for (auto t = 0u; t < _param_config->skip_op.size(); ++t) { for (auto t = 0u; t < _param_config->skip_op.size(); ++t) {
if (op->Type().find(_param_config->skip_op[t]) != std::string::npos) { if (op->Type().find(_param_config->skip_op[t]) !=
std::string::npos) {
need_skip = true; need_skip = true;
break; break;
} }
...@@ -359,13 +361,13 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() { ...@@ -359,13 +361,13 @@ void AsyncExecutorThreadWorker::TrainOneNetwork() {
UpdateParams(); UpdateParams();
} }
void AsyncExecutorThreadWorker::SetParamConfig(
void AsyncExecutorThreadWorker::SetParamConfig(AsyncWorkerParamConfig* param_config) { AsyncWorkerParamConfig* param_config) {
_param_config = param_config; _param_config = param_config;
} }
void AsyncExecutorThreadWorker::PrepareParams() { void AsyncExecutorThreadWorker::PrepareParams() {
for (auto table_id: _param_config->sparse_table_id) { for (auto table_id : _param_config->sparse_table_id) {
PullSparse(table_id); PullSparse(table_id);
for (auto& t : _pull_sparse_status) { for (auto& t : _pull_sparse_status) {
t.wait(); t.wait();
...@@ -378,7 +380,7 @@ void AsyncExecutorThreadWorker::PrepareParams() { ...@@ -378,7 +380,7 @@ void AsyncExecutorThreadWorker::PrepareParams() {
} }
_pull_sparse_status.resize(0); _pull_sparse_status.resize(0);
for (auto table_id: _param_config->sparse_table_id) { for (auto table_id : _param_config->sparse_table_id) {
FillSparse(table_id); FillSparse(table_id);
} }
} }
...@@ -440,180 +442,198 @@ void AsyncExecutorThreadWorker::PushDense(int table_id) { ...@@ -440,180 +442,198 @@ void AsyncExecutorThreadWorker::PushDense(int table_id) {
void AsyncExecutorThreadWorker::PullSparse(int table_id) { void AsyncExecutorThreadWorker::PullSparse(int table_id) {
auto& features = _features[table_id]; auto& features = _features[table_id];
auto& feature_value = _feature_value[table_id]; auto& feature_value = _feature_value[table_id];
auto fea_dim = _param_config->fea_dim; auto fea_dim = _param_config->fea_dim;
// slot id starts from 1 // slot id starts from 1
features.clear(); features.clear();
features.resize(0); features.resize(0);
features.reserve(MAX_FEASIGN_NUM); features.reserve(MAX_FEASIGN_NUM);
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias(); const std::vector<std::string>& feed_vec =
// slot_idx = 0 is label TODO thread_reader_->GetUseSlotAlias();
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) { // slot_idx = 0 is label TODO
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]); for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
LoDTensor* tensor = var->GetMutable<LoDTensor>(); Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
int64_t* ids = tensor->data<int64_t>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int len = tensor->numel(); int64_t* ids = tensor->data<int64_t>();
for (auto i = 0u; i < len; ++i) { int len = tensor->numel();
//todo: current trick - filter feasign=use_slot_mod(bug: datafeed fill use_slot_mod for empty slot) for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) { // todo(colourful-tree): current trick - filter feasign=use_slot_mod(
continue; // bug: datafeed fill use_slot_mod for empty slot)
} if (ids[i] == 0u) {
features.push_back(static_cast<uint64_t>(ids[i])); continue;
} }
} features.push_back(static_cast<uint64_t>(ids[i]));
check_pull_push_memory(features, feature_value, fea_dim);
std::vector<float*> pull_feature_value;
for (auto i = 0u; i < features.size(); ++i) {
pull_feature_value.push_back(feature_value[i].data());
}
for (int i = 0; i < features.size(); ++i) {
} }
auto status = _pslib_ptr->_worker_ptr->pull_sparse( }
pull_feature_value.data(), table_id, features.data(), features.size()); check_pull_push_memory(features, feature_value, fea_dim);
_pull_sparse_status.push_back(std::move(status));
std::vector<float*> pull_feature_value;
auto& push_g = _feature_push_value[table_id]; for (auto i = 0u; i < features.size(); ++i) {
check_pull_push_memory(features, push_g, fea_dim); pull_feature_value.push_back(feature_value[i].data());
}
collect_feasign_info(table_id);
auto status = _pslib_ptr->_worker_ptr->pull_sparse(
pull_feature_value.data(), table_id, features.data(), features.size());
_pull_sparse_status.push_back(std::move(status));
auto& push_g = _feature_push_value[table_id];
check_pull_push_memory(features, push_g, fea_dim);
collect_feasign_info(table_id);
} }
void AsyncExecutorThreadWorker::FillSparse(int table_id) { void AsyncExecutorThreadWorker::FillSparse(int table_id) {
auto slot_dim = _param_config->slot_dim; auto slot_dim = _param_config->slot_dim;
auto fea_dim = _param_config->fea_dim; auto fea_dim = _param_config->fea_dim;
auto& features = _features[table_id]; auto& features = _features[table_id];
auto& fea_value = _feature_value[table_id]; auto& fea_value = _feature_value[table_id];
CHECK(features.size() > 0) << "feature size check failed"; CHECK(features.size() > 0) << "feature size check failed";
auto fea_idx = 0u; auto fea_idx = 0u;
std::vector<float> init_value(fea_dim); std::vector<float> init_value(fea_dim);
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias(); const std::vector<std::string>& feed_vec =
// slot_idx = 0 is label TODO thread_reader_->GetUseSlotAlias();
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) { // slot_idx = 0 is label TODO
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]); for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
LoDTensor* tensor = var->GetMutable<LoDTensor>(); Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
int64_t* ids = tensor->data<int64_t>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int len = tensor->numel(); int64_t* ids = tensor->data<int64_t>();
Variable* var_emb = thread_scope_->FindVar(_param_config->slot_input_vec[table_id][slot_idx - 1]); int len = tensor->numel();
LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>(); Variable* var_emb = thread_scope_->FindVar(
float* ptr = tensor_emb->mutable_data<float>({len, slot_dim}, platform::CPUPlace()); _param_config->slot_input_vec[table_id][slot_idx - 1]);
memset(ptr, 0, sizeof(float) * len * slot_dim); LoDTensor* tensor_emb = var_emb->GetMutable<LoDTensor>();
auto& tensor_lod = tensor->lod()[0]; float* ptr = tensor_emb->mutable_data<float>(
{len, slot_dim}, platform::CPUPlace());
LoD data_lod{tensor_lod}; memset(ptr, 0, sizeof(float) * len * slot_dim);
tensor_emb->set_lod(data_lod); auto& tensor_lod = tensor->lod()[0];
for (auto index = 0u; index < len; ++index){ LoD data_lod{tensor_lod};
if (ids[index] == 0u) { tensor_emb->set_lod(data_lod);
memcpy(ptr + slot_dim * index, init_value.data() + 2, sizeof(float) * slot_dim);
continue; for (auto index = 0u; index < len; ++index) {
} if (ids[index] == 0u) {
memcpy(ptr + slot_dim * index, fea_value[fea_idx].data() + 2, sizeof(float) * slot_dim); memcpy(ptr + slot_dim * index,
fea_idx++; init_value.data() + 2, sizeof(float) * slot_dim);
} continue;
}
memcpy(ptr + slot_dim * index,
fea_value[fea_idx].data() + 2, sizeof(float) * slot_dim);
fea_idx++;
} }
}
} }
void AsyncExecutorThreadWorker::PushSparse(int table_id) { void AsyncExecutorThreadWorker::PushSparse(int table_id) {
auto slot_dim = _param_config->slot_dim; auto slot_dim = _param_config->slot_dim;
auto fea_dim = _param_config->fea_dim; auto fea_dim = _param_config->fea_dim;
auto& features = _features[table_id]; auto& features = _features[table_id];
CHECK(features.size() < 1000000) << "features size is too big, may be wrong:" << features.size(); auto& push_g = _feature_push_value[table_id];
auto& push_g = _feature_push_value[table_id]; check_pull_push_memory(features, push_g, fea_dim);
check_pull_push_memory(features, push_g, fea_dim); CHECK(push_g.size() == features.size() + 1) <<
CHECK(push_g.size() == features.size() + 1) << "push_g size:" << push_g.size() << " features size:" << features.size(); "push_g size:" << push_g.size() << " features size:" << features.size();
uint64_t fea_idx = 0u; uint64_t fea_idx = 0u;
auto& fea_info = _fea_info[table_id]; auto& fea_info = _fea_info[table_id];
int offset = 2; int offset = 2;
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias(); const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
// slot_idx = 0 is label // slot_idx = 0 is label
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) { for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
if (_param_config->slot_alias_to_table.find(feed_vec[slot_idx]) == _param_config->slot_alias_to_table.end()) { if (_param_config->slot_alias_to_table.find(
LOG(ERROR) << "ERROR slot_idx:" << slot_idx << " name:" << feed_vec[slot_idx]; feed_vec[slot_idx]) == _param_config->slot_alias_to_table.end()) {
} else if (_param_config->slot_alias_to_table[feed_vec[slot_idx]] != table_id) { LOG(ERROR) << "ERROR slot_idx:" << slot_idx <<
continue; " name:" << feed_vec[slot_idx];
} } else if (
Variable* g_var = thread_scope_->FindVar(_param_config->gradient_var[table_id][slot_idx - 1]); _param_config->slot_alias_to_table[feed_vec[slot_idx]] != table_id) {
CHECK(g_var != nullptr) << "var[" << _param_config->gradient_var[table_id][slot_idx - 1] << "] not found"; continue;
LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
if (g_tensor == NULL) {
LOG(ERROR) << "var[" << _param_config->gradient_var[table_id][slot_idx - 1] << "] not found";
exit(-1);
}
float* g = g_tensor->data<float>();
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
CHECK(var != nullptr) << "var[" << feed_vec[slot_idx] << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) {
LOG(ERROR) << "var[" << feed_vec[slot_idx] << "] not found";
exit(-1);
}
int len = tensor->numel();
CHECK(slot_dim * len == g_tensor->numel()) << "len:" << len << " g_numel:" << g_tensor->numel();
CHECK(len == tensor->numel()) << "len:" << len << "t_numel:" << tensor->numel();
int64_t* ids = tensor->data<int64_t>();
for (auto id_idx = 0u; id_idx < len; ++id_idx){
if (ids[id_idx] == 0) {
g += slot_dim;
continue;
}
memcpy(push_g[fea_idx].data() + offset, g, sizeof(float) * slot_dim);
push_g[fea_idx][0] = 1.0f;
CHECK(fea_idx < fea_info.size()) << "fea_idx:" << fea_idx << " size:" << fea_info.size();
push_g[fea_idx][1] = static_cast<float>(fea_info[fea_idx].label);
g += slot_dim;
fea_idx++;
}
} }
CHECK(fea_idx == features.size()) << "fea_idx:" << fea_idx << " features size:" << features.size(); Variable* g_var = thread_scope_->FindVar(
CHECK(features.size() > 0); _param_config->gradient_var[table_id][slot_idx - 1]);
CHECK(g_var != nullptr) << "var[" <<
std::vector<float*> push_g_vec; _param_config->gradient_var[table_id][slot_idx - 1] << "] not found";
for (auto i = 0u; i < features.size(); ++i) { LoDTensor* g_tensor = g_var->GetMutable<LoDTensor>();
push_g_vec.push_back(push_g[i].data()); if (g_tensor == NULL) {
LOG(ERROR) << "var[" <<
_param_config->gradient_var[table_id][slot_idx - 1] << "] not found";
exit(-1);
}
float* g = g_tensor->data<float>();
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
CHECK(var != nullptr) << "var[" << feed_vec[slot_idx] << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
if (tensor == NULL) {
LOG(ERROR) << "var[" << feed_vec[slot_idx] << "] not found";
exit(-1);
}
int len = tensor->numel();
CHECK(slot_dim * len == g_tensor->numel()) <<
"len:" << len << " g_numel:" << g_tensor->numel();
CHECK(len == tensor->numel()) << "len:" <<
len << "t_numel:" << tensor->numel();
int64_t* ids = tensor->data<int64_t>();
for (auto id_idx = 0u; id_idx < len; ++id_idx) {
if (ids[id_idx] == 0) {
g += slot_dim;
continue;
}
memcpy(push_g[fea_idx].data() + offset,
g, sizeof(float) * slot_dim);
push_g[fea_idx][0] = 1.0f;
CHECK(fea_idx < fea_info.size()) << "fea_idx:" <<
fea_idx << " size:" << fea_info.size();
push_g[fea_idx][1] = static_cast<float>(fea_info[fea_idx].label);
g += slot_dim;
fea_idx++;
} }
auto status = _pslib_ptr->_worker_ptr->push_sparse( }
table_id, features.data(), (const float**)push_g_vec.data(), features.size()); CHECK(fea_idx == features.size()) << "fea_idx:" <<
_push_sparse_status.push_back(std::move(status)); fea_idx << " features size:" << features.size();
CHECK_GT(features.size(), 0);
std::vector<float*> push_g_vec;
for (auto i = 0u; i < features.size(); ++i) {
push_g_vec.push_back(push_g[i].data());
}
auto status = _pslib_ptr->_worker_ptr->push_sparse(
table_id, features.data(),
(const float**)push_g_vec.data(), features.size());
_push_sparse_status.push_back(std::move(status));
} }
void AsyncExecutorThreadWorker::collect_feasign_info( void AsyncExecutorThreadWorker::collect_feasign_info(
int table_id) { int table_id) {
auto& fea_info = _fea_info[table_id]; auto& fea_info = _fea_info[table_id];
auto& feature = _features[table_id]; auto& feature = _features[table_id];
fea_info.resize(feature.size()); fea_info.resize(feature.size());
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias();
const std::vector<std::string>& feed_vec = thread_reader_->GetUseSlotAlias(); Variable* var = thread_scope_->FindVar(feed_vec[0]);
Variable* var = thread_scope_->FindVar(feed_vec[0]); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label = tensor->data<int64_t>();
int global_index = 0;
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) {
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int64_t* label = tensor->data<int64_t>(); int64_t* ids = tensor->data<int64_t>();
int global_index = 0; int fea_idx = 0;
for (auto slot_idx = 1u; slot_idx < feed_vec.size(); ++slot_idx) { for (auto ins_idx = 1u; ins_idx < tensor->lod()[0].size(); ++ins_idx) {
Variable* var = thread_scope_->FindVar(feed_vec[slot_idx]); for (; fea_idx < tensor->lod()[0][ins_idx]; ++fea_idx) {
LoDTensor* tensor = var->GetMutable<LoDTensor>(); if (ids[fea_idx] == 0u) {
int64_t* ids = tensor->data<int64_t>(); continue;
int fea_idx = 0;
for (auto ins_idx = 1u; ins_idx < tensor->lod()[0].size(); ++ins_idx) {
for (; fea_idx < tensor->lod()[0][ins_idx]; ++fea_idx) {
if (ids[fea_idx] == 0u) {
continue;
}
FeasignInfo info{slot_idx, ins_idx, label[ins_idx - 1]};
fea_info[global_index++] = std::move(info);
}
} }
FeasignInfo info{slot_idx, ins_idx, label[ins_idx - 1]};
fea_info[global_index++] = std::move(info);
}
} }
CHECK(global_index == feature.size()) << "expect fea info size:" << feature.size() }
<< " real:" << global_index; CHECK(global_index == feature.size()) <<
"expect fea info size:" << feature.size()
<< " real:" << global_index;
} }
void AsyncExecutorThreadWorker::check_pull_push_memory( void AsyncExecutorThreadWorker::check_pull_push_memory(
......
...@@ -35,21 +35,22 @@ const static uint32_t MAX_FEASIGN_NUM = 1000 * 100 * 100; ...@@ -35,21 +35,22 @@ const static uint32_t MAX_FEASIGN_NUM = 1000 * 100 * 100;
void CreateTensor(Variable* var, proto::VarType::Type var_type); void CreateTensor(Variable* var, proto::VarType::Type var_type);
struct AsyncWorkerParamConfig { struct AsyncWorkerParamConfig {
int slot_dim; int slot_dim;
int fea_dim; int fea_dim;
int32_t tmp_push_dense_wait_times; int32_t tmp_push_dense_wait_times;
int32_t tmp_push_sparse_wait_times; int32_t tmp_push_sparse_wait_times;
std::vector<std::string> skip_op; std::vector<std::string> skip_op;
std::map<uint64_t, std::vector<std::string>> dense_variable_name; std::map<uint64_t, std::vector<std::string>> dense_variable_name;
std::map<uint64_t, std::vector<std::string>> dense_gradient_variable_name; std::map<uint64_t, std::vector<std::string>> dense_gradient_variable_name;
std::vector<int> dense_table_id; std::vector<int> dense_table_id;
std::vector<uint32_t> dense_table_size; // fea_dim for each dense table // fea_dim for each dense table
std::vector<int> sparse_table_id; std::vector<uint32_t> dense_table_size;
std::map<uint64_t, std::vector<std::string>> slot_input_vec; //6048slot 6050slot //name std::vector<int> sparse_table_id;
std::map<uint64_t, std::vector<std::string>> gradient_var; //6048slot_embed std::map<uint64_t, std::vector<std::string>> slot_input_vec;
std::map<std::string, uint64_t> slot_alias_to_table; //TODO done std::map<uint64_t, std::vector<std::string>> gradient_var;
std::map<std::string, uint64_t> slot_alias_to_table;
}; };
struct DensePullThreadParam { struct DensePullThreadParam {
...@@ -62,8 +63,8 @@ struct DensePullThreadParam { ...@@ -62,8 +63,8 @@ struct DensePullThreadParam {
}; };
class DensePullThread { class DensePullThread {
public: public:
DensePullThread(DensePullThreadParam& param) : explicit DensePullThread(const DensePullThreadParam& param) :
_running(false) { _running(false) {
_ps_client = param.ps_client; _ps_client = param.ps_client;
_threshold = param.threshold; _threshold = param.threshold;
...@@ -96,11 +97,11 @@ public: ...@@ -96,11 +97,11 @@ public:
void pull_dense2(uint64_t table_id); void pull_dense2(uint64_t table_id);
void wait_all(); void wait_all();
private: private:
void run(); void run();
bool check_update_param(uint64_t table_id); bool check_update_param(uint64_t table_id);
private: private:
std::shared_ptr<paddle::ps::PSClient> _ps_client; std::shared_ptr<paddle::ps::PSClient> _ps_client;
int _thread_num; int _thread_num;
int _threshold; int _threshold;
...@@ -153,9 +154,13 @@ class ExecutorThreadWorker { ...@@ -153,9 +154,13 @@ class ExecutorThreadWorker {
virtual void TrainFiles(); virtual void TrainFiles();
// set fetch variable names from python interface assigned by users // set fetch variable names from python interface assigned by users
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names); void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
virtual void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); virtual void SetPSlibPtr(
virtual void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt) {}; std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
virtual void SetParamConfig(AsyncWorkerParamConfig* param_config) {}; virtual void SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) {}
virtual void SetParamConfig(
AsyncWorkerParamConfig * param_config) {}
private: private:
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
void CreateThreadOperators(const framework::ProgramDesc& program); void CreateThreadOperators(const framework::ProgramDesc& program);
...@@ -178,32 +183,37 @@ class ExecutorThreadWorker { ...@@ -178,32 +183,37 @@ class ExecutorThreadWorker {
Scope* root_scope_; Scope* root_scope_;
// a thread scope, father scope is global score which is shared // a thread scope, father scope is global score which is shared
Scope* thread_scope_; Scope* thread_scope_;
//private:
std::vector<std::string> fetch_var_names_; std::vector<std::string> fetch_var_names_;
std::vector<std::vector<float>> fetch_values_; std::vector<std::vector<float>> fetch_values_;
bool debug_; bool debug_;
}; };
class AsyncExecutorThreadWorker: public ExecutorThreadWorker { class AsyncExecutorThreadWorker: public ExecutorThreadWorker {
public: public:
AsyncExecutorThreadWorker(){}; AsyncExecutorThreadWorker() {}
virtual ~AsyncExecutorThreadWorker() {} virtual ~AsyncExecutorThreadWorker() {}
void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); void SetPSlibPtr(std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt); void SetPullDenseThread(std::shared_ptr<DensePullThread> dpt);
void SetParamConfig(AsyncWorkerParamConfig* param_config); void SetParamConfig(AsyncWorkerParamConfig* param_config);
void TrainFiles(); void TrainFiles();
void TrainOneNetwork(); void TrainOneNetwork();
void PrepareParams(); void PrepareParams();
void UpdateParams(); void UpdateParams();
void PullSparse(int table_id); void PullSparse(int table_id);
void FillSparse(int table_id); void FillSparse(int table_id);
void PushSparse(int table_id); void PushSparse(int table_id);
void PushDense(int table_id); void PushDense(int table_id);
void check_pull_push_memory(std::vector<uint64_t>& features, std::vector<float*>& push_g, int dim); void check_pull_push_memory(
void check_pull_push_memory(std::vector<uint64_t>& features, std::vector<std::vector<float>>& push_g, int dim); const std::vector<uint64_t>& features,
std::vector<float*>& push_g,
int dim);
void check_pull_push_memory(const std::vector<uint64_t>& features,
std::vector<std::vector<float>>& push_g,
int dim);
void collect_feasign_info(int table_id); void collect_feasign_info(int table_id);
private:
private:
struct FeasignInfo { struct FeasignInfo {
uint32_t slot; uint32_t slot;
uint32_t ins; uint32_t ins;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册