未验证 提交 36710ebc 编写于 作者: T tangwei12 提交者: GitHub

test=develop, save/load, shrink (#30625) (#31107)

* test=develop, save/load, shrink
Co-authored-by: NseiriosPlus <tangwei12@baidu.com>
Co-authored-by: N123malin <malin10@baidu.com>
上级 29543da5
...@@ -479,9 +479,15 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) { ...@@ -479,9 +479,15 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) {
} }
} }
void FleetWrapper::ShrinkSparseTable(int table_id) { void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) {
auto ret = pserver_ptr_->_worker_ptr->shrink(table_id); auto* communicator = Communicator::GetInstance();
auto ret =
communicator->_worker_ptr->shrink(table_id, std::to_string(threshold));
ret.wait(); ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
LOG(ERROR) << "shrink sparse table stat failed";
}
} }
void FleetWrapper::ClearModel() { void FleetWrapper::ClearModel() {
......
...@@ -207,7 +207,7 @@ class FleetWrapper { ...@@ -207,7 +207,7 @@ class FleetWrapper {
// clear one table // clear one table
void ClearOneTable(const uint64_t table_id); void ClearOneTable(const uint64_t table_id);
// shrink sparse table // shrink sparse table
void ShrinkSparseTable(int table_id); void ShrinkSparseTable(int table_id, int threshold);
// shrink dense table // shrink dense table
void ShrinkDenseTable(int table_id, Scope* scope, void ShrinkDenseTable(int table_id, Scope* scope,
std::vector<std::string> var_list, float decay, std::vector<std::string> var_list, float decay,
......
...@@ -353,8 +353,9 @@ std::future<int32_t> BrpcPsClient::send_save_cmd( ...@@ -353,8 +353,9 @@ std::future<int32_t> BrpcPsClient::send_save_cmd(
return fut; return fut;
} }
std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id) { std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id,
return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")}); const std::string threshold) {
return send_cmd(table_id, PS_SHRINK_TABLE, {threshold});
} }
std::future<int32_t> BrpcPsClient::load(const std::string &epoch, std::future<int32_t> BrpcPsClient::load(const std::string &epoch,
......
...@@ -102,7 +102,8 @@ class BrpcPsClient : public PSClient { ...@@ -102,7 +102,8 @@ class BrpcPsClient : public PSClient {
} }
virtual int32_t create_client2client_connection( virtual int32_t create_client2client_connection(
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
virtual std::future<int32_t> shrink(uint32_t table_id) override; virtual std::future<int32_t> shrink(uint32_t table_id,
const std::string threshold) override;
virtual std::future<int32_t> load(const std::string &epoch, virtual std::future<int32_t> load(const std::string &epoch,
const std::string &mode) override; const std::string &mode) override;
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch, virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
......
...@@ -460,6 +460,8 @@ int32_t BrpcPsService::save_one_table(Table *table, ...@@ -460,6 +460,8 @@ int32_t BrpcPsService::save_one_table(Table *table,
table->flush(); table->flush();
int32_t feasign_size = 0; int32_t feasign_size = 0;
VLOG(0) << "save one table " << request.params(0) << " " << request.params(1);
feasign_size = table->save(request.params(0), request.params(1)); feasign_size = table->save(request.params(0), request.params(1));
if (feasign_size < 0) { if (feasign_size < 0) {
set_response_code(response, -1, "table save failed"); set_response_code(response, -1, "table save failed");
...@@ -491,10 +493,18 @@ int32_t BrpcPsService::shrink_table(Table *table, ...@@ -491,10 +493,18 @@ int32_t BrpcPsService::shrink_table(Table *table,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 1, threshold");
return -1;
}
table->flush(); table->flush();
if (table->shrink() != 0) { if (table->shrink(request.params(0)) != 0) {
set_response_code(response, -1, "table shrink failed"); set_response_code(response, -1, "table shrink failed");
return -1;
} }
VLOG(0) << "Pserver Shrink Finished";
return 0; return 0;
} }
......
...@@ -69,7 +69,8 @@ class PSClient { ...@@ -69,7 +69,8 @@ class PSClient {
int max_retry) = 0; int max_retry) = 0;
// 触发table数据退场 // 触发table数据退场
virtual std::future<int32_t> shrink(uint32_t table_id) = 0; virtual std::future<int32_t> shrink(uint32_t table_id,
const std::string threshold) = 0;
// 全量table进行数据load // 全量table进行数据load
virtual std::future<int32_t> load(const std::string &epoch, virtual std::future<int32_t> load(const std::string &epoch,
......
...@@ -58,7 +58,7 @@ class CommonDenseTable : public DenseTable { ...@@ -58,7 +58,7 @@ class CommonDenseTable : public DenseTable {
} }
virtual int32_t flush() override { return 0; } virtual int32_t flush() override { return 0; }
virtual int32_t shrink() override { return 0; } virtual int32_t shrink(const std::string& param) override { return 0; }
virtual void clear() override { return; } virtual void clear() override { return; }
protected: protected:
......
...@@ -22,9 +22,12 @@ ...@@ -22,9 +22,12 @@
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#define PSERVER_SAVE_SUFFIX "_txt" #define PSERVER_SAVE_SUFFIX "_txt"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
enum SaveMode { all, base, delta };
struct Meta { struct Meta {
std::string param; std::string param;
int shard_id; int shard_id;
...@@ -94,12 +97,9 @@ struct Meta { ...@@ -94,12 +97,9 @@ struct Meta {
void ProcessALine(const std::vector<std::string>& columns, const Meta& meta, void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
std::vector<std::vector<float>>* values) { std::vector<std::vector<float>>* values) {
PADDLE_ENFORCE_EQ(columns.size(), 2, auto colunmn_size = columns.size();
paddle::platform::errors::InvalidArgument( auto load_values =
"The data format does not meet the requirements. It " paddle::string::split_string<std::string>(columns[colunmn_size - 1], ",");
"should look like feasign_id \t params."));
auto load_values = paddle::string::split_string<std::string>(columns[1], ",");
values->reserve(meta.names.size()); values->reserve(meta.names.size());
int offset = 0; int offset = 0;
...@@ -121,11 +121,18 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta, ...@@ -121,11 +121,18 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block, int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
const int mode) { const int mode) {
int64_t not_save_num = 0;
for (auto value : block->values_) { for (auto value : block->values_) {
if (mode == SaveMode::delta && !value.second->need_save_) {
not_save_num++;
continue;
}
auto* vs = value.second->data_.data(); auto* vs = value.second->data_.data();
std::stringstream ss; std::stringstream ss;
auto id = value.first; auto id = value.first;
ss << id << "\t"; ss << id << "\t" << value.second->count_ << "\t"
<< value.second->unseen_days_ << "\t" << value.second->is_entry_ << "\t";
for (int i = 0; i < block->value_length_; i++) { for (int i = 0; i < block->value_length_; i++) {
ss << vs[i]; ss << vs[i];
...@@ -135,9 +142,13 @@ int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block, ...@@ -135,9 +142,13 @@ int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
ss << "\n"; ss << "\n";
os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
if (mode == SaveMode::base || mode == SaveMode::delta) {
value.second->need_save_ = false;
}
} }
return block->values_.size(); return block->values_.size() - not_save_num;
} }
int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
...@@ -165,8 +176,21 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, ...@@ -165,8 +176,21 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
std::vector<std::vector<float>> kvalues; std::vector<std::vector<float>> kvalues;
ProcessALine(values, meta, &kvalues); ProcessALine(values, meta, &kvalues);
// warning: need fix
block->Init(id); block->Init(id, false);
auto value_instant = block->GetValue(id);
if (values.size() == 5) {
value_instant->count_ = std::stoi(values[1]);
value_instant->unseen_days_ = std::stoi(values[2]);
value_instant->is_entry_ = static_cast<bool>(std::stoi(values[3]));
}
std::vector<float*> block_values = block->Get(id, meta.names, meta.dims);
auto blas = GetBlas<float>();
for (int x = 0; x < meta.names.size(); ++x) {
blas.VCOPY(meta.dims[x], kvalues[x].data(), block_values[x]);
}
} }
return 0; return 0;
...@@ -393,7 +417,7 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys, ...@@ -393,7 +417,7 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys,
for (int i = 0; i < offsets.size(); ++i) { for (int i = 0; i < offsets.size(); ++i) {
auto offset = offsets[i]; auto offset = offsets[i];
auto id = keys[offset]; auto id = keys[offset];
auto* value = block->InitFromInitializer(id); auto* value = block->Init(id);
std::copy_n(value + param_offset_, param_dim_, std::copy_n(value + param_offset_, param_dim_,
pull_values + param_dim_ * offset); pull_values + param_dim_ * offset);
} }
...@@ -488,9 +512,10 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, ...@@ -488,9 +512,10 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
for (int i = 0; i < offsets.size(); ++i) { for (int i = 0; i < offsets.size(); ++i) {
auto offset = offsets[i]; auto offset = offsets[i];
auto id = keys[offset]; auto id = keys[offset];
auto* value = block->InitFromInitializer(id); auto* value = block->Init(id, false);
std::copy_n(values + param_dim_ * offset, param_dim_, std::copy_n(values + param_dim_ * offset, param_dim_,
value + param_offset_); value + param_offset_);
block->SetEntry(id, true);
} }
return 0; return 0;
}); });
...@@ -505,10 +530,20 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, ...@@ -505,10 +530,20 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
int32_t CommonSparseTable::flush() { return 0; } int32_t CommonSparseTable::flush() { return 0; }
int32_t CommonSparseTable::shrink() { int32_t CommonSparseTable::shrink(const std::string& param) {
VLOG(0) << "shrink coming soon"; rwlock_->WRLock();
int threshold = std::stoi(param);
VLOG(0) << "sparse table shrink: " << threshold;
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
// shrink
VLOG(0) << shard_id << " " << task_pool_size_ << " begin shrink";
shard_values_[shard_id]->Shrink(threshold);
}
rwlock_->UNLock();
return 0; return 0;
} }
void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; } void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; }
} // namespace distributed } // namespace distributed
......
...@@ -73,7 +73,7 @@ class CommonSparseTable : public SparseTable { ...@@ -73,7 +73,7 @@ class CommonSparseTable : public SparseTable {
virtual int32_t pour(); virtual int32_t pour();
virtual int32_t flush(); virtual int32_t flush();
virtual int32_t shrink(); virtual int32_t shrink(const std::string& param);
virtual void clear(); virtual void clear();
protected: protected:
......
...@@ -108,7 +108,7 @@ class DenseTable : public Table { ...@@ -108,7 +108,7 @@ class DenseTable : public Table {
int32_t push_dense_param(const float *values, size_t num) override { int32_t push_dense_param(const float *values, size_t num) override {
return 0; return 0;
} }
int32_t shrink() override { return 0; } int32_t shrink(const std::string &param) override { return 0; }
}; };
class BarrierTable : public Table { class BarrierTable : public Table {
...@@ -133,7 +133,7 @@ class BarrierTable : public Table { ...@@ -133,7 +133,7 @@ class BarrierTable : public Table {
int32_t push_dense_param(const float *values, size_t num) override { int32_t push_dense_param(const float *values, size_t num) override {
return 0; return 0;
} }
int32_t shrink() override { return 0; } int32_t shrink(const std::string &param) override { return 0; }
virtual void clear(){}; virtual void clear(){};
virtual int32_t flush() { return 0; }; virtual int32_t flush() { return 0; };
virtual int32_t load(const std::string &path, const std::string &param) { virtual int32_t load(const std::string &path, const std::string &param) {
......
...@@ -47,43 +47,34 @@ namespace distributed { ...@@ -47,43 +47,34 @@ namespace distributed {
enum Mode { training, infer }; enum Mode { training, infer };
template <typename T>
inline bool entry(const int count, const T threshold);
template <>
inline bool entry<std::string>(const int count, const std::string threshold) {
return true;
}
template <>
inline bool entry<int>(const int count, const int threshold) {
return count >= threshold;
}
template <>
inline bool entry<float>(const int count, const float threshold) {
UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
return uniform.GetValue() >= threshold;
}
struct VALUE { struct VALUE {
explicit VALUE(size_t length) explicit VALUE(size_t length)
: length_(length), : length_(length),
count_(1), count_(0),
unseen_days_(0), unseen_days_(0),
seen_after_last_save_(true), need_save_(false),
is_entry_(true) { is_entry_(false) {
data_.resize(length); data_.resize(length);
memset(data_.data(), 0, sizeof(float) * length);
} }
size_t length_; size_t length_;
std::vector<float> data_; std::vector<float> data_;
int count_; int count_;
int unseen_days_; int unseen_days_; // use to check knock-out
bool seen_after_last_save_; bool need_save_; // whether need to save
bool is_entry_; bool is_entry_; // whether knock-in
}; };
inline bool count_entry(std::shared_ptr<VALUE> value, int threshold) {
return value->count_ >= threshold;
}
inline bool probility_entry(std::shared_ptr<VALUE> value, float threshold) {
UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
return uniform.GetValue() >= threshold;
}
class ValueBlock { class ValueBlock {
public: public:
explicit ValueBlock(const std::vector<std::string> &value_names, explicit ValueBlock(const std::vector<std::string> &value_names,
...@@ -102,21 +93,21 @@ class ValueBlock { ...@@ -102,21 +93,21 @@ class ValueBlock {
// for Entry // for Entry
{ {
if (entry_attr == "none") { auto slices = string::split_string<std::string>(entry_attr, "&");
has_entry_ = false; if (slices[0] == "none") {
entry_func_ = std::bind(&count_entry, std::placeholders::_1, 0);
} else if (slices[0] == "count_filter") {
int threshold = std::stoi(slices[1]);
entry_func_ = std::bind(&count_entry, std::placeholders::_1, threshold);
} else if (slices[0] == "probability") {
float threshold = std::stof(slices[1]);
entry_func_ = entry_func_ =
std::bind(entry<std::string>, std::placeholders::_1, "none"); std::bind(&probility_entry, std::placeholders::_1, threshold);
} else { } else {
has_entry_ = true; PADDLE_THROW(platform::errors::InvalidArgument(
auto slices = string::split_string<std::string>(entry_attr, "&"); "Not supported Entry Type : %s, Only support [count_filter, "
if (slices[0] == "count_filter") { "probability]",
int threshold = std::stoi(slices[1]); slices[0]));
entry_func_ = std::bind(entry<int>, std::placeholders::_1, threshold);
} else if (slices[0] == "probability") {
float threshold = std::stof(slices[1]);
entry_func_ =
std::bind(entry<float>, std::placeholders::_1, threshold);
}
} }
} }
...@@ -144,58 +135,87 @@ class ValueBlock { ...@@ -144,58 +135,87 @@ class ValueBlock {
~ValueBlock() {} ~ValueBlock() {}
float *Init(const uint64_t &id) {
auto value = std::make_shared<VALUE>(value_length_);
for (int x = 0; x < value_names_.size(); ++x) {
initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
value_dims_[x]);
}
values_[id] = value;
return value->data_.data();
}
std::vector<float *> Get(const uint64_t &id, std::vector<float *> Get(const uint64_t &id,
const std::vector<std::string> &value_names) { const std::vector<std::string> &value_names,
const std::vector<int> &value_dims) {
auto pts = std::vector<float *>(); auto pts = std::vector<float *>();
pts.reserve(value_names.size()); pts.reserve(value_names.size());
auto &values = values_.at(id); auto &values = values_.at(id);
for (int i = 0; i < static_cast<int>(value_names.size()); i++) { for (int i = 0; i < static_cast<int>(value_names.size()); i++) {
PADDLE_ENFORCE_EQ(
value_dims[i], value_dims_[i],
platform::errors::InvalidArgument("value dims is not match"));
pts.push_back(values->data_.data() + pts.push_back(values->data_.data() +
value_offsets_.at(value_idx_.at(value_names[i]))); value_offsets_.at(value_idx_.at(value_names[i])));
} }
return pts; return pts;
} }
float *Get(const uint64_t &id) { // pull
auto pts = std::vector<std::vector<float> *>(); float *Init(const uint64_t &id, const bool with_update = true) {
auto &values = values_.at(id); if (!Has(id)) {
values_[id] = std::make_shared<VALUE>(value_length_);
}
auto &value = values_.at(id);
return values->data_.data(); if (with_update) {
AttrUpdate(value);
}
return value->data_.data();
} }
float *InitFromInitializer(const uint64_t &id) { void AttrUpdate(std::shared_ptr<VALUE> value) {
if (Has(id)) { // update state
if (has_entry_) { value->unseen_days_ = 0;
Update(id); ++value->count_;
if (!value->is_entry_) {
value->is_entry_ = entry_func_(value);
if (value->is_entry_) {
// initialize
for (int x = 0; x < value_names_.size(); ++x) {
initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
value_dims_[x]);
}
} }
return Get(id);
} }
return Init(id);
value->need_save_ = true;
return;
} }
// dont jude if (has(id))
float *Get(const uint64_t &id) {
auto &value = values_.at(id);
return value->data_.data();
}
// for load, to reset count, unseen_days
std::shared_ptr<VALUE> GetValue(const uint64_t &id) { return values_.at(id); }
bool GetEntry(const uint64_t &id) { bool GetEntry(const uint64_t &id) {
auto value = values_.at(id); auto &value = values_.at(id);
return value->is_entry_; return value->is_entry_;
} }
void Update(const uint64_t id) { void SetEntry(const uint64_t &id, const bool state) {
auto value = values_.at(id); auto &value = values_.at(id);
value->unseen_days_ = 0; value->is_entry_ = state;
auto count = ++value->count_; }
if (!value->is_entry_) { void Shrink(const int threshold) {
value->is_entry_ = entry_func_(count); for (auto iter = values_.begin(); iter != values_.end();) {
auto &value = iter->second;
value->unseen_days_++;
if (value->unseen_days_ >= threshold) {
iter = values_.erase(iter);
} else {
++iter;
}
} }
return;
} }
private: private:
...@@ -218,8 +238,7 @@ class ValueBlock { ...@@ -218,8 +238,7 @@ class ValueBlock {
const std::vector<int> &value_offsets_; const std::vector<int> &value_offsets_;
const std::unordered_map<std::string, int> &value_idx_; const std::unordered_map<std::string, int> &value_idx_;
bool has_entry_ = false; std::function<bool(std::shared_ptr<VALUE>)> entry_func_;
std::function<bool(uint64_t)> entry_func_;
std::vector<std::shared_ptr<Initializer>> initializers_; std::vector<std::shared_ptr<Initializer>> initializers_;
}; };
......
...@@ -76,6 +76,7 @@ class SSUM : public SparseOptimizer { ...@@ -76,6 +76,7 @@ class SSUM : public SparseOptimizer {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
for (auto x : offsets) { for (auto x : offsets) {
auto id = keys[x]; auto id = keys[x];
if (!block->GetEntry(id)) continue;
auto* value = block->Get(id); auto* value = block->Get(id);
float* param = value + param_offset; float* param = value + param_offset;
blas.VADD(update_numel, update_values + x * update_numel, param, param); blas.VADD(update_numel, update_values + x * update_numel, param, param);
...@@ -105,6 +106,7 @@ class SSGD : public SparseOptimizer { ...@@ -105,6 +106,7 @@ class SSGD : public SparseOptimizer {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
for (auto x : offsets) { for (auto x : offsets) {
auto id = keys[x]; auto id = keys[x];
if (!block->GetEntry(id)) continue;
auto* value = block->Get(id); auto* value = block->Get(id);
float learning_rate = *(global_learning_rate_) * (value + lr_offset)[0]; float learning_rate = *(global_learning_rate_) * (value + lr_offset)[0];
...@@ -161,6 +163,7 @@ class SAdam : public SparseOptimizer { ...@@ -161,6 +163,7 @@ class SAdam : public SparseOptimizer {
auto blas = GetBlas<float>(); auto blas = GetBlas<float>();
for (auto x : offsets) { for (auto x : offsets) {
auto id = keys[x]; auto id = keys[x];
if (!block->GetEntry(id)) continue;
auto* values = block->Get(id); auto* values = block->Get(id);
float lr_ = *(global_learning_rate_) * (values + lr_offset)[0]; float lr_ = *(global_learning_rate_) * (values + lr_offset)[0];
VLOG(4) << "SAdam LearningRate: " << lr_; VLOG(4) << "SAdam LearningRate: " << lr_;
......
...@@ -90,7 +90,7 @@ class Table { ...@@ -90,7 +90,7 @@ class Table {
virtual void clear() = 0; virtual void clear() = 0;
virtual int32_t flush() = 0; virtual int32_t flush() = 0;
virtual int32_t shrink() = 0; virtual int32_t shrink(const std::string &param) = 0;
//指定加载路径 //指定加载路径
virtual int32_t load(const std::string &path, virtual int32_t load(const std::string &path,
......
...@@ -51,7 +51,7 @@ class TensorTable : public Table { ...@@ -51,7 +51,7 @@ class TensorTable : public Table {
size_t num) override { size_t num) override {
return 0; return 0;
} }
int32_t shrink() override { return 0; } int32_t shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; }
...@@ -101,7 +101,7 @@ class DenseTensorTable : public TensorTable { ...@@ -101,7 +101,7 @@ class DenseTensorTable : public TensorTable {
size_t num) override { size_t num) override {
return 0; return 0;
} }
int32_t shrink() override { return 0; } int32_t shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; }
...@@ -157,7 +157,7 @@ class GlobalStepTable : public DenseTensorTable { ...@@ -157,7 +157,7 @@ class GlobalStepTable : public DenseTensorTable {
size_t num) override { size_t num) override {
return 0; return 0;
} }
int32_t shrink() override { return 0; } int32_t shrink(const std::string &param) override { return 0; }
virtual void *get_shard(size_t shard_idx) { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; }
......
...@@ -62,7 +62,8 @@ void BindDistFleetWrapper(py::module* m) { ...@@ -62,7 +62,8 @@ void BindDistFleetWrapper(py::module* m) {
.def("sparse_table_stat", &FleetWrapper::PrintTableStat) .def("sparse_table_stat", &FleetWrapper::PrintTableStat)
.def("stop_server", &FleetWrapper::StopServer) .def("stop_server", &FleetWrapper::StopServer)
.def("stop_worker", &FleetWrapper::FinalizeWorker) .def("stop_worker", &FleetWrapper::FinalizeWorker)
.def("barrier", &FleetWrapper::BarrierWithTable); .def("barrier", &FleetWrapper::BarrierWithTable)
.def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable);
} }
void BindPSHost(py::module* m) { void BindPSHost(py::module* m) {
......
...@@ -63,3 +63,4 @@ set_lr = fleet.set_lr ...@@ -63,3 +63,4 @@ set_lr = fleet.set_lr
get_lr = fleet.get_lr get_lr = fleet.get_lr
state_dict = fleet.state_dict state_dict = fleet.state_dict
set_state_dict = fleet.set_state_dict set_state_dict = fleet.set_state_dict
shrink = fleet.shrink
...@@ -520,7 +520,8 @@ class Fleet(object): ...@@ -520,7 +520,8 @@ class Fleet(object):
feeded_var_names, feeded_var_names,
target_vars, target_vars,
main_program=None, main_program=None,
export_for_deployment=True): export_for_deployment=True,
mode=0):
""" """
save inference model for inference. save inference model for inference.
...@@ -543,7 +544,7 @@ class Fleet(object): ...@@ -543,7 +544,7 @@ class Fleet(object):
self._runtime_handle._save_inference_model( self._runtime_handle._save_inference_model(
executor, dirname, feeded_var_names, target_vars, main_program, executor, dirname, feeded_var_names, target_vars, main_program,
export_for_deployment) export_for_deployment, mode)
def save_persistables(self, executor, dirname, main_program=None, mode=0): def save_persistables(self, executor, dirname, main_program=None, mode=0):
""" """
...@@ -590,6 +591,9 @@ class Fleet(object): ...@@ -590,6 +591,9 @@ class Fleet(object):
self._runtime_handle._save_persistables(executor, dirname, main_program, self._runtime_handle._save_persistables(executor, dirname, main_program,
mode) mode)
def shrink(self, threshold):
self._runtime_handle._shrink(threshold)
def distributed_optimizer(self, optimizer, strategy=None): def distributed_optimizer(self, optimizer, strategy=None):
""" """
Optimizer for distributed training. Optimizer for distributed training.
......
...@@ -946,7 +946,8 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -946,7 +946,8 @@ class TheOnePSRuntime(RuntimeBase):
feeded_var_names, feeded_var_names,
target_vars, target_vars,
main_program=None, main_program=None,
export_for_deployment=True): export_for_deployment=True,
mode=0):
""" """
Prune the given `main_program` to build a new program especially for inference, Prune the given `main_program` to build a new program especially for inference,
and then save it and all related parameters to given `dirname` by the `executor`. and then save it and all related parameters to given `dirname` by the `executor`.
...@@ -983,10 +984,25 @@ class TheOnePSRuntime(RuntimeBase): ...@@ -983,10 +984,25 @@ class TheOnePSRuntime(RuntimeBase):
program = Program.parse_from_string(program_desc_str) program = Program.parse_from_string(program_desc_str)
program._copy_dist_param_info_from(fluid.default_main_program()) program._copy_dist_param_info_from(fluid.default_main_program())
self._ps_inference_save_persistables(executor, dirname, program) self._ps_inference_save_persistables(executor, dirname, program,
mode)
def _save_inference_model(self, *args, **kwargs): def _save_inference_model(self, *args, **kwargs):
self._ps_inference_save_inference_model(*args, **kwargs) self._ps_inference_save_inference_model(*args, **kwargs)
def _save_persistables(self, *args, **kwargs): def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs) self._ps_inference_save_persistables(*args, **kwargs)
def _shrink(self, threshold):
import paddle.distributed.fleet as fleet
fleet.util.barrier()
if self.role_maker._is_first_worker():
sparses = self.compiled_strategy.get_the_one_recv_context(
is_dense=False,
split_dense_table=self.role_maker.
_is_heter_parameter_server_mode,
use_origin_program=True)
for id, names in sparses.items():
self._worker.shrink_sparse_table(id, threshold)
fleet.util.barrier()
...@@ -65,7 +65,7 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestPSPassWithBow(unittest.TestCase):
return avg_cost return avg_cost
is_distributed = False is_distributed = False
is_sparse = True is_sparse = False
# query # query
q = fluid.layers.data( q = fluid.layers.data(
...@@ -162,7 +162,7 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -162,7 +162,7 @@ class TestPSPassWithBow(unittest.TestCase):
role = fleet.UserDefinedRoleMaker( role = fleet.UserDefinedRoleMaker(
current_id=0, current_id=0,
role=role_maker.Role.SERVER, role=role_maker.Role.WORKER,
worker_num=2, worker_num=2,
server_endpoints=endpoints) server_endpoints=endpoints)
...@@ -172,11 +172,13 @@ class TestPSPassWithBow(unittest.TestCase): ...@@ -172,11 +172,13 @@ class TestPSPassWithBow(unittest.TestCase):
strategy = paddle.distributed.fleet.DistributedStrategy() strategy = paddle.distributed.fleet.DistributedStrategy()
strategy.a_sync = True strategy.a_sync = True
strategy.a_sync_configs = {"k_steps": 100} strategy.a_sync_configs = {"launch_barrier": False}
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(loss) optimizer.minimize(loss)
fleet.shrink(10)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册