提交 378037c5 编写于 作者: D dongdaxiang

make s_instance_ private to ensure singleton

上级 a446d26e
...@@ -47,7 +47,6 @@ class PullDenseWorker { ...@@ -47,7 +47,6 @@ class PullDenseWorker {
void IncreaseThreadVersion(int thread_id, uint64_t table_id); void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id); void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec); void Wait(std::vector<::std::future<int32_t>>* status_vec);
static std::shared_ptr<PullDenseWorker> s_instance_;
static std::shared_ptr<PullDenseWorker> GetInstance() { static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker()); s_instance_.reset(new paddle::framework::PullDenseWorker());
...@@ -61,6 +60,7 @@ class PullDenseWorker { ...@@ -61,6 +60,7 @@ class PullDenseWorker {
bool CheckUpdateParam(uint64_t table_id); bool CheckUpdateParam(uint64_t table_id);
private: private:
static std::shared_ptr<PullDenseWorker> s_instance_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_; std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
PullDenseWorkerParameter param_; PullDenseWorkerParameter param_;
Scope* root_scope_; Scope* root_scope_;
......
...@@ -58,6 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { ...@@ -58,6 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_[i] = param_.skip_ops(i); skip_ops_[i] = param_.skip_ops(i);
} }
skip_ops_.resize(param_.skip_ops_size()); skip_ops_.resize(param_.skip_ops_size());
fleet_ptr_ = FleetWrapper::GetInstance();
} }
void DownpourWorker::CollectLabelInfo(size_t table_idx) { void DownpourWorker::CollectLabelInfo(size_t table_idx) {
......
...@@ -40,7 +40,8 @@ namespace framework { ...@@ -40,7 +40,8 @@ namespace framework {
// Async: PullSparseVarsAsync(not implemented currently) // Async: PullSparseVarsAsync(not implemented currently)
// Push // Push
// Sync: PushSparseVarsSync // Sync: PushSparseVarsSync
// Async: PushSparseVarsAsync // Async: PushSparseVarsAsync(not implemented currently)
// Async: PushSparseVarsWithLabelAsync(with special usage)
// Push dense variables to server in Async mode // Push dense variables to server in Async mode
// Param<in>: scope, table_id, var_names // Param<in>: scope, table_id, var_names
// Param<out>: push_sparse_status // Param<out>: push_sparse_status
...@@ -109,7 +110,6 @@ class FleetWrapper { ...@@ -109,7 +110,6 @@ class FleetWrapper {
uint64_t RunServer(); uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num); void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
static std::shared_ptr<FleetWrapper> s_instance_;
static std::shared_ptr<FleetWrapper> GetInstance() { static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) { if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper()); s_instance_.reset(new paddle::framework::FleetWrapper());
...@@ -121,6 +121,9 @@ class FleetWrapper { ...@@ -121,6 +121,9 @@ class FleetWrapper {
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_; static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif #endif
private:
static std::shared_ptr<FleetWrapper> s_instance_;
private: private:
FleetWrapper() {} FleetWrapper() {}
......
...@@ -20,31 +20,25 @@ namespace framework { ...@@ -20,31 +20,25 @@ namespace framework {
std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL; std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL;
void PullDenseWorker::Initialize(const TrainerDesc& param) { void PullDenseWorker::Initialize(const TrainerDesc& param) {
LOG(WARNING) << "going to initialize pull dense worker";
running_ = false; running_ = false;
param_ = param.pull_dense_param(); param_ = param.pull_dense_param();
threshold_ = param_.threshold(); threshold_ = param_.threshold();
thread_num_ = param_.device_num(); thread_num_ = param_.device_num();
sleep_time_ms_ = param_.sleep_time_ms(); sleep_time_ms_ = param_.sleep_time_ms();
LOG(WARNING) << "dense table size: " << param_.dense_table_size();
LOG(WARNING) << "thread num: " << thread_num_;
for (size_t i = 0; i < param_.dense_table_size(); ++i) { for (size_t i = 0; i < param_.dense_table_size(); ++i) {
// setup dense variables for each table // setup dense variables for each table
int var_num = param_.dense_table(i).dense_value_name_size(); int var_num = param_.dense_table(i).dense_value_name_size();
LOG(WARNING) << "var num: " << var_num;
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id()); uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
dense_value_names_[tid].resize(var_num); dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) { for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j); dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j);
LOG(WARNING) << "dense value names " << j << " "
<< dense_value_names_[tid][j];
} }
// setup training version for each table // setup training version for each table
training_versions_[tid].resize(thread_num_, 0); training_versions_[tid].resize(thread_num_, 0);
last_versions_[tid] = 0; last_versions_[tid] = 0;
current_version_[tid] = 0; current_version_[tid] = 0;
} }
LOG(WARNING) << "initialize pull dense worker done."; fleet_ptr_ = FleetWrapper::GetInstance();
} }
void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) { void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
...@@ -98,10 +92,7 @@ void PullDenseWorker::Run() { ...@@ -98,10 +92,7 @@ void PullDenseWorker::Run() {
} }
void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) { void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) {
LOG(WARNING) << "increase thread version input: " << thread_id << " table id "
<< table_id;
std::lock_guard<std::mutex> lock(mutex_for_version_); std::lock_guard<std::mutex> lock(mutex_for_version_);
LOG(WARNING) << "going to increase";
training_versions_[table_id][thread_id]++; training_versions_[table_id][thread_id]++;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册