未验证 提交 5e7e7c9f 编写于 作者: T tangwei12 提交者: GitHub

[Optimize]SparseKV speedup and memory save (#32048)


Change-Id: Ie35a09772e46f7d90cb68ca82c1d18b9201d1abe

* large scale kv store optimize

Change-Id: I582cc661afdaa20749ec7493eae1b88c32b967f7

* replace std::unorded_map with roundrobin map

Change-Id: I48ee0efef38853876c92d982cdfcac6603c52c88

* remove license

* fix cpp lint

Change-Id: Ia21fafa65adc09bb9094f7dbc987e31d5af2686e
上级 186682fe
...@@ -146,6 +146,44 @@ void FleetWrapper::CreateClient2ClientConnection() { ...@@ -146,6 +146,44 @@ void FleetWrapper::CreateClient2ClientConnection() {
client2client_max_retry_); client2client_max_retry_);
} }
std::future<int32_t> FleetWrapper::PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_value_dim) {
fea_keys->clear();
fea_keys->resize(0);
fea_keys->reserve(MAX_FEASIGN_NUM);
for (auto name : var_names) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
continue;
}
LoDTensor* tensor = var->GetMutable<LoDTensor>();
CHECK(tensor != nullptr) << "tensor of var " << name << " is null";
int64_t* ids = tensor->data<int64_t>();
size_t len = tensor->numel();
for (auto i = 0u; i < len; ++i) {
if (ids[i] == 0u) {
continue;
}
fea_keys->push_back(static_cast<uint64_t>(ids[i]));
}
}
fea_values->resize(fea_keys->size() + 1);
for (auto& t : *fea_values) {
t.resize(fea_value_dim);
}
std::vector<float*> pull_result_ptr;
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
bool training = true;
return pserver_ptr_->_worker_ptr->pull_sparse(pull_result_ptr.data(),
table_id, fea_keys->data(),
fea_keys->size(), training);
}
void FleetWrapper::PullSparseVarsSync( void FleetWrapper::PullSparseVarsSync(
const Scope& scope, const uint64_t table_id, const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys, const std::vector<std::string>& var_names, std::vector<uint64_t>* fea_keys,
......
...@@ -84,6 +84,15 @@ class FleetWrapper { ...@@ -84,6 +84,15 @@ class FleetWrapper {
int fea_dim, int fea_dim,
const std::vector<std::string>& var_emb_names); const std::vector<std::string>& var_emb_names);
// Pull sparse variables from server in async mode
// Param<in>: scope, table_id, var_names, fea_keys, fea_dim
// Param<out>: fea_values std::future
std::future<int32_t> PullSparseVarsAsync(
const Scope& scope, const uint64_t table_id,
const std::vector<std::string>& var_names,
std::vector<uint64_t>* fea_keys,
std::vector<std::vector<float>>* fea_values, int fea_dim);
// Pull sparse variables from server in sync mode // Pull sparse variables from server in sync mode
// pull immediately to tensors // pull immediately to tensors
// is_training is true means training, false means inference, the behavior is // is_training is true means training, false means inference, the behavior is
......
...@@ -126,17 +126,17 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta, ...@@ -126,17 +126,17 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block, int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
const int mode) { const int mode) {
int64_t not_save_num = 0; int64_t not_save_num = 0;
for (auto value : block->values_) { for (auto& value : block->values_) {
if (mode == SaveMode::delta && !value.second->need_save_) { if (mode == SaveMode::delta && !value.second.need_save_) {
not_save_num++; not_save_num++;
continue; continue;
} }
auto* vs = value.second->data_.data(); auto* vs = value.second.data_;
std::stringstream ss; std::stringstream ss;
auto id = value.first; auto id = value.first;
ss << id << "\t" << value.second->count_ << "\t" ss << id << "\t" << value.second.count_ << "\t" << value.second.unseen_days_
<< value.second->unseen_days_ << "\t" << value.second->is_entry_ << "\t"; << "\t" << value.second.is_entry_ << "\t";
for (int i = 0; i < block->value_length_; i++) { for (int i = 0; i < block->value_length_; i++) {
ss << vs[i]; ss << vs[i];
...@@ -148,7 +148,7 @@ int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block, ...@@ -148,7 +148,7 @@ int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
if (mode == SaveMode::base || mode == SaveMode::delta) { if (mode == SaveMode::base || mode == SaveMode::delta) {
value.second->need_save_ = false; value.second.need_save_ = false;
} }
} }
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "paddle/fluid/distributed/common/utils.h" #include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/table/depends/initializers.h" #include "paddle/fluid/distributed/table/depends/initializers.h"
#include "paddle/fluid/distributed/thirdparty/round_robin.h"
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/framework/rw_lock.h"
...@@ -54,23 +55,53 @@ struct VALUE { ...@@ -54,23 +55,53 @@ struct VALUE {
unseen_days_(0), unseen_days_(0),
need_save_(false), need_save_(false),
is_entry_(false) { is_entry_(false) {
data_.resize(length); data_ = new float[length];
memset(data_.data(), 0, sizeof(float) * length); memset(data_, 0, sizeof(float) * length);
}
VALUE(const VALUE &value) {
length_ = value.length_;
count_ = value.count_;
unseen_days_ = value.unseen_days_;
need_save_ = value.need_save_;
is_entry_ = value.is_entry_;
data_ = new float[length_];
memcpy(data_, value.data_, sizeof(float) * length_);
}
VALUE &operator=(const VALUE &value) {
if (this != &value) {
delete[] data_;
length_ = value.length_;
count_ = value.count_;
unseen_days_ = value.unseen_days_;
need_save_ = value.need_save_;
is_entry_ = value.is_entry_;
data_ = new float[length_];
memcpy(data_, value.data_, sizeof(float) * length_);
}
return *this;
}
~VALUE() {
delete[] data_;
data_ = nullptr;
} }
size_t length_; size_t length_;
std::vector<float> data_;
int count_; int count_;
int unseen_days_; // use to check knock-out int unseen_days_; // use to check knock-out
bool need_save_; // whether need to save bool need_save_; // whether need to save
bool is_entry_; // whether knock-in bool is_entry_; // whether knock-in
float *data_;
}; };
inline bool count_entry(std::shared_ptr<VALUE> value, int threshold) { inline bool count_entry(VALUE *value, int threshold) {
return value->count_ >= threshold; return value->count_ >= threshold;
} }
inline bool probility_entry(std::shared_ptr<VALUE> value, float threshold) { inline bool probility_entry(VALUE *value, float threshold) {
UniformInitializer uniform = UniformInitializer({"uniform", "0", "0", "1"}); UniformInitializer uniform = UniformInitializer({"uniform", "0", "0", "1"});
return uniform.GetValue() >= threshold; return uniform.GetValue() >= threshold;
} }
...@@ -150,7 +181,7 @@ class ValueBlock { ...@@ -150,7 +181,7 @@ class ValueBlock {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
value_dims[i], value_dims_[i], value_dims[i], value_dims_[i],
platform::errors::InvalidArgument("value dims is not match")); platform::errors::InvalidArgument("value dims is not match"));
pts.push_back(values->data_.data() + pts.push_back(values.data_ +
value_offsets_.at(value_idx_.at(value_names[i]))); value_offsets_.at(value_idx_.at(value_names[i])));
} }
return pts; return pts;
...@@ -160,34 +191,35 @@ class ValueBlock { ...@@ -160,34 +191,35 @@ class ValueBlock {
float *Init(const uint64_t &id, const bool with_update = true, float *Init(const uint64_t &id, const bool with_update = true,
const int counter = 1) { const int counter = 1) {
if (!Has(id)) { if (!Has(id)) {
values_[id] = std::make_shared<VALUE>(value_length_); values_.emplace(std::make_pair(id, VALUE(value_length_)));
} }
auto &value = values_.at(id); auto &value = values_.at(id);
if (with_update) { if (with_update) {
AttrUpdate(value, counter); AttrUpdate(&value, counter);
} }
return value->data_.data(); return value.data_;
} }
VALUE *InitGet(const uint64_t &id, const bool with_update = true, VALUE *InitGet(const uint64_t &id, const bool with_update = true,
const int counter = 1) { const int counter = 1) {
if (!Has(id)) { if (!Has(id)) {
values_[id] = std::make_shared<VALUE>(value_length_); values_.emplace(std::make_pair(id, VALUE(value_length_)));
} }
auto &value = values_.at(id); auto &value = values_.at(id);
if (with_update) { if (with_update) {
AttrUpdate(value, counter); AttrUpdate(&value, counter);
} }
return value.get(); return &value;
} }
void AttrUpdate(std::shared_ptr<VALUE> value, const int counter) { void AttrUpdate(VALUE *value, const int counter) {
// update state // update state
value->unseen_days_ = 0; value->unseen_days_ = 0;
value->count_ += counter; value->count_ += counter;
...@@ -197,7 +229,7 @@ class ValueBlock { ...@@ -197,7 +229,7 @@ class ValueBlock {
if (value->is_entry_) { if (value->is_entry_) {
// initialize // initialize
for (size_t x = 0; x < value_names_.size(); ++x) { for (size_t x = 0; x < value_names_.size(); ++x) {
initializers_[x]->GetValue(value->data_.data() + value_offsets_[x], initializers_[x]->GetValue(value->data_ + value_offsets_[x],
value_dims_[x]); value_dims_[x]);
} }
value->need_save_ = true; value->need_save_ = true;
...@@ -212,27 +244,27 @@ class ValueBlock { ...@@ -212,27 +244,27 @@ class ValueBlock {
// dont jude if (has(id)) // dont jude if (has(id))
float *Get(const uint64_t &id) { float *Get(const uint64_t &id) {
auto &value = values_.at(id); auto &value = values_.at(id);
return value->data_.data(); return value.data_;
} }
// for load, to reset count, unseen_days // for load, to reset count, unseen_days
std::shared_ptr<VALUE> GetValue(const uint64_t &id) { return values_.at(id); } VALUE *GetValue(const uint64_t &id) { return &values_.at(id); }
bool GetEntry(const uint64_t &id) { bool GetEntry(const uint64_t &id) {
auto &value = values_.at(id); auto &value = values_.at(id);
return value->is_entry_; return value.is_entry_;
} }
void SetEntry(const uint64_t &id, const bool state) { void SetEntry(const uint64_t &id, const bool state) {
auto &value = values_.at(id); auto &value = values_.at(id);
value->is_entry_ = state; value.is_entry_ = state;
} }
void Shrink(const int threshold) { void Shrink(const int threshold) {
for (auto iter = values_.begin(); iter != values_.end();) { for (auto iter = values_.begin(); iter != values_.end();) {
auto &value = iter->second; auto &value = iter->second;
value->unseen_days_++; value.unseen_days_++;
if (value->unseen_days_ >= threshold) { if (value.unseen_days_ >= threshold) {
iter = values_.erase(iter); iter = values_.erase(iter);
} else { } else {
++iter; ++iter;
...@@ -254,7 +286,7 @@ class ValueBlock { ...@@ -254,7 +286,7 @@ class ValueBlock {
} }
public: public:
std::unordered_map<uint64_t, std::shared_ptr<VALUE>> values_; robin_hood::unordered_map<uint64_t, VALUE> values_;
size_t value_length_ = 0; size_t value_length_ = 0;
private: private:
...@@ -263,10 +295,11 @@ class ValueBlock { ...@@ -263,10 +295,11 @@ class ValueBlock {
const std::vector<int> &value_offsets_; const std::vector<int> &value_offsets_;
const std::unordered_map<std::string, int> &value_idx_; const std::unordered_map<std::string, int> &value_idx_;
std::function<bool(std::shared_ptr<VALUE>)> entry_func_; std::function<bool(VALUE *)> entry_func_;
std::vector<std::shared_ptr<Initializer>> initializers_; std::vector<std::shared_ptr<Initializer>> initializers_;
float threshold_; float threshold_;
}; };
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册