From 36710ebcae7deea11cfed2f0437417f2f9293288 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 23 Feb 2021 14:11:57 +0800 Subject: [PATCH] test=develop, save/load, shrink (#30625) (#31107) * test=develop, save/load, shrink Co-authored-by: seiriosPlus Co-authored-by: 123malin --- paddle/fluid/distributed/fleet.cc | 10 +- paddle/fluid/distributed/fleet.h | 2 +- .../distributed/service/brpc_ps_client.cc | 5 +- .../distributed/service/brpc_ps_client.h | 3 +- .../distributed/service/brpc_ps_server.cc | 12 +- paddle/fluid/distributed/service/ps_client.h | 3 +- .../distributed/table/common_dense_table.h | 2 +- .../distributed/table/common_sparse_table.cc | 63 +++++-- .../distributed/table/common_sparse_table.h | 2 +- paddle/fluid/distributed/table/common_table.h | 4 +- .../table/depends/large_scale_kv.h | 155 ++++++++++-------- .../fluid/distributed/table/depends/sparse.h | 3 + paddle/fluid/distributed/table/table.h | 2 +- paddle/fluid/distributed/table/tensor_table.h | 6 +- paddle/fluid/pybind/fleet_py.cc | 3 +- python/paddle/distributed/fleet/__init__.py | 1 + .../distributed/fleet/base/fleet_base.py | 8 +- .../distributed/fleet/runtime/the_one_ps.py | 20 ++- .../tests/unittests/test_dist_fleet_ps3.py | 8 +- 19 files changed, 206 insertions(+), 106 deletions(-) diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc index 8db32c5cc4d..e635a730b03 100644 --- a/paddle/fluid/distributed/fleet.cc +++ b/paddle/fluid/distributed/fleet.cc @@ -479,9 +479,15 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) { } } -void FleetWrapper::ShrinkSparseTable(int table_id) { - auto ret = pserver_ptr_->_worker_ptr->shrink(table_id); +void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) { + auto* communicator = Communicator::GetInstance(); + auto ret = + communicator->_worker_ptr->shrink(table_id, std::to_string(threshold)); ret.wait(); + int32_t err_code = ret.get(); + if (err_code == -1) { + LOG(ERROR) << "shrink sparse table stat failed"; + } } void FleetWrapper::ClearModel() { diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h index 03d915c5005..3e07248fa68 100644 --- a/paddle/fluid/distributed/fleet.h +++ b/paddle/fluid/distributed/fleet.h @@ -207,7 +207,7 @@ class FleetWrapper { // clear one table void ClearOneTable(const uint64_t table_id); // shrink sparse table - void ShrinkSparseTable(int table_id); + void ShrinkSparseTable(int table_id, int threshold); // shrink dense table void ShrinkDenseTable(int table_id, Scope* scope, std::vector var_list, float decay, diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc index e781cc4bcf4..c991ffb1139 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -353,8 +353,9 @@ std::future BrpcPsClient::send_save_cmd( return fut; } -std::future BrpcPsClient::shrink(uint32_t table_id) { - return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")}); +std::future BrpcPsClient::shrink(uint32_t table_id, + const std::string threshold) { + return send_cmd(table_id, PS_SHRINK_TABLE, {threshold}); } std::future BrpcPsClient::load(const std::string &epoch, diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h index 82f772c2d5a..60e6354fa35 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/service/brpc_ps_client.h @@ -102,7 +102,8 @@ class BrpcPsClient : public PSClient { } virtual int32_t create_client2client_connection( int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry); - virtual std::future shrink(uint32_t table_id) override; + virtual std::future shrink(uint32_t table_id, + const std::string threshold) override; virtual std::future load(const std::string &epoch, const std::string &mode) override; virtual std::future load(uint32_t table_id, const std::string &epoch, diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc index ef497d3222a..02b4087de38 100644 --- a/paddle/fluid/distributed/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -460,6 +460,8 @@ int32_t BrpcPsService::save_one_table(Table *table, table->flush(); int32_t feasign_size = 0; + + VLOG(0) << "save one table " << request.params(0) << " " << request.params(1); feasign_size = table->save(request.params(0), request.params(1)); if (feasign_size < 0) { set_response_code(response, -1, "table save failed"); @@ -491,10 +493,18 @@ int32_t BrpcPsService::shrink_table(Table *table, PsResponseMessage &response, brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 1) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 1, threshold"); + return -1; + } table->flush(); - if (table->shrink() != 0) { + if (table->shrink(request.params(0)) != 0) { set_response_code(response, -1, "table shrink failed"); + return -1; } + VLOG(0) << "Pserver Shrink Finished"; return 0; } diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index a23a06c46e0..f3b09e3e597 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -69,7 +69,8 @@ class PSClient { int max_retry) = 0; // 触发table数据退场 - virtual std::future shrink(uint32_t table_id) = 0; + virtual std::future shrink(uint32_t table_id, + const std::string threshold) = 0; // 全量table进行数据load virtual std::future load(const std::string &epoch, diff --git a/paddle/fluid/distributed/table/common_dense_table.h b/paddle/fluid/distributed/table/common_dense_table.h index c32e6e194de..40a53bb53fb 100644 --- a/paddle/fluid/distributed/table/common_dense_table.h +++ b/paddle/fluid/distributed/table/common_dense_table.h @@ -58,7 +58,7 @@ class CommonDenseTable : public DenseTable { } virtual int32_t flush() override { return 0; } - virtual int32_t shrink() override { return 0; } + virtual int32_t shrink(const std::string& param) override { return 0; } virtual void clear() override { return; } protected: diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc index 98db14e0eca..3ab3a398cdb 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.cc +++ b/paddle/fluid/distributed/table/common_sparse_table.cc @@ -22,9 +22,12 @@ #include "paddle/fluid/string/string_helper.h" #define PSERVER_SAVE_SUFFIX "_txt" + namespace paddle { namespace distributed { +enum SaveMode { all, base, delta }; + struct Meta { std::string param; int shard_id; @@ -94,12 +97,9 @@ struct Meta { void ProcessALine(const std::vector& columns, const Meta& meta, std::vector>* values) { - PADDLE_ENFORCE_EQ(columns.size(), 2, - paddle::platform::errors::InvalidArgument( - "The data format does not meet the requirements. It " - "should look like feasign_id \t params.")); - - auto load_values = paddle::string::split_string(columns[1], ","); + auto colunmn_size = columns.size(); + auto load_values = + paddle::string::split_string(columns[colunmn_size - 1], ","); values->reserve(meta.names.size()); int offset = 0; @@ -121,11 +121,18 @@ void ProcessALine(const std::vector& columns, const Meta& meta, int64_t SaveToText(std::ostream* os, std::shared_ptr block, const int mode) { + int64_t not_save_num = 0; for (auto value : block->values_) { + if (mode == SaveMode::delta && !value.second->need_save_) { + not_save_num++; + continue; + } + auto* vs = value.second->data_.data(); std::stringstream ss; auto id = value.first; - ss << id << "\t"; + ss << id << "\t" << value.second->count_ << "\t" + << value.second->unseen_days_ << "\t" << value.second->is_entry_ << "\t"; for (int i = 0; i < block->value_length_; i++) { ss << vs[i]; @@ -135,9 +142,13 @@ int64_t SaveToText(std::ostream* os, std::shared_ptr block, ss << "\n"; os->write(ss.str().c_str(), sizeof(char) * ss.str().size()); + + if (mode == SaveMode::base || mode == SaveMode::delta) { + value.second->need_save_ = false; + } } - return block->values_.size(); + return block->values_.size() - not_save_num; } int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, @@ -165,8 +176,21 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath, std::vector> kvalues; ProcessALine(values, meta, &kvalues); - // warning: need fix - block->Init(id); + + block->Init(id, false); + + auto value_instant = block->GetValue(id); + if (values.size() == 5) { + value_instant->count_ = std::stoi(values[1]); + value_instant->unseen_days_ = std::stoi(values[2]); + value_instant->is_entry_ = static_cast(std::stoi(values[3])); + } + + std::vector block_values = block->Get(id, meta.names, meta.dims); + auto blas = GetBlas(); + for (int x = 0; x < meta.names.size(); ++x) { + blas.VCOPY(meta.dims[x], kvalues[x].data(), block_values[x]); + } } return 0; @@ -393,7 +417,7 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys, for (int i = 0; i < offsets.size(); ++i) { auto offset = offsets[i]; auto id = keys[offset]; - auto* value = block->InitFromInitializer(id); + auto* value = block->Init(id); std::copy_n(value + param_offset_, param_dim_, pull_values + param_dim_ * offset); } @@ -488,9 +512,10 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, for (int i = 0; i < offsets.size(); ++i) { auto offset = offsets[i]; auto id = keys[offset]; - auto* value = block->InitFromInitializer(id); + auto* value = block->Init(id, false); std::copy_n(values + param_dim_ * offset, param_dim_, value + param_offset_); + block->SetEntry(id, true); } return 0; }); @@ -505,10 +530,20 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys, int32_t CommonSparseTable::flush() { return 0; } -int32_t CommonSparseTable::shrink() { - VLOG(0) << "shrink coming soon"; +int32_t CommonSparseTable::shrink(const std::string& param) { + rwlock_->WRLock(); + int threshold = std::stoi(param); + VLOG(0) << "sparse table shrink: " << threshold; + + for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) { + // shrink + VLOG(0) << shard_id << " " << task_pool_size_ << " begin shrink"; + shard_values_[shard_id]->Shrink(threshold); + } + rwlock_->UNLock(); return 0; } + void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; } } // namespace distributed diff --git a/paddle/fluid/distributed/table/common_sparse_table.h b/paddle/fluid/distributed/table/common_sparse_table.h index e74a6bac44e..51dbe3c65ca 100644 --- a/paddle/fluid/distributed/table/common_sparse_table.h +++ b/paddle/fluid/distributed/table/common_sparse_table.h @@ -73,7 +73,7 @@ class CommonSparseTable : public SparseTable { virtual int32_t pour(); virtual int32_t flush(); - virtual int32_t shrink(); + virtual int32_t shrink(const std::string& param); virtual void clear(); protected: diff --git a/paddle/fluid/distributed/table/common_table.h b/paddle/fluid/distributed/table/common_table.h index d37e6677e63..034769e0212 100644 --- a/paddle/fluid/distributed/table/common_table.h +++ b/paddle/fluid/distributed/table/common_table.h @@ -108,7 +108,7 @@ class DenseTable : public Table { int32_t push_dense_param(const float *values, size_t num) override { return 0; } - int32_t shrink() override { return 0; } + int32_t shrink(const std::string ¶m) override { return 0; } }; class BarrierTable : public Table { @@ -133,7 +133,7 @@ class BarrierTable : public Table { int32_t push_dense_param(const float *values, size_t num) override { return 0; } - int32_t shrink() override { return 0; } + int32_t shrink(const std::string ¶m) override { return 0; } virtual void clear(){}; virtual int32_t flush() { return 0; }; virtual int32_t load(const std::string &path, const std::string ¶m) { diff --git a/paddle/fluid/distributed/table/depends/large_scale_kv.h b/paddle/fluid/distributed/table/depends/large_scale_kv.h index 79a4c4700a9..5cbf48548fb 100644 --- a/paddle/fluid/distributed/table/depends/large_scale_kv.h +++ b/paddle/fluid/distributed/table/depends/large_scale_kv.h @@ -47,43 +47,34 @@ namespace distributed { enum Mode { training, infer }; -template -inline bool entry(const int count, const T threshold); - -template <> -inline bool entry(const int count, const std::string threshold) { - return true; -} - -template <> -inline bool entry(const int count, const int threshold) { - return count >= threshold; -} - -template <> -inline bool entry(const int count, const float threshold) { - UniformInitializer uniform = UniformInitializer({"0", "0", "1"}); - return uniform.GetValue() >= threshold; -} - struct VALUE { explicit VALUE(size_t length) : length_(length), - count_(1), + count_(0), unseen_days_(0), - seen_after_last_save_(true), - is_entry_(true) { + need_save_(false), + is_entry_(false) { data_.resize(length); + memset(data_.data(), 0, sizeof(float) * length); } size_t length_; std::vector data_; int count_; - int unseen_days_; - bool seen_after_last_save_; - bool is_entry_; + int unseen_days_; // use to check knock-out + bool need_save_; // whether need to save + bool is_entry_; // whether knock-in }; +inline bool count_entry(std::shared_ptr value, int threshold) { + return value->count_ >= threshold; +} + +inline bool probility_entry(std::shared_ptr value, float threshold) { + UniformInitializer uniform = UniformInitializer({"0", "0", "1"}); + return uniform.GetValue() >= threshold; +} + class ValueBlock { public: explicit ValueBlock(const std::vector &value_names, @@ -102,21 +93,21 @@ class ValueBlock { // for Entry { - if (entry_attr == "none") { - has_entry_ = false; + auto slices = string::split_string(entry_attr, "&"); + if (slices[0] == "none") { + entry_func_ = std::bind(&count_entry, std::placeholders::_1, 0); + } else if (slices[0] == "count_filter") { + int threshold = std::stoi(slices[1]); + entry_func_ = std::bind(&count_entry, std::placeholders::_1, threshold); + } else if (slices[0] == "probability") { + float threshold = std::stof(slices[1]); entry_func_ = - std::bind(entry, std::placeholders::_1, "none"); + std::bind(&probility_entry, std::placeholders::_1, threshold); } else { - has_entry_ = true; - auto slices = string::split_string(entry_attr, "&"); - if (slices[0] == "count_filter") { - int threshold = std::stoi(slices[1]); - entry_func_ = std::bind(entry, std::placeholders::_1, threshold); - } else if (slices[0] == "probability") { - float threshold = std::stof(slices[1]); - entry_func_ = - std::bind(entry, std::placeholders::_1, threshold); - } + PADDLE_THROW(platform::errors::InvalidArgument( + "Not supported Entry Type : %s, Only support [count_filter, " + "probability]", + slices[0])); } } @@ -144,58 +135,87 @@ class ValueBlock { ~ValueBlock() {} - float *Init(const uint64_t &id) { - auto value = std::make_shared(value_length_); - for (int x = 0; x < value_names_.size(); ++x) { - initializers_[x]->GetValue(value->data_.data() + value_offsets_[x], - value_dims_[x]); - } - values_[id] = value; - return value->data_.data(); - } - std::vector Get(const uint64_t &id, - const std::vector &value_names) { + const std::vector &value_names, + const std::vector &value_dims) { auto pts = std::vector(); pts.reserve(value_names.size()); auto &values = values_.at(id); for (int i = 0; i < static_cast(value_names.size()); i++) { + PADDLE_ENFORCE_EQ( + value_dims[i], value_dims_[i], + platform::errors::InvalidArgument("value dims is not match")); pts.push_back(values->data_.data() + value_offsets_.at(value_idx_.at(value_names[i]))); } return pts; } - float *Get(const uint64_t &id) { - auto pts = std::vector *>(); - auto &values = values_.at(id); + // pull + float *Init(const uint64_t &id, const bool with_update = true) { + if (!Has(id)) { + values_[id] = std::make_shared(value_length_); + } + + auto &value = values_.at(id); - return values->data_.data(); + if (with_update) { + AttrUpdate(value); + } + + return value->data_.data(); } - float *InitFromInitializer(const uint64_t &id) { - if (Has(id)) { - if (has_entry_) { - Update(id); + void AttrUpdate(std::shared_ptr value) { + // update state + value->unseen_days_ = 0; + ++value->count_; + + if (!value->is_entry_) { + value->is_entry_ = entry_func_(value); + if (value->is_entry_) { + // initialize + for (int x = 0; x < value_names_.size(); ++x) { + initializers_[x]->GetValue(value->data_.data() + value_offsets_[x], + value_dims_[x]); + } } - return Get(id); } - return Init(id); + + value->need_save_ = true; + return; } + // dont jude if (has(id)) + float *Get(const uint64_t &id) { + auto &value = values_.at(id); + return value->data_.data(); + } + + // for load, to reset count, unseen_days + std::shared_ptr GetValue(const uint64_t &id) { return values_.at(id); } + bool GetEntry(const uint64_t &id) { - auto value = values_.at(id); + auto &value = values_.at(id); return value->is_entry_; } - void Update(const uint64_t id) { - auto value = values_.at(id); - value->unseen_days_ = 0; - auto count = ++value->count_; + void SetEntry(const uint64_t &id, const bool state) { + auto &value = values_.at(id); + value->is_entry_ = state; + } - if (!value->is_entry_) { - value->is_entry_ = entry_func_(count); + void Shrink(const int threshold) { + for (auto iter = values_.begin(); iter != values_.end();) { + auto &value = iter->second; + value->unseen_days_++; + if (value->unseen_days_ >= threshold) { + iter = values_.erase(iter); + } else { + ++iter; + } } + return; } private: @@ -218,8 +238,7 @@ class ValueBlock { const std::vector &value_offsets_; const std::unordered_map &value_idx_; - bool has_entry_ = false; - std::function entry_func_; + std::function)> entry_func_; std::vector> initializers_; }; diff --git a/paddle/fluid/distributed/table/depends/sparse.h b/paddle/fluid/distributed/table/depends/sparse.h index 1900da32155..38ae03777c8 100644 --- a/paddle/fluid/distributed/table/depends/sparse.h +++ b/paddle/fluid/distributed/table/depends/sparse.h @@ -76,6 +76,7 @@ class SSUM : public SparseOptimizer { auto blas = GetBlas(); for (auto x : offsets) { auto id = keys[x]; + if (!block->GetEntry(id)) continue; auto* value = block->Get(id); float* param = value + param_offset; blas.VADD(update_numel, update_values + x * update_numel, param, param); @@ -105,6 +106,7 @@ class SSGD : public SparseOptimizer { auto blas = GetBlas(); for (auto x : offsets) { auto id = keys[x]; + if (!block->GetEntry(id)) continue; auto* value = block->Get(id); float learning_rate = *(global_learning_rate_) * (value + lr_offset)[0]; @@ -161,6 +163,7 @@ class SAdam : public SparseOptimizer { auto blas = GetBlas(); for (auto x : offsets) { auto id = keys[x]; + if (!block->GetEntry(id)) continue; auto* values = block->Get(id); float lr_ = *(global_learning_rate_) * (values + lr_offset)[0]; VLOG(4) << "SAdam LearningRate: " << lr_; diff --git a/paddle/fluid/distributed/table/table.h b/paddle/fluid/distributed/table/table.h index 1bfedb53ab8..65c99d2bbd4 100644 --- a/paddle/fluid/distributed/table/table.h +++ b/paddle/fluid/distributed/table/table.h @@ -90,7 +90,7 @@ class Table { virtual void clear() = 0; virtual int32_t flush() = 0; - virtual int32_t shrink() = 0; + virtual int32_t shrink(const std::string ¶m) = 0; //指定加载路径 virtual int32_t load(const std::string &path, diff --git a/paddle/fluid/distributed/table/tensor_table.h b/paddle/fluid/distributed/table/tensor_table.h index 58680145a43..f89e2e9e730 100644 --- a/paddle/fluid/distributed/table/tensor_table.h +++ b/paddle/fluid/distributed/table/tensor_table.h @@ -51,7 +51,7 @@ class TensorTable : public Table { size_t num) override { return 0; } - int32_t shrink() override { return 0; } + int32_t shrink(const std::string ¶m) override { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; } @@ -101,7 +101,7 @@ class DenseTensorTable : public TensorTable { size_t num) override { return 0; } - int32_t shrink() override { return 0; } + int32_t shrink(const std::string ¶m) override { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; } @@ -157,7 +157,7 @@ class GlobalStepTable : public DenseTensorTable { size_t num) override { return 0; } - int32_t shrink() override { return 0; } + int32_t shrink(const std::string ¶m) override { return 0; } virtual void *get_shard(size_t shard_idx) { return 0; } diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 4777951d82c..ba716fb3b55 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -62,7 +62,8 @@ void BindDistFleetWrapper(py::module* m) { .def("sparse_table_stat", &FleetWrapper::PrintTableStat) .def("stop_server", &FleetWrapper::StopServer) .def("stop_worker", &FleetWrapper::FinalizeWorker) - .def("barrier", &FleetWrapper::BarrierWithTable); + .def("barrier", &FleetWrapper::BarrierWithTable) + .def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable); } void BindPSHost(py::module* m) { diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 0b7e8da101b..bd8492ecfa7 100644 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -63,3 +63,4 @@ set_lr = fleet.set_lr get_lr = fleet.get_lr state_dict = fleet.state_dict set_state_dict = fleet.set_state_dict +shrink = fleet.shrink diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 0e4559e6bc6..f4703a47cb7 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -520,7 +520,8 @@ class Fleet(object): feeded_var_names, target_vars, main_program=None, - export_for_deployment=True): + export_for_deployment=True, + mode=0): """ save inference model for inference. @@ -543,7 +544,7 @@ class Fleet(object): self._runtime_handle._save_inference_model( executor, dirname, feeded_var_names, target_vars, main_program, - export_for_deployment) + export_for_deployment, mode) def save_persistables(self, executor, dirname, main_program=None, mode=0): """ @@ -590,6 +591,9 @@ class Fleet(object): self._runtime_handle._save_persistables(executor, dirname, main_program, mode) + def shrink(self, threshold): + self._runtime_handle._shrink(threshold) + def distributed_optimizer(self, optimizer, strategy=None): """ Optimizer for distributed training. diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index dc78e1ce485..91a70bd3f39 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -946,7 +946,8 @@ class TheOnePSRuntime(RuntimeBase): feeded_var_names, target_vars, main_program=None, - export_for_deployment=True): + export_for_deployment=True, + mode=0): """ Prune the given `main_program` to build a new program especially for inference, and then save it and all related parameters to given `dirname` by the `executor`. @@ -983,10 +984,25 @@ class TheOnePSRuntime(RuntimeBase): program = Program.parse_from_string(program_desc_str) program._copy_dist_param_info_from(fluid.default_main_program()) - self._ps_inference_save_persistables(executor, dirname, program) + self._ps_inference_save_persistables(executor, dirname, program, + mode) def _save_inference_model(self, *args, **kwargs): self._ps_inference_save_inference_model(*args, **kwargs) def _save_persistables(self, *args, **kwargs): self._ps_inference_save_persistables(*args, **kwargs) + + def _shrink(self, threshold): + import paddle.distributed.fleet as fleet + fleet.util.barrier() + if self.role_maker._is_first_worker(): + sparses = self.compiled_strategy.get_the_one_recv_context( + is_dense=False, + split_dense_table=self.role_maker. + _is_heter_parameter_server_mode, + use_origin_program=True) + + for id, names in sparses.items(): + self._worker.shrink_sparse_table(id, threshold) + fleet.util.barrier() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps3.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps3.py index d1740f9d96f..aa7975d2b8b 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps3.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps3.py @@ -65,7 +65,7 @@ class TestPSPassWithBow(unittest.TestCase): return avg_cost is_distributed = False - is_sparse = True + is_sparse = False # query q = fluid.layers.data( @@ -162,7 +162,7 @@ class TestPSPassWithBow(unittest.TestCase): role = fleet.UserDefinedRoleMaker( current_id=0, - role=role_maker.Role.SERVER, + role=role_maker.Role.WORKER, worker_num=2, server_endpoints=endpoints) @@ -172,11 +172,13 @@ class TestPSPassWithBow(unittest.TestCase): strategy = paddle.distributed.fleet.DistributedStrategy() strategy.a_sync = True - strategy.a_sync_configs = {"k_steps": 100} + strategy.a_sync_configs = {"launch_barrier": False} optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer.minimize(loss) + fleet.shrink(10) + if __name__ == '__main__': unittest.main() -- GitLab