提交 378037c5 编写于 作者: D dongdaxiang

make s_instance_ private to ensure singleton

上级 a446d26e
......@@ -47,7 +47,6 @@ class PullDenseWorker {
void IncreaseThreadVersion(int thread_id, uint64_t table_id);
void ResetThreadVersion(uint64_t table_id);
void Wait(std::vector<::std::future<int32_t>>* status_vec);
static std::shared_ptr<PullDenseWorker> s_instance_;
static std::shared_ptr<PullDenseWorker> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::PullDenseWorker());
......@@ -61,6 +60,7 @@ class PullDenseWorker {
bool CheckUpdateParam(uint64_t table_id);
private:
static std::shared_ptr<PullDenseWorker> s_instance_;
std::shared_ptr<paddle::framework::FleetWrapper> fleet_ptr_;
PullDenseWorkerParameter param_;
Scope* root_scope_;
......
......@@ -58,6 +58,8 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
skip_ops_[i] = param_.skip_ops(i);
}
skip_ops_.resize(param_.skip_ops_size());
fleet_ptr_ = FleetWrapper::GetInstance();
}
void DownpourWorker::CollectLabelInfo(size_t table_idx) {
......
......@@ -40,7 +40,8 @@ namespace framework {
// Async: PullSparseVarsAsync(not implemented currently)
// Push
// Sync: PushSparseVarsSync
// Async: PushSparseVarsAsync
// Async: PushSparseVarsAsync(not implemented currently)
// Async: PushSparseVarsWithLabelAsync(with special usage)
// Push dense variables to server in Async mode
// Param<in>: scope, table_id, var_names
// Param<out>: push_sparse_status
......@@ -109,7 +110,6 @@ class FleetWrapper {
uint64_t RunServer();
void GatherServers(const std::vector<uint64_t>& host_sign_list, int node_num);
static std::shared_ptr<FleetWrapper> s_instance_;
static std::shared_ptr<FleetWrapper> GetInstance() {
if (NULL == s_instance_) {
s_instance_.reset(new paddle::framework::FleetWrapper());
......@@ -121,6 +121,9 @@ class FleetWrapper {
static std::shared_ptr<paddle::distributed::PSlib> pslib_ptr_;
#endif
private:
static std::shared_ptr<FleetWrapper> s_instance_;
private:
FleetWrapper() {}
......
......@@ -20,31 +20,25 @@ namespace framework {
std::shared_ptr<PullDenseWorker> PullDenseWorker::s_instance_ = NULL;
void PullDenseWorker::Initialize(const TrainerDesc& param) {
LOG(WARNING) << "going to initialize pull dense worker";
running_ = false;
param_ = param.pull_dense_param();
threshold_ = param_.threshold();
thread_num_ = param_.device_num();
sleep_time_ms_ = param_.sleep_time_ms();
LOG(WARNING) << "dense table size: " << param_.dense_table_size();
LOG(WARNING) << "thread num: " << thread_num_;
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
// setup dense variables for each table
int var_num = param_.dense_table(i).dense_value_name_size();
LOG(WARNING) << "var num: " << var_num;
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
dense_value_names_[tid].resize(var_num);
for (int j = 0; j < var_num; ++j) {
dense_value_names_[tid][j] = param_.dense_table(i).dense_value_name(j);
LOG(WARNING) << "dense value names " << j << " "
<< dense_value_names_[tid][j];
}
// setup training version for each table
training_versions_[tid].resize(thread_num_, 0);
last_versions_[tid] = 0;
current_version_[tid] = 0;
}
LOG(WARNING) << "initialize pull dense worker done.";
fleet_ptr_ = FleetWrapper::GetInstance();
}
void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
......@@ -98,10 +92,7 @@ void PullDenseWorker::Run() {
}
void PullDenseWorker::IncreaseThreadVersion(int thread_id, uint64_t table_id) {
LOG(WARNING) << "increase thread version input: " << thread_id << " table id "
<< table_id;
std::lock_guard<std::mutex> lock(mutex_for_version_);
LOG(WARNING) << "going to increase";
training_versions_[table_id][thread_id]++;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册