From 378037c535caf1b14a92b60f60b43eea1229f0a4 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Sat, 2 Feb 2019 12:54:18 +0800 Subject: [PATCH] make s_instance_ private to ensure singleton --- paddle/fluid/framework/device_worker.h | 2 +- paddle/fluid/framework/downpour_worker.cc | 2 ++ paddle/fluid/framework/fleet/fleet_wrapper.h | 7 +++++-- paddle/fluid/framework/pull_dense_worker.cc | 11 +---------- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index bb6fcdbd7b4..f663fa89f99 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -47,7 +47,6 @@ class PullDenseWorker { void IncreaseThreadVersion(int thread_id, uint64_t table_id); void ResetThreadVersion(uint64_t table_id); void Wait(std::vector<::std::future>* status_vec); - static std::shared_ptr s_instance_; static std::shared_ptr GetInstance() { if (NULL == s_instance_) { s_instance_.reset(new paddle::framework::PullDenseWorker()); @@ -61,6 +60,7 @@ class PullDenseWorker { bool CheckUpdateParam(uint64_t table_id); private: + static std::shared_ptr s_instance_; std::shared_ptr fleet_ptr_; PullDenseWorkerParameter param_; Scope* root_scope_; diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index f790fc7d695..ff2fc3f89ad 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -58,6 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) { skip_ops_[i] = param_.skip_ops(i); } skip_ops_.resize(param_.skip_ops_size()); + + fleet_ptr_ = FleetWrapper::GetInstance(); } void DownpourWorker::CollectLabelInfo(size_t table_idx) { diff --git a/paddle/fluid/framework/fleet/fleet_wrapper.h b/paddle/fluid/framework/fleet/fleet_wrapper.h index 945600daff0..8151d196bec 100644 --- a/paddle/fluid/framework/fleet/fleet_wrapper.h +++ b/paddle/fluid/framework/fleet/fleet_wrapper.h @@ -40,7 +40,8 @@ namespace framework { // Async: PullSparseVarsAsync(not implemented currently) // Push // Sync: PushSparseVarsSync -// Async: PushSparseVarsAsync +// Async: PushSparseVarsAsync(not implemented currently) +// Async: PushSparseVarsWithLabelAsync(with special usage) // Push dense variables to server in Async mode // Param: scope, table_id, var_names // Param: push_sparse_status @@ -109,7 +110,6 @@ class FleetWrapper { uint64_t RunServer(); void GatherServers(const std::vector& host_sign_list, int node_num); - static std::shared_ptr s_instance_; static std::shared_ptr GetInstance() { if (NULL == s_instance_) { s_instance_.reset(new paddle::framework::FleetWrapper()); @@ -121,6 +121,9 @@ class FleetWrapper { static std::shared_ptr pslib_ptr_; #endif + private: + static std::shared_ptr s_instance_; + private: FleetWrapper() {} diff --git a/paddle/fluid/framework/pull_dense_worker.cc b/paddle/fluid/framework/pull_dense_worker.cc index 7d94b5254d9..556424311a2 100644 --- a/paddle/fluid/framework/pull_dense_worker.cc +++ b/paddle/fluid/framework/pull_dense_worker.cc @@ -20,31 +20,25 @@ namespace framework { std::shared_ptr PullDenseWorker::s_instance_ = NULL; void PullDenseWorker::Initialize(const TrainerDesc& param) { - LOG(WARNING) << "going to initialize pull dense worker"; running_ = false; param_ = param.pull_dense_param(); threshold_ = param_.threshold(); thread_num_ = param_.device_num(); sleep_time_ms_ = param_.sleep_time_ms(); - LOG(WARNING) << "dense table size: " << param_.dense_table_size(); - LOG(WARNING) << "thread num: " << thread_num_; for (size_t i = 0; i < param_.dense_table_size(); ++i) { // setup dense variables for each table int var_num = param_.dense_table(i).dense_value_name_size(); - LOG(WARNING) << "var num: " << var_num; uint64_t tid = static_cast(param_.dense_table(i).table_id()); dense_value_names_[tid].resize(var_num); for (int j = 0; j < var_num; ++j) { dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j); - LOG(WARNING) << "dense value names " << j << " " - << dense_value_names_[tid][j]; } // setup training version for each table training_versions_[tid].resize(thread_num_, 0); last_versions_[tid] = 0; current_version_[tid] = 0; } - LOG(WARNING) << "initialize pull dense worker done."; + fleet_ptr_ = FleetWrapper::GetInstance(); } void PullDenseWorker::Wait(std::vector<::std::future>* status_vec) { @@ -98,10 +92,7 @@ void PullDenseWorker::Run() { } void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) { - LOG(WARNING) << "increase thread version input: " << thread_id << " table id " - << table_id; std::lock_guard lock(mutex_for_version_); - LOG(WARNING) << "going to increase"; training_versions_[table_id][thread_id]++; } -- GitLab