提交 0075a7b0 编写于 作者: S seiriosPlus

add save delta for large scale kv

上级 20d435b1
...@@ -246,6 +246,7 @@ struct VALUE { ...@@ -246,6 +246,7 @@ struct VALUE {
std::vector<std::string> names_; std::vector<std::string> names_;
int count_; int count_;
bool seen_after_save_;
int unseen_days_; int unseen_days_;
bool is_entry_; bool is_entry_;
std::vector<std::vector<float>> values_; std::vector<std::vector<float>> values_;
...@@ -322,6 +323,7 @@ class ValueBlock { ...@@ -322,6 +323,7 @@ class ValueBlock {
auto value = new VALUE(value_names_); auto value = new VALUE(value_names_);
value->set(values); value->set(values);
value->seen_after_save_ = true;
value->count_ = count; value->count_ = count;
values_[id] = value; values_[id] = value;
} }
...@@ -590,9 +592,9 @@ class SparseVariable { ...@@ -590,9 +592,9 @@ class SparseVariable {
} }
} }
void Save(const std::string &dirname) { void Save(const std::string &dirname, const int mode = 0) {
rwlock_->WRLock(); rwlock_->WRLock();
VLOG(1) << "save " << meta_.name << " in dir: " << dirname << " begin"; VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " begin";
MkDirRecursively(dirname.c_str()); MkDirRecursively(dirname.c_str());
...@@ -601,22 +603,15 @@ class SparseVariable { ...@@ -601,22 +603,15 @@ class SparseVariable {
auto filename = string::Sprintf("%s/%s", dirname, value_name); auto filename = string::Sprintf("%s/%s", dirname, value_name);
filenames.push_back(filename); filenames.push_back(filename);
} }
SaveToSelectedRows(filenames, meta_.value_names);
// // save sparse to text SaveToSelectedRows(filenames, meta_.value_names, mode);
// std::vector<std::string> txt_filenames; VLOG(3) << "save " << meta_.name << " in dir: " << dirname << " done";
// for (auto &value_name : meta_.value_names) {
// auto filename = string::Sprintf("%s/%s.txt", dirname, value_name);
// txt_filenames.push_back(filename);
// }
// SaveToText(txt_filenames, meta_.value_names);
VLOG(1) << "save " << meta_.name << " in dir: " << dirname << " done";
rwlock_->UNLock(); rwlock_->UNLock();
} }
void SaveToSelectedRows(const std::vector<std::string> &filenames, void SaveToSelectedRows(const std::vector<std::string> &filenames,
const std::vector<std::string> &valuenames) { const std::vector<std::string> &valuenames,
const int mode) {
for (auto &value_name : valuenames) { for (auto &value_name : valuenames) {
auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(), auto it = std::find(meta_.value_names.begin(), meta_.value_names.end(),
value_name); value_name);
...@@ -630,14 +625,30 @@ class SparseVariable { ...@@ -630,14 +625,30 @@ class SparseVariable {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
int64_t ids_num = 0; std::vector<int64_t> ids;
for (auto &block : shard_blocks_) { for (auto &block : shard_blocks_) {
ids_num += block->values_.size(); for (auto value : block->values_) {
bool id_need_save = false;
// save all params
if (mode == 0) {
id_need_save = true;
} else {
id_need_save = value.second.seen_after_save_;
}
if (id_need_save) {
ids.push_back(value.first);
}
value.second.seen_after_save_ = false;
} }
}
VLOG(3) << "save " << ids.size() << " feasigns for " << meta_.name
<< " with mode: " << mode;
std::vector<std::shared_ptr<framework::Variable>> variables; std::vector<std::shared_ptr<framework::Variable>> variables;
std::vector<float *> tensors; std::vector<float *> tensors;
std::vector<int64_t> ids;
std::vector<int64_t> dims; std::vector<int64_t> dims;
for (int i = 0; i < static_cast<int>(filenames.size()); i++) { for (int i = 0; i < static_cast<int>(filenames.size()); i++) {
...@@ -646,7 +657,7 @@ class SparseVariable { ...@@ -646,7 +657,7 @@ class SparseVariable {
auto *slr = var->GetMutable<framework::SelectedRows>(); auto *slr = var->GetMutable<framework::SelectedRows>();
auto *src_t = slr->mutable_value(); auto *src_t = slr->mutable_value();
src_t->Resize({ids_num, dim}); src_t->Resize({ids.size(), dim});
auto *value = src_t->mutable_data<float>(place); auto *value = src_t->mutable_data<float>(place);
dims.push_back(dim); dims.push_back(dim);
...@@ -654,21 +665,19 @@ class SparseVariable { ...@@ -654,21 +665,19 @@ class SparseVariable {
tensors.push_back(value); tensors.push_back(value);
} }
int64_t offset = 0; std::vector<std::vector<std::vector<float> *>> *values;
for (auto &block : shard_blocks_) { Get(ids, variables, values);
for (auto value : block->values_) {
ids.push_back(value.first);
std::vector<std::vector<float> *> vss = value.second->get(valuenames);
int64_t offset = 0;
for (auto *value : values) {
auto vss = value;
for (int i = 0; i < static_cast<int>(vss.size()); i++) { for (int i = 0; i < static_cast<int>(vss.size()); i++) {
auto &vs = vss[i]; auto &vs = vss[i];
std::memcpy(tensors[i] + offset * dims[i], vs->data(), std::memcpy(tensors[i] + offset * dims[i], vs->data(),
sizeof(float) * dims[i]); sizeof(float) * dims[i]);
} }
offset += 1; offset += 1;
} }
}
for (auto &var : variables) { for (auto &var : variables) {
auto *slr = var->GetMutable<framework::SelectedRows>(); auto *slr = var->GetMutable<framework::SelectedRows>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册