From ac3603bf3b8af73b541cdabef8c087ca087022ed Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 25 May 2021 13:43:50 +0800 Subject: [PATCH] add async save for sparse table (#33072) * add async save for sparse table * add load profiler for sparse table * add load info for sparse table --- paddle/fluid/distributed/common/utils.h | 10 ++- .../distributed/table/common_sparse_table.cc | 75 ++++++++++--------- paddle/fluid/distributed/table/table.h | 10 ++- .../distributed/fleet/runtime/the_one_ps.py | 5 -- 4 files changed, 54 insertions(+), 46 deletions(-) diff --git a/paddle/fluid/distributed/common/utils.h b/paddle/fluid/distributed/common/utils.h index f81f84b1e11..2305001ad6f 100644 --- a/paddle/fluid/distributed/common/utils.h +++ b/paddle/fluid/distributed/common/utils.h @@ -14,6 +14,8 @@ #pragma once +#include + #include #include #include @@ -83,5 +85,11 @@ std::string to_string(const std::vector& vec) { } return ss.str(); } + +inline double GetCurrentUS() { + struct timeval time; + gettimeofday(&time, NULL); + return 1e+6 * time.tv_sec + time.tv_usec; } -} +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc index a4f672c2963..b667aec186f 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -134,10 +134,23 @@ void ProcessALine(const std::vector& columns, const Meta& meta, } } -int64_t SaveToText(std::ostream* os, std::shared_ptr block, - const int mode) { - int64_t save_num = 0; +void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common, + const size_t shard_idx, const int64_t total) { + // save meta + std::stringstream stream; + stream << "param=" << common.table_name() << "\n"; + stream << "shard_id=" << shard_idx << "\n"; + stream << "row_names=" << paddle::string::join_strings(common.params(), ',') + << "\n"; + stream << "row_dims=" << paddle::string::join_strings(common.dims(), ',') + << "\n"; + stream << "count=" << total << "\n"; + os->write(stream.str().c_str(), sizeof(char) * stream.str().size()); +} +int64_t SaveValueToText(std::ostream* os, std::shared_ptr block, + std::shared_ptr<::ThreadPool> pool, const int mode) { + int64_t save_num = 0; for (auto& table : block->values_) { for (auto& value : table) { if (mode == SaveMode::delta && !value.second->need_save_) { @@ -334,16 +347,24 @@ int32_t CommonSparseTable::set_global_lr(float* lr) { int32_t CommonSparseTable::load(const std::string& path, const std::string& param) { + auto begin = GetCurrentUS(); rwlock_->WRLock(); - VLOG(3) << "sparse table load with " << path << " with meta " << param; LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_, &shard_values_); rwlock_->UNLock(); + auto end = GetCurrentUS(); + + auto varname = _config.common().table_name(); + VLOG(0) << "load " << varname << " with value: " << path + << " , meta: " << param + << " using: " << std::to_string((end - begin) / 1e+6) << " seconds"; + return 0; } int32_t CommonSparseTable::save(const std::string& dirname, const std::string& param) { + auto begin = GetCurrentUS(); rwlock_->WRLock(); int mode = std::stoi(param); VLOG(3) << "sparse table save: " << dirname << " mode: " << mode; @@ -356,36 +377,33 @@ int32_t CommonSparseTable::save(const std::string& dirname, VLOG(3) << "save " << varname << " in dir: " << var_store << " begin"; std::vector params(_config.common().params().begin(), _config.common().params().end()); + std::string shard_var_pre = string::Sprintf("%s.block%d", varname, _shard_idx); std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre); - std::unique_ptr value_out(new std::ofstream(value_)); + std::unique_ptr vs(new std::ofstream(value_)); int64_t total_ins = 0; for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { // save values - total_ins += SaveToText(value_out.get(), shard_values_[shard_id], mode); + auto shard_save_num = SaveValueToText(vs.get(), shard_values_[shard_id], + _shards_task_pool[shard_id], mode); + total_ins += shard_save_num; } - value_out->close(); + vs->close(); - // save meta - std::stringstream stream; - stream << "param=" << _config.common().table_name() << "\n"; - stream << "shard_id=" << _shard_idx << "\n"; - stream << "row_names=" - << paddle::string::join_strings(_config.common().params(), ',') - << "\n"; - stream << "row_dims=" - << paddle::string::join_strings(_config.common().dims(), ',') << "\n"; - stream << "count=" << total_ins << "\n"; std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre); - std::unique_ptr meta_out(new std::ofstream(meta_)); - meta_out->write(stream.str().c_str(), sizeof(char) * stream.str().size()); - meta_out->close(); - VLOG(3) << "save " << varname << " in dir: " << var_store << " done"; + std::unique_ptr ms(new std::ofstream(meta_)); + SaveMetaToText(ms.get(), _config.common(), _shard_idx, total_ins); + ms->close(); + + auto end = GetCurrentUS(); rwlock_->UNLock(); + VLOG(0) << "save " << varname << " with path: " << value_ + << " using: " << std::to_string((end - begin) / 1e+6) << " seconds"; + return 0; } @@ -403,8 +421,6 @@ std::pair CommonSparseTable::print_table_stat() { } int32_t CommonSparseTable::pour() { - rwlock_->RDLock(); - std::vector values; std::vector keys; @@ -421,14 +437,11 @@ int32_t CommonSparseTable::pour() { _push_sparse(keys.data(), values.data(), pull_reservoir_.size()); pull_reservoir_.clear(); - rwlock_->UNLock(); return 0; } int32_t CommonSparseTable::pull_sparse(float* pull_values, const PullSparseValue& pull_value) { - rwlock_->RDLock(); - auto shard_num = task_pool_size_; std::vector> tasks(shard_num); @@ -464,7 +477,6 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { tasks[shard_id].wait(); } - rwlock_->UNLock(); return 0; } @@ -507,7 +519,6 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, const float* values, size_t num) { - rwlock_->RDLock(); std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -531,7 +542,6 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { tasks[shard_id].wait(); } - rwlock_->UNLock(); return 0; } @@ -569,7 +579,6 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys, int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, const float** values, size_t num) { - rwlock_->RDLock(); std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -596,14 +605,11 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { tasks[shard_id].wait(); } - rwlock_->UNLock(); return 0; } int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, const float* values, size_t num) { - rwlock_->RDLock(); - std::vector> offset_bucket; offset_bucket.resize(task_pool_size_); @@ -635,14 +641,12 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { tasks[shard_id].wait(); } - rwlock_->UNLock(); return 0; } int32_t CommonSparseTable::flush() { return 0; } int32_t CommonSparseTable::shrink(const std::string& param) { - rwlock_->WRLock(); int threshold = std::stoi(param); VLOG(3) << "sparse table shrink: " << threshold; @@ -651,7 +655,6 @@ int32_t CommonSparseTable::shrink(const std::string& param) { VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink"; shard_values_[shard_id]->Shrink(threshold); } - rwlock_->UNLock(); return 0; } diff --git a/paddle/fluid/distributed/table/table.h b/paddle/fluid/distributed/table/table.h index 81a1ff5eced..55fc92c9b57 100644 --- a/paddle/fluid/distributed/table/table.h +++ b/paddle/fluid/distributed/table/table.h @@ -36,7 +36,7 @@ class Table { Table() {} virtual ~Table() {} virtual int32_t initialize(const TableParameter &config, - const FsClientParameter &fs_config) final; + const FsClientParameter &fs_config); virtual int32_t pull_dense(float *values, size_t num) = 0; virtual int32_t push_dense(const float *values, size_t num) = 0; @@ -58,7 +58,9 @@ class Table { virtual int32_t push_sparse(const uint64_t *keys, const float *values, size_t num) = 0; virtual int32_t push_sparse(const uint64_t *keys, const float **values, - size_t num){}; + size_t num) { + return 0; + } virtual int32_t push_sparse_param(const uint64_t *keys, const float *values, size_t num) { return 0; @@ -108,7 +110,7 @@ class Table { virtual int32_t save(const std::string &path, const std::string &converter) = 0; - virtual int32_t set_shard(size_t shard_idx, size_t shard_num) final { + virtual int32_t set_shard(size_t shard_idx, size_t shard_num) { _shard_idx = shard_idx; _shard_num = shard_num; return initialize_shard(); @@ -123,7 +125,7 @@ class Table { protected: virtual int32_t initialize() = 0; - virtual int32_t initialize_accessor() final; + virtual int32_t initialize_accessor(); virtual int32_t initialize_shard() = 0; virtual std::string table_dir(const std::string &model_dir) { return paddle::string::format_string("%s/%03d/", model_dir.c_str(), diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index d31fa549ad5..f18b82eaecd 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -847,8 +847,6 @@ class TheOnePSRuntime(RuntimeBase): dirname = os.path.normpath(dirname) pserver_id = self.role_maker._role_id() - import time - begin = time.time() for var_name in load_varnames: table_id = sparse_table_maps[var_name] path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, @@ -856,9 +854,6 @@ class TheOnePSRuntime(RuntimeBase): meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, "{}.block{}.meta".format(var_name, pserver_id)) self._server.load_sparse(path, meta, table_id) - end = time.time() - print("init sparse variables: {} cost time: {}".format(load_varnames, - end - begin)) def _run_server(self): if self.role_maker._is_heter_worker(): -- GitLab