未验证 提交 ac3603bf 编写于 作者: T tangwei12 提交者: GitHub

add async save for sparse table (#33072)

* add async save for sparse table
* add load profiler for sparse table
* add load info for sparse table
上级 dc72ffa5
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <sys/time.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -83,5 +85,11 @@ std::string to_string(const std::vector<T>& vec) { ...@@ -83,5 +85,11 @@ std::string to_string(const std::vector<T>& vec) {
} }
return ss.str(); return ss.str();
} }
inline double GetCurrentUS() {
struct timeval time;
gettimeofday(&time, NULL);
return 1e+6 * time.tv_sec + time.tv_usec;
} }
} } // namespace distributed
} // namespace paddle
...@@ -134,10 +134,23 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta, ...@@ -134,10 +134,23 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
} }
} }
int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block, void SaveMetaToText(std::ostream* os, const CommonAccessorParameter& common,
const int mode) { const size_t shard_idx, const int64_t total) {
int64_t save_num = 0; // save meta
std::stringstream stream;
stream << "param=" << common.table_name() << "\n";
stream << "shard_id=" << shard_idx << "\n";
stream << "row_names=" << paddle::string::join_strings(common.params(), ',')
<< "\n";
stream << "row_dims=" << paddle::string::join_strings(common.dims(), ',')
<< "\n";
stream << "count=" << total << "\n";
os->write(stream.str().c_str(), sizeof(char) * stream.str().size());
}
int64_t SaveValueToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
std::shared_ptr<::ThreadPool> pool, const int mode) {
int64_t save_num = 0;
for (auto& table : block->values_) { for (auto& table : block->values_) {
for (auto& value : table) { for (auto& value : table) {
if (mode == SaveMode::delta && !value.second->need_save_) { if (mode == SaveMode::delta && !value.second->need_save_) {
...@@ -334,16 +347,24 @@ int32_t CommonSparseTable::set_global_lr(float* lr) { ...@@ -334,16 +347,24 @@ int32_t CommonSparseTable::set_global_lr(float* lr) {
int32_t CommonSparseTable::load(const std::string& path, int32_t CommonSparseTable::load(const std::string& path,
const std::string& param) { const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock(); rwlock_->WRLock();
VLOG(3) << "sparse table load with " << path << " with meta " << param;
LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_, LoadFromText(path, param, _shard_idx, _shard_num, task_pool_size_,
&shard_values_); &shard_values_);
rwlock_->UNLock(); rwlock_->UNLock();
auto end = GetCurrentUS();
auto varname = _config.common().table_name();
VLOG(0) << "load " << varname << " with value: " << path
<< " , meta: " << param
<< " using: " << std::to_string((end - begin) / 1e+6) << " seconds";
return 0; return 0;
} }
int32_t CommonSparseTable::save(const std::string& dirname, int32_t CommonSparseTable::save(const std::string& dirname,
const std::string& param) { const std::string& param) {
auto begin = GetCurrentUS();
rwlock_->WRLock(); rwlock_->WRLock();
int mode = std::stoi(param); int mode = std::stoi(param);
VLOG(3) << "sparse table save: " << dirname << " mode: " << mode; VLOG(3) << "sparse table save: " << dirname << " mode: " << mode;
...@@ -356,36 +377,33 @@ int32_t CommonSparseTable::save(const std::string& dirname, ...@@ -356,36 +377,33 @@ int32_t CommonSparseTable::save(const std::string& dirname,
VLOG(3) << "save " << varname << " in dir: " << var_store << " begin"; VLOG(3) << "save " << varname << " in dir: " << var_store << " begin";
std::vector<std::string> params(_config.common().params().begin(), std::vector<std::string> params(_config.common().params().begin(),
_config.common().params().end()); _config.common().params().end());
std::string shard_var_pre = std::string shard_var_pre =
string::Sprintf("%s.block%d", varname, _shard_idx); string::Sprintf("%s.block%d", varname, _shard_idx);
std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre); std::string value_ = string::Sprintf("%s/%s.txt", var_store, shard_var_pre);
std::unique_ptr<std::ofstream> value_out(new std::ofstream(value_)); std::unique_ptr<std::ofstream> vs(new std::ofstream(value_));
int64_t total_ins = 0; int64_t total_ins = 0;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// save values // save values
total_ins += SaveToText(value_out.get(), shard_values_[shard_id], mode); auto shard_save_num = SaveValueToText(vs.get(), shard_values_[shard_id],
_shards_task_pool[shard_id], mode);
total_ins += shard_save_num;
} }
value_out->close(); vs->close();
// save meta
std::stringstream stream;
stream << "param=" << _config.common().table_name() << "\n";
stream << "shard_id=" << _shard_idx << "\n";
stream << "row_names="
<< paddle::string::join_strings(_config.common().params(), ',')
<< "\n";
stream << "row_dims="
<< paddle::string::join_strings(_config.common().dims(), ',') << "\n";
stream << "count=" << total_ins << "\n";
std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre); std::string meta_ = string::Sprintf("%s/%s.meta", var_store, shard_var_pre);
std::unique_ptr<std::ofstream> meta_out(new std::ofstream(meta_)); std::unique_ptr<std::ofstream> ms(new std::ofstream(meta_));
meta_out->write(stream.str().c_str(), sizeof(char) * stream.str().size()); SaveMetaToText(ms.get(), _config.common(), _shard_idx, total_ins);
meta_out->close(); ms->close();
VLOG(3) << "save " << varname << " in dir: " << var_store << " done";
auto end = GetCurrentUS();
rwlock_->UNLock(); rwlock_->UNLock();
VLOG(0) << "save " << varname << " with path: " << value_
<< " using: " << std::to_string((end - begin) / 1e+6) << " seconds";
return 0; return 0;
} }
...@@ -403,8 +421,6 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() { ...@@ -403,8 +421,6 @@ std::pair<int64_t, int64_t> CommonSparseTable::print_table_stat() {
} }
int32_t CommonSparseTable::pour() { int32_t CommonSparseTable::pour() {
rwlock_->RDLock();
std::vector<float> values; std::vector<float> values;
std::vector<uint64_t> keys; std::vector<uint64_t> keys;
...@@ -421,14 +437,11 @@ int32_t CommonSparseTable::pour() { ...@@ -421,14 +437,11 @@ int32_t CommonSparseTable::pour() {
_push_sparse(keys.data(), values.data(), pull_reservoir_.size()); _push_sparse(keys.data(), values.data(), pull_reservoir_.size());
pull_reservoir_.clear(); pull_reservoir_.clear();
rwlock_->UNLock();
return 0; return 0;
} }
int32_t CommonSparseTable::pull_sparse(float* pull_values, int32_t CommonSparseTable::pull_sparse(float* pull_values,
const PullSparseValue& pull_value) { const PullSparseValue& pull_value) {
rwlock_->RDLock();
auto shard_num = task_pool_size_; auto shard_num = task_pool_size_;
std::vector<std::future<int>> tasks(shard_num); std::vector<std::future<int>> tasks(shard_num);
...@@ -464,7 +477,6 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, ...@@ -464,7 +477,6 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait(); tasks[shard_id].wait();
} }
rwlock_->UNLock();
return 0; return 0;
} }
...@@ -507,7 +519,6 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values, ...@@ -507,7 +519,6 @@ int32_t CommonSparseTable::pull_sparse_ptr(char** pull_values,
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
rwlock_->RDLock();
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -531,7 +542,6 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -531,7 +542,6 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait(); tasks[shard_id].wait();
} }
rwlock_->UNLock();
return 0; return 0;
} }
...@@ -569,7 +579,6 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys, ...@@ -569,7 +579,6 @@ int32_t CommonSparseTable::push_sparse(const uint64_t* keys,
int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
const float** values, size_t num) { const float** values, size_t num) {
rwlock_->RDLock();
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -596,14 +605,11 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys, ...@@ -596,14 +605,11 @@ int32_t CommonSparseTable::_push_sparse(const uint64_t* keys,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait(); tasks[shard_id].wait();
} }
rwlock_->UNLock();
return 0; return 0;
} }
int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
const float* values, size_t num) { const float* values, size_t num) {
rwlock_->RDLock();
std::vector<std::vector<uint64_t>> offset_bucket; std::vector<std::vector<uint64_t>> offset_bucket;
offset_bucket.resize(task_pool_size_); offset_bucket.resize(task_pool_size_);
...@@ -635,14 +641,12 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, ...@@ -635,14 +641,12 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) { for (size_t shard_id = 0; shard_id < tasks.size(); ++shard_id) {
tasks[shard_id].wait(); tasks[shard_id].wait();
} }
rwlock_->UNLock();
return 0; return 0;
} }
int32_t CommonSparseTable::flush() { return 0; } int32_t CommonSparseTable::flush() { return 0; }
int32_t CommonSparseTable::shrink(const std::string& param) { int32_t CommonSparseTable::shrink(const std::string& param) {
rwlock_->WRLock();
int threshold = std::stoi(param); int threshold = std::stoi(param);
VLOG(3) << "sparse table shrink: " << threshold; VLOG(3) << "sparse table shrink: " << threshold;
...@@ -651,7 +655,6 @@ int32_t CommonSparseTable::shrink(const std::string& param) { ...@@ -651,7 +655,6 @@ int32_t CommonSparseTable::shrink(const std::string& param) {
VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink"; VLOG(4) << shard_id << " " << task_pool_size_ << " begin shrink";
shard_values_[shard_id]->Shrink(threshold); shard_values_[shard_id]->Shrink(threshold);
} }
rwlock_->UNLock();
return 0; return 0;
} }
......
...@@ -36,7 +36,7 @@ class Table { ...@@ -36,7 +36,7 @@ class Table {
Table() {} Table() {}
virtual ~Table() {} virtual ~Table() {}
virtual int32_t initialize(const TableParameter &config, virtual int32_t initialize(const TableParameter &config,
const FsClientParameter &fs_config) final; const FsClientParameter &fs_config);
virtual int32_t pull_dense(float *values, size_t num) = 0; virtual int32_t pull_dense(float *values, size_t num) = 0;
virtual int32_t push_dense(const float *values, size_t num) = 0; virtual int32_t push_dense(const float *values, size_t num) = 0;
...@@ -58,7 +58,9 @@ class Table { ...@@ -58,7 +58,9 @@ class Table {
virtual int32_t push_sparse(const uint64_t *keys, const float *values, virtual int32_t push_sparse(const uint64_t *keys, const float *values,
size_t num) = 0; size_t num) = 0;
virtual int32_t push_sparse(const uint64_t *keys, const float **values, virtual int32_t push_sparse(const uint64_t *keys, const float **values,
size_t num){}; size_t num) {
return 0;
}
virtual int32_t push_sparse_param(const uint64_t *keys, const float *values, virtual int32_t push_sparse_param(const uint64_t *keys, const float *values,
size_t num) { size_t num) {
return 0; return 0;
...@@ -108,7 +110,7 @@ class Table { ...@@ -108,7 +110,7 @@ class Table {
virtual int32_t save(const std::string &path, virtual int32_t save(const std::string &path,
const std::string &converter) = 0; const std::string &converter) = 0;
virtual int32_t set_shard(size_t shard_idx, size_t shard_num) final { virtual int32_t set_shard(size_t shard_idx, size_t shard_num) {
_shard_idx = shard_idx; _shard_idx = shard_idx;
_shard_num = shard_num; _shard_num = shard_num;
return initialize_shard(); return initialize_shard();
...@@ -123,7 +125,7 @@ class Table { ...@@ -123,7 +125,7 @@ class Table {
protected: protected:
virtual int32_t initialize() = 0; virtual int32_t initialize() = 0;
virtual int32_t initialize_accessor() final; virtual int32_t initialize_accessor();
virtual int32_t initialize_shard() = 0; virtual int32_t initialize_shard() = 0;
virtual std::string table_dir(const std::string &model_dir) { virtual std::string table_dir(const std::string &model_dir) {
return paddle::string::format_string("%s/%03d/", model_dir.c_str(), return paddle::string::format_string("%s/%03d/", model_dir.c_str(),
......
...@@ -847,8 +847,6 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -847,8 +847,6 @@ class TheOnePSRuntime(RuntimeBase):
dirname = os.path.normpath(dirname) dirname = os.path.normpath(dirname)
pserver_id = self.role_maker._role_id() pserver_id = self.role_maker._role_id()
import time
begin = time.time()
for var_name in load_varnames: for var_name in load_varnames:
table_id = sparse_table_maps[var_name] table_id = sparse_table_maps[var_name]
path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, path = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
...@@ -856,9 +854,6 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -856,9 +854,6 @@ class TheOnePSRuntime(RuntimeBase):
meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX, meta = os.path.join(dirname, var_name + PSERVER_SAVE_SUFFIX,
"{}.block{}.meta".format(var_name, pserver_id)) "{}.block{}.meta".format(var_name, pserver_id))
self._server.load_sparse(path, meta, table_id) self._server.load_sparse(path, meta, table_id)
end = time.time()
print("init sparse variables: {} cost time: {}".format(load_varnames,
end - begin))
def _run_server(self): def _run_server(self):
if self.role_maker._is_heter_worker(): if self.role_maker._is_heter_worker():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册