提交 c9b79989 编写于 作者: D dongdaxiang

fix tag in async_executor

上级 95b887c4
...@@ -102,139 +102,139 @@ void AsyncExecutor::GatherServers( ...@@ -102,139 +102,139 @@ void AsyncExecutor::GatherServers(
} }
void AsyncExecutor::InitParamConfig() { void AsyncExecutor::InitParamConfig() {
for (int i = 0; i < for (int i = 0; i <
_pslib_ptr->get_param()->server_param().\ _pslib_ptr->get_param()->server_param(). \
downpour_server_param().\ downpour_server_param(). \
downpour_table_param_size(); downpour_table_param_size();
++i) { ++i) {
if (_pslib_ptr->get_param()->server_param().\ if (_pslib_ptr->get_param()->server_param(). \
downpour_server_param().downpour_table_param(i).\ downpour_server_param().downpour_table_param(i). \
table_class().find("SparseTable") != -1) { table_class().find("SparseTable") != -1) {
_param_config.fea_dim = _pslib_ptr->get_param()->server_param().\ _param_config.fea_dim = _pslib_ptr->get_param()->server_param(). \
downpour_server_param().\ downpour_server_param(). \
downpour_table_param(i).\ downpour_table_param(i). \
accessor().fea_dim(); accessor().fea_dim();
break; break;
}
} }
_param_config.slot_dim = _param_config.fea_dim - 2; }
_param_config.tmp_push_dense_wait_times = static_cast<int32_t>( _param_config.slot_dim = _param_config.fea_dim - 2;
_pslib_ptr->get_param()->trainer_param().push_dense_per_batch()); _param_config.tmp_push_dense_wait_times = static_cast<int32_t>(
_param_config.tmp_push_sparse_wait_times = static_cast<int32_t>( _pslib_ptr->get_param()->trainer_param().push_dense_per_batch());
_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch()); _param_config.tmp_push_sparse_wait_times = static_cast<int32_t>(
_pslib_ptr->get_param()->trainer_param().push_sparse_per_batch());
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().skip_op_size(); for (auto t = 0u;
++t) { t < _pslib_ptr->get_param()->trainer_param().skip_op_size();
_param_config.skip_op.push_back( ++t) {
_pslib_ptr->get_param()->trainer_param().skip_op(t)); _param_config.skip_op.push_back(
_pslib_ptr->get_param()->trainer_param().skip_op(t));
}
for (auto t = 0u;
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size();
++t) {
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_key(i)] =
table.table_id();
} }
std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto t = 0u; for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
t < _pslib_ptr->get_param()->trainer_param().sparse_table_size(); tmp_sparse_gradient_variable_name.push_back(
++t) { table.slot_gradient(i));
auto& table = _pslib_ptr->get_param()->trainer_param().sparse_table(t);
std::vector<std::string> tmp_sparse_variable_name;
for (int i = 0u; i < table.slot_value_size(); ++i) {
tmp_sparse_variable_name.push_back(table.slot_value(i));
_param_config.slot_alias_to_table[table.slot_key(i)] =
table.table_id();
}
std::vector<std::string> tmp_sparse_gradient_variable_name;
for (auto i = 0u; i < table.slot_gradient_size(); ++i) {
tmp_sparse_gradient_variable_name.push_back(
table.slot_gradient(i));
}
_param_config.slot_input_vec[table.table_id()] =
std::move(tmp_sparse_variable_name);
_param_config.gradient_var[table.table_id()] =
std::move(tmp_sparse_gradient_variable_name);
_param_config.sparse_table_id.push_back(table.table_id());
} }
_param_config.slot_input_vec[table.table_id()] =
for (auto t = 0u; std::move(tmp_sparse_variable_name);
t < _pslib_ptr->get_param()->trainer_param().dense_table_size(); _param_config.gradient_var[table.table_id()] =
++t) { std::move(tmp_sparse_gradient_variable_name);
auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t); _param_config.sparse_table_id.push_back(table.table_id());
std::vector<std::string> tmp_dense_variable_name; }
for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
tmp_dense_variable_name.push_back(table.dense_variable_name(i)); for (auto t = 0u;
} t < _pslib_ptr->get_param()->trainer_param().dense_table_size();
std::vector<std::string> tmp_dense_gradient_variable_name; ++t) {
for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) { auto& table = _pslib_ptr->get_param()->trainer_param().dense_table(t);
tmp_dense_gradient_variable_name.push_back( std::vector<std::string> tmp_dense_variable_name;
table.dense_gradient_variable_name(i)); for (int i = 0u; i < table.dense_variable_name_size(); ++i) {
} tmp_dense_variable_name.push_back(table.dense_variable_name(i));
_param_config.dense_variable_name[table.table_id()] = }
std::move(tmp_dense_variable_name); std::vector<std::string> tmp_dense_gradient_variable_name;
_param_config.dense_gradient_variable_name[table.table_id()] = for (auto i = 0u; i < table.dense_gradient_variable_name_size(); ++i) {
std::move(tmp_dense_gradient_variable_name); tmp_dense_gradient_variable_name.push_back(
_param_config.dense_table_id.push_back(table.table_id()); table.dense_gradient_variable_name(i));
_param_config.dense_table_size.push_back(table.fea_dim());
} }
_param_config.dense_variable_name[table.table_id()] =
std::move(tmp_dense_variable_name);
_param_config.dense_gradient_variable_name[table.table_id()] =
std::move(tmp_dense_gradient_variable_name);
_param_config.dense_table_id.push_back(table.table_id());
_param_config.dense_table_size.push_back(table.fea_dim());
}
} }
void AsyncExecutor::InitModel() { void AsyncExecutor::InitModel() {
for (auto table_id : _param_config.dense_table_id) { for (auto table_id : _param_config.dense_table_id) {
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
for (auto& t : _param_config.dense_variable_name[table_id]) { for (auto& t : _param_config.dense_variable_name[table_id]) {
Variable* var = root_scope_->FindVar(t); Variable* var = root_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found"; CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->data<float>(); float* g = tensor->data<float>();
CHECK(g != nullptr) << "var[" << t << "] value not initialized"; CHECK(g != nullptr) << "var[" << t << "] value not initialized";
float init_range = 0.2; float init_range = 0.2;
int rown = tensor->dims()[0]; int rown = tensor->dims()[0];
init_range /= sqrt(rown); init_range /= sqrt(rown);
std::normal_distribution<float> ndistr(0.0, 1.0); std::normal_distribution<float> ndistr(0.0, 1.0);
for (auto i = 0u; i < tensor->numel(); ++i) { for (auto i = 0u; i < tensor->numel(); ++i) {
g[i] = ndistr(local_random_engine()) * init_range; g[i] = ndistr(local_random_engine()) * init_range;
} }
paddle::ps::Region reg(g, tensor->numel()); paddle::ps::Region reg(g, tensor->numel());
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
}
auto push_status =
_pslib_ptr->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push dense param failed, status[" << status << "]";
exit(-1);
}
} }
auto push_status =
_pslib_ptr->_worker_ptr->push_dense_param(
regions.data(), regions.size(), table_id);
push_status.wait();
auto status = push_status.get();
if (status != 0) {
LOG(FATAL) << "push dense param failed, status[" << status << "]";
exit(-1);
}
}
} }
void AsyncExecutor::SaveModel(const std::string& path) { void AsyncExecutor::SaveModel(const std::string& path) {
auto ret = _pslib_ptr->_worker_ptr->flush(); auto ret = _pslib_ptr->_worker_ptr->flush();
ret.wait(); ret.wait();
ret = _pslib_ptr->_worker_ptr->save(path, 0); ret = _pslib_ptr->_worker_ptr->save(path, 0);
ret.wait(); ret.wait();
int32_t feasign_cnt = ret.get(); int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0 if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0
LOG(FATAL) << "save model failed"; LOG(FATAL) << "save model failed";
exit(-1); exit(-1);
} }
} }
void AsyncExecutor::PrepareDenseThread(const std::string& mode) { void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
if (mode == "mpi") { if (mode == "mpi") {
DensePullThreadParam param; DensePullThreadParam param;
param.ps_client = _pslib_ptr->_worker_ptr;; param.ps_client = _pslib_ptr->_worker_ptr;;
param.threshold = 1; param.threshold = 1;
param.training_thread_num = actual_thread_num; param.training_thread_num = actual_thread_num;
param.root_scope = root_scope_; param.root_scope = root_scope_;
param.dense_params = &_param_config.dense_variable_name; param.dense_params = &_param_config.dense_variable_name;
_pull_dense_thread = std::shared_ptr<DensePullThread>( _pull_dense_thread = std::shared_ptr<DensePullThread>(
new DensePullThread(param)); new DensePullThread(param));
_pull_dense_thread->start(); _pull_dense_thread->start();
} }
} }
#endif #endif
......
...@@ -45,7 +45,8 @@ inline std::default_random_engine& local_random_engine() { ...@@ -45,7 +45,8 @@ inline std::default_random_engine& local_random_engine() {
engine_wrapper_t() { engine_wrapper_t() {
static std::atomic<uint64_t> x(0); static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, std::seed_seq sseq = {x++, x++, x++,
static_cast<uint64_t>(current_realtime() * 1000)}; static_cast<uint64_t>(
current_realtime() * 1000)};
engine.seed(sseq); engine.seed(sseq);
} }
}; };
...@@ -77,6 +78,7 @@ class AsyncExecutor { ...@@ -77,6 +78,7 @@ class AsyncExecutor {
void SaveModel(const std::string& path); void SaveModel(const std::string& path);
void InitParamConfig(); void InitParamConfig();
#endif #endif
private: private:
void CreateThreads(ExecutorThreadWorker* worker, void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program, const ProgramDesc& main_program,
...@@ -87,6 +89,7 @@ class AsyncExecutor { ...@@ -87,6 +89,7 @@ class AsyncExecutor {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
void PrepareDenseThread(const std::string& mode); void PrepareDenseThread(const std::string& mode);
#endif #endif
public: public:
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr; std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
......
...@@ -33,87 +33,87 @@ namespace framework { ...@@ -33,87 +33,87 @@ namespace framework {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
int DensePullThread::start() { int DensePullThread::start() {
_running = true; _running = true;
_t = std::thread(&DensePullThread::run, this); _t = std::thread(&DensePullThread::run, this);
return 0; return 0;
} }
void DensePullThread::run() { void DensePullThread::run() {
while (_running) { while (_running) {
_pull_dense_status.resize(0); _pull_dense_status.resize(0);
for (auto& t : _dense_variable_name) { for (auto& t : _dense_variable_name) {
if (check_update_param(t.first)) { if (check_update_param(t.first)) {
auto status = pull_dense(t.first); auto status = pull_dense(t.first);
_pull_dense_status.emplace_back(std::move(status)); _pull_dense_status.emplace_back(std::move(status));
reset_thread_version(t.first); reset_thread_version(t.first);
} }
} }
if (_pull_dense_status.size() != 0) { if (_pull_dense_status.size() != 0) {
wait_all(); wait_all();
}
usleep(_sleep_time_ms * 1000);
} }
usleep(_sleep_time_ms * 1000);
}
} }
bool DensePullThread::check_update_param(uint64_t table_id) { bool DensePullThread::check_update_param(uint64_t table_id) {
{ {
std::lock_guard<std::mutex> lock(_mutex_for_version); std::lock_guard<std::mutex> lock(_mutex_for_version);
auto& version = _training_versions[table_id]; auto& version = _training_versions[table_id];
_current_version[table_id] = _current_version[table_id] =
*(std::min_element(version.begin(), version.end())); *(std::min_element(version.begin(), version.end()));
} }
if (_current_version[table_id] - _last_versions[table_id] < _threshold) { if (_current_version[table_id] - _last_versions[table_id] < _threshold) {
return false; return false;
} }
return true; return true;
} }
void DensePullThread::reset_thread_version(uint64_t table_id) { void DensePullThread::reset_thread_version(uint64_t table_id) {
std::lock_guard<std::mutex> lock(_mutex_for_version); std::lock_guard<std::mutex> lock(_mutex_for_version);
_last_versions[table_id] = _current_version[table_id]; _last_versions[table_id] = _current_version[table_id];
} }
std::future<int32_t> DensePullThread::pull_dense(uint64_t table_id) { std::future<int32_t> DensePullThread::pull_dense(uint64_t table_id) {
auto& regions = _regions[table_id]; auto& regions = _regions[table_id];
regions.clear(); regions.clear();
auto& variables = _dense_variable_name[table_id]; auto& variables = _dense_variable_name[table_id];
regions.resize(variables.size()); regions.resize(variables.size());
for (auto i = 0u; i < variables.size(); ++i) { for (auto i = 0u; i < variables.size(); ++i) {
auto& t = variables[i]; auto& t = variables[i];
Variable* var = _root_scope->FindVar(t); Variable* var = _root_scope->FindVar(t);
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* w = tensor->data<float>(); float* w = tensor->data<float>();
paddle::ps::Region reg(w, tensor->numel()); paddle::ps::Region reg(w, tensor->numel());
regions[i] = std::move(reg); regions[i] = std::move(reg);
} }
return _ps_client->pull_dense(regions.data(), regions.size(), table_id); return _ps_client->pull_dense(regions.data(), regions.size(), table_id);
} }
void DensePullThread::wait_all() { void DensePullThread::wait_all() {
for (auto& t : _pull_dense_status) { for (auto& t : _pull_dense_status) {
t.wait(); t.wait();
auto status = t.get(); auto status = t.get();
if (status != 0) { if (status != 0) {
LOG(WARNING) << "pull dense failed times:" << LOG(WARNING) << "pull dense failed times:" <<
++_pull_dense_fail_times; ++_pull_dense_fail_times;
}
} }
}
if (_pull_dense_fail_times > 20) {
LOG(FATAL) << "pull dense failed times more than 20 times"; if (_pull_dense_fail_times > 20) {
exit(-1); LOG(FATAL) << "pull dense failed times more than 20 times";
} exit(-1);
}
_pull_dense_status.resize(0);
_pull_dense_status.resize(0);
} }
void DensePullThread::increase_thread_version( void DensePullThread::increase_thread_version(
int thread_id, uint64_t table_id) { int thread_id, uint64_t table_id) {
std::lock_guard<std::mutex> lock(_mutex_for_version); std::lock_guard<std::mutex> lock(_mutex_for_version);
_training_versions[table_id][thread_id]++; _training_versions[table_id][thread_id]++;
} }
#endif #endif
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) { void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0); auto& block = program.Block(0);
...@@ -336,56 +336,56 @@ void AsyncExecutorThreadWorker::TrainFiles() { ...@@ -336,56 +336,56 @@ void AsyncExecutorThreadWorker::TrainFiles() {
void AsyncExecutorThreadWorker::SetPSlibPtr( void AsyncExecutorThreadWorker::SetPSlibPtr(
std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) { std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {
_pslib_ptr = pslib_ptr; _pslib_ptr = pslib_ptr;
} }
void AsyncExecutorThreadWorker::SetPullDenseThread( void AsyncExecutorThreadWorker::SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) { std::shared_ptr<DensePullThread> dpt) {
_pull_dense_thread = dpt; _pull_dense_thread = dpt;
} }
void AsyncExecutorThreadWorker::TrainOneNetwork() { void AsyncExecutorThreadWorker::TrainOneNetwork() {
PrepareParams(); PrepareParams();
for (auto& op : ops_) { for (auto& op : ops_) {
if (op->Type().find("sgd") != std::string::npos) { if (op->Type().find("sgd") != std::string::npos) {
continue; continue;
} }
bool need_skip = false; bool need_skip = false;
for (auto t = 0u; t < _param_config->skip_op.size(); ++t) { for (auto t = 0u; t < _param_config->skip_op.size(); ++t) {
if (op->Type().find(_param_config->skip_op[t]) != if (op->Type().find(_param_config->skip_op[t]) !=
std::string::npos) { std::string::npos) {
need_skip = true; need_skip = true;
break; break;
} }
} }
if (!need_skip) { if (!need_skip) {
op->Run(*thread_scope_, place_); op->Run(*thread_scope_, place_);
}
} }
UpdateParams(); }
UpdateParams();
} }
void AsyncExecutorThreadWorker::SetParamConfig( void AsyncExecutorThreadWorker::SetParamConfig(
AsyncWorkerParamConfig* param_config) { AsyncWorkerParamConfig* param_config) {
_param_config = param_config; _param_config = param_config;
} }
void AsyncExecutorThreadWorker::PrepareParams() { void AsyncExecutorThreadWorker::PrepareParams() {
for (auto table_id : _param_config->sparse_table_id) { for (auto table_id : _param_config->sparse_table_id) {
PullSparse(table_id); PullSparse(table_id);
for (auto& t : _pull_sparse_status) { for (auto& t : _pull_sparse_status) {
t.wait(); t.wait();
auto status = t.get(); auto status = t.get();
if (status != 0) { if (status != 0) {
LOG(ERROR) << "pull sparse failed, status[" << status << "]"; LOG(ERROR) << "pull sparse failed, status[" << status << "]";
exit(-1); exit(-1);
} }
}
} }
_pull_sparse_status.resize(0); }
_pull_sparse_status.resize(0);
for (auto table_id : _param_config->sparse_table_id) { for (auto table_id : _param_config->sparse_table_id) {
FillSparse(table_id); FillSparse(table_id);
} }
} }
void AsyncExecutorThreadWorker::UpdateParams() { void AsyncExecutorThreadWorker::UpdateParams() {
...@@ -426,21 +426,20 @@ void AsyncExecutorThreadWorker::UpdateParams() { ...@@ -426,21 +426,20 @@ void AsyncExecutorThreadWorker::UpdateParams() {
} }
void AsyncExecutorThreadWorker::PushDense(int table_id) { void AsyncExecutorThreadWorker::PushDense(int table_id) {
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
for (auto& t : _param_config->dense_gradient_variable_name[table_id]) { for (auto& t : _param_config->dense_gradient_variable_name[table_id]) {
Variable* var = thread_scope_->FindVar(t); Variable* var = thread_scope_->FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found"; CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
int count = tensor->numel(); int count = tensor->numel();
float* g = tensor->data<float>(); float* g = tensor->data<float>();
paddle::ps::Region reg(g, count); paddle::ps::Region reg(g, count);
regions.emplace_back(std::move(reg)); regions.emplace_back(std::move(reg));
} }
auto status = _pslib_ptr->_worker_ptr->push_dense( auto status = _pslib_ptr->_worker_ptr->push_dense(
regions.data(), regions.size(), table_id); regions.data(), regions.size(), table_id);
_push_dense_status.push_back(std::move(status)); _push_dense_status.push_back(std::move(status));
} }
void AsyncExecutorThreadWorker::PullSparse(int table_id) { void AsyncExecutorThreadWorker::PullSparse(int table_id) {
...@@ -643,24 +642,24 @@ void AsyncExecutorThreadWorker::check_pull_push_memory( ...@@ -643,24 +642,24 @@ void AsyncExecutorThreadWorker::check_pull_push_memory(
const std::vector<uint64_t>& features, const std::vector<uint64_t>& features,
std::vector<std::vector<float>>& push_g, std::vector<std::vector<float>>& push_g,
int dim) { int dim) {
push_g.resize(features.size() + 1); push_g.resize(features.size() + 1);
for (auto& t : push_g) { for (auto& t : push_g) {
t.resize(dim); t.resize(dim);
} }
} }
void AsyncExecutorThreadWorker::check_pull_push_memory( void AsyncExecutorThreadWorker::check_pull_push_memory(
const std::vector<uint64_t>& features, const std::vector<uint64_t>& features,
std::vector<float*>& push_g, std::vector<float*>& push_g,
int dim) { int dim) {
if (features.size() > push_g.size()) { if (features.size() > push_g.size()) {
push_g.reserve(features.size() + 1); push_g.reserve(features.size() + 1);
auto size = features.size() - push_g.size() + 1; auto size = features.size() - push_g.size() + 1;
for (auto i = 0u; i < size; ++i) { for (auto i = 0u; i < size; ++i) {
float* ptr = new float[dim]; float* ptr = new float[dim];
push_g.push_back(ptr); push_g.push_back(ptr);
}
} }
}
} }
#endif #endif
......
...@@ -67,79 +67,79 @@ struct DensePullThreadParam { ...@@ -67,79 +67,79 @@ struct DensePullThreadParam {
class DensePullThread { class DensePullThread {
public: public:
explicit DensePullThread(const DensePullThreadParam& param) : explicit DensePullThread(const DensePullThreadParam& param) :
_running(false) { _running(false) {
_ps_client = param.ps_client; _ps_client = param.ps_client;
_threshold = param.threshold; _threshold = param.threshold;
_thread_num = param.training_thread_num; _thread_num = param.training_thread_num;
_root_scope = param.root_scope; _root_scope = param.root_scope;
_sleep_time_ms = param.sleep_time_ms; _sleep_time_ms = param.sleep_time_ms;
for (auto& t : *param.dense_params) { for (auto& t : *param.dense_params) {
_dense_variable_name[t.first].insert( _dense_variable_name[t.first].insert(
_dense_variable_name[t.first].end(), _dense_variable_name[t.first].end(),
t.second.begin(), t.second.end()); t.second.begin(), t.second.end());
_training_versions[t.first].resize(_thread_num, 0); _training_versions[t.first].resize(_thread_num, 0);
_last_versions[t.first] = 0; _last_versions[t.first] = 0;
_current_version[t.first] = 0; _current_version[t.first] = 0;
}
} }
}
int start();
int start();
void stop() {
if (_running) { void stop() {
_running = false; if (_running) {
_t.join(); _running = false;
} _t.join();
} }
}
void increase_thread_version(int thread_id, uint64_t table_id);
void reset_thread_version(uint64_t table_id); void increase_thread_version(int thread_id, uint64_t table_id);
std::future<int32_t> pull_dense(uint64_t table_id); void reset_thread_version(uint64_t table_id);
void pull_dense2(uint64_t table_id); std::future<int32_t> pull_dense(uint64_t table_id);
void wait_all(); void pull_dense2(uint64_t table_id);
void wait_all();
private: private:
void run(); void run();
bool check_update_param(uint64_t table_id); bool check_update_param(uint64_t table_id);
private: private:
std::shared_ptr<paddle::ps::PSClient> _ps_client; std::shared_ptr<paddle::ps::PSClient> _ps_client;
int _thread_num; int _thread_num;
int _threshold; int _threshold;
int _sleep_time_ms; int _sleep_time_ms;
Scope* _root_scope; Scope* _root_scope;
bool _running; bool _running;
std::map<uint64_t, uint64_t> _last_versions; std::map<uint64_t, uint64_t> _last_versions;
std::map<uint64_t, uint64_t> _current_version; std::map<uint64_t, uint64_t> _current_version;
std::mutex _mutex_for_version; std::mutex _mutex_for_version;
std::map<uint64_t, std::vector<uint64_t>> _training_versions; std::map<uint64_t, std::vector<uint64_t>> _training_versions;
std::map<uint64_t, std::vector<std::string>> _dense_variable_name; std::map<uint64_t, std::vector<std::string>> _dense_variable_name;
std::thread _t; std::thread _t;
std::vector<::std::future<int32_t>> _pull_dense_status; std::vector<::std::future<int32_t>> _pull_dense_status;
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions; std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
uint32_t _pull_dense_fail_times = 0; uint32_t _pull_dense_fail_times = 0;
std::vector<float> _base_norm_param; std::vector<float> _base_norm_param;
std::vector<float> _mean; std::vector<float> _mean;
std::vector<float> _scale; std::vector<float> _scale;
float _squared_sum_epsilon = 1e-4; float _squared_sum_epsilon = 1e-4;
std::mutex _mutex_for_mean_scale; std::mutex _mutex_for_mean_scale;
float _total_batch_num = 0; float _total_batch_num = 0;
}; };
#endif #endif
class ExecutorThreadWorker { class ExecutorThreadWorker {
public: public:
ExecutorThreadWorker() ExecutorThreadWorker()
: thread_id_(-1), root_scope_(NULL), thread_scope_(NULL), debug_(false) {} : thread_id_(-1), root_scope_(NULL), thread_scope_(NULL), debug_(false) {}
virtual ~ExecutorThreadWorker() {} virtual ~ExecutorThreadWorker() {}
void CreateThreadResource(const framework::ProgramDesc& program, void CreateThreadResource(const framework::ProgramDesc& program,
const paddle::platform::Place& place); const paddle::platform::Place& place);
void SetThreadId(int tid); void SetThreadId(int tid);
...@@ -160,7 +160,7 @@ class ExecutorThreadWorker { ...@@ -160,7 +160,7 @@ class ExecutorThreadWorker {
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names); void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
virtual void SetPSlibPtr( virtual void SetPSlibPtr(
std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {}; std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {}
virtual void SetPullDenseThread( virtual void SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) {} std::shared_ptr<DensePullThread> dpt) {}
virtual void SetParamConfig( virtual void SetParamConfig(
...@@ -218,32 +218,32 @@ class AsyncExecutorThreadWorker: public ExecutorThreadWorker { ...@@ -218,32 +218,32 @@ class AsyncExecutorThreadWorker: public ExecutorThreadWorker {
void check_pull_push_memory(const std::vector<uint64_t>& features, void check_pull_push_memory(const std::vector<uint64_t>& features,
std::vector<std::vector<float>>& push_g, std::vector<std::vector<float>>& push_g,
int dim); int dim);
void collect_feasign_info(int table_id); void collect_feasign_info(int table_id);
private: private:
struct FeasignInfo { struct FeasignInfo {
uint32_t slot; uint32_t slot;
uint32_t ins; uint32_t ins;
int64_t label; int64_t label;
}; };
std::map<uint64_t, std::vector<uint64_t>> _features; std::map<uint64_t, std::vector<uint64_t>> _features;
std::map<uint64_t, std::vector<FeasignInfo>> _fea_info; std::map<uint64_t, std::vector<FeasignInfo>> _fea_info;
std::map<uint64_t, std::vector<std::vector<float>>> _feature_value; std::map<uint64_t, std::vector<std::vector<float>>> _feature_value;
std::map<uint64_t, std::vector<std::vector<float>>> _feature_push_value; std::map<uint64_t, std::vector<std::vector<float>>> _feature_push_value;
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr; std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread; std::shared_ptr<DensePullThread> _pull_dense_thread;
std::vector<::std::future<int32_t>> _pull_sparse_status; std::vector<::std::future<int32_t>> _pull_sparse_status;
std::vector<::std::future<int32_t>> _pull_dense_status; std::vector<::std::future<int32_t>> _pull_dense_status;
std::vector<::std::future<int32_t>> _push_sparse_status; std::vector<::std::future<int32_t>> _push_sparse_status;
std::vector<::std::future<int32_t>> _push_dense_status; std::vector<::std::future<int32_t>> _push_dense_status;
AsyncWorkerParamConfig* _param_config; AsyncWorkerParamConfig* _param_config;
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册