From f382eb06f0c6acec2475d3a545cb0abe9aafad4b Mon Sep 17 00:00:00 2001 From: zhaocaibei123 <48509226+zhaocaibei123@users.noreply.github.com> Date: Tue, 19 Jul 2022 15:51:07 +0800 Subject: [PATCH] add save_cache/patch (#44420) * add save_cache/patch * add pybind * remove pybind * remove const_cast * add fleet --- .../distributed/ps/service/brpc_ps_client.cc | 11 +- .../distributed/ps/service/brpc_ps_client.h | 11 +- .../distributed/ps/service/brpc_ps_server.cc | 52 +- .../distributed/ps/service/brpc_ps_server.h | 80 ++-- .../fluid/distributed/ps/service/ps_client.h | 27 +- .../distributed/ps/service/sendrecv.proto | 2 + .../distributed/ps/table/depends/dense.h | 2 +- .../ps/table/memory_dense_table.cc | 51 +- .../ps/table/memory_sparse_table.cc | 447 +++++++++++++++--- .../ps/table/memory_sparse_table.h | 50 +- paddle/fluid/distributed/ps/table/table.h | 11 +- paddle/fluid/distributed/ps/wrapper/fleet.cc | 18 + paddle/fluid/distributed/ps/wrapper/fleet.h | 2 + .../test/memory_sparse_table_test.cc | 3 - paddle/fluid/distributed/the_one_ps.proto | 4 +- paddle/fluid/pybind/fleet_py.cc | 4 +- 16 files changed, 609 insertions(+), 166 deletions(-) diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index ec6eda07cfb..c9135f919cb 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -482,7 +482,7 @@ std::future BrpcPsClient::GetCacheThreshold(uint32_t table_id, request_call_num, [request_call_num, cmd_id, &cache_threshold](void *done) { int ret = 0; - auto *closure = (DownpourBrpcClosure *)done; + auto *closure = reinterpret_cast(done); std::vector cache_thresholds(request_call_num, 0); for (size_t i = 0; i < request_call_num; ++i) { if (closure->check_response(i, cmd_id) != 0) { @@ -530,6 +530,14 @@ std::future BrpcPsClient::Clear(uint32_t table_id) { return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {}); } +std::future BrpcPsClient::Revert() { + return SendCmd(-1, PS_REVERT, {}); +} + +std::future BrpcPsClient::CheckSavePrePatchDone() { + return SendCmd(-1, PS_CHECK_SAVE_PRE_PATCH_DONE, {}); +} + std::future BrpcPsClient::Flush() { VLOG(0) << "BrpcPsClient::flush begin"; _flushing = true; @@ -1170,6 +1178,7 @@ std::future BrpcPsClient::PullSparseParam(float **select_values, } closure->set_promise_value(ret); }); + closure->add_timer(timer); auto promise = std::make_shared>(); closure->add_promise(promise); diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.h b/paddle/fluid/distributed/ps/service/brpc_ps_client.h index f0735f17610..3b455a44dc0 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.h @@ -178,6 +178,9 @@ class BrpcPsClient : public PSClient { std::future Clear(uint32_t table_id) override; + std::future Revert() override; + std::future CheckSavePrePatchDone() override; + std::future StopServer() override; std::future StartProfiler() override; @@ -298,16 +301,16 @@ class BrpcPsClient : public PSClient { int PushSparseAsyncShardMerge( std::vector> &task_list, // NOLINT - std::vector &request_kv_num, + std::vector &request_kv_num, // NOLINT int table_id, - int shard_idx, // NOLINT + int shard_idx, ValueAccessor *accessor); int PushSparseAsyncShardPush( std::vector> &task_list, // NOLINT - std::vector &request_kv_num, + std::vector &request_kv_num, // NOLINT int table_id, - int shard_idx, // NOLINT + int shard_idx, DownpourBrpcClosure *closure, ValueAccessor *accessor); diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index 7e341d5f378..c965496c68b 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -146,7 +146,7 @@ std::future BrpcPsServer::SendPServer2PServerMsg( return fut; } auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) { - auto *closure = (DownpourPServerBrpcClosure *)done; + auto *closure = reinterpret_cast(done); int32_t ret = closure->check_response(0, msg_type + 1000); closure->set_promise_value(ret); }); @@ -209,13 +209,16 @@ int32_t BrpcPsService::Initialize() { _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler; _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep; // for save cache - _service_handler_map[PS_SAVE_ONE_CACHE_TABLE] = &BrpcPsService::SaveCacheTable; _service_handler_map[PS_GET_CACHE_THRESHOLD] = &BrpcPsService::GetCacheThreshold; _service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle; + _service_handler_map[PS_REVERT] = &BrpcPsService::Revert; + _service_handler_map[PS_CHECK_SAVE_PRE_PATCH_DONE] = + &BrpcPsService::CheckSavePrePatchDone; + auto &profiler = CostProfiler::instance(); profiler.register_profiler("pserver_server_pull_dense"); profiler.register_profiler("pserver_server_push_dense"); @@ -319,9 +322,8 @@ int32_t BrpcPsService::PullDense(Table *table, table_context.pull_context.values = res_data->data(); table_context.num = num; table->Pull(table_context); - // table->PullDense(res_data->data(), num); - cntl->response_attachment().append((char *)(res_data->data()), + cntl->response_attachment().append(reinterpret_cast(res_data->data()), res_data->size() * sizeof(float)); butil::return_object(res_data); @@ -356,7 +358,6 @@ int32_t BrpcPsService::PushDenseParam(Table *table, table_context.push_context.is_param = true; table_context.num = num; - // if (table->PushDenseParam(values, num) != 0) { if (table->Push(table_context) != 0) { set_response_code(response, -1, "PushDenseParam failed"); } @@ -438,7 +439,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table, "least 1 for num of sparse_key"); return 0; } - uint32_t num = *(uint32_t *)(request.params(0).c_str()); + const uint32_t num = + *(reinterpret_cast(request.params(0).c_str())); /* Push Content: |---keysData---|---valuesData---| @@ -484,10 +486,11 @@ int32_t BrpcPsService::PullGeoParam(Table *table, // table->PullGeoParam(trainer_id, &values, &ids); uint32_t num = ids.size(); - cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); - cntl->response_attachment().append((char *)ids.data(), + cntl->response_attachment().append(reinterpret_cast(&num), + sizeof(uint32_t)); + cntl->response_attachment().append(reinterpret_cast(ids.data()), ids.size() * sizeof(uint64_t)); - cntl->response_attachment().append((char *)values.data(), + cntl->response_attachment().append(reinterpret_cast(values.data()), values.size() * sizeof(float)); return 0; } @@ -517,7 +520,8 @@ int32_t BrpcPsService::PullSparse(Table *table, } CostTimer timer("pserver_server_pull_sparse"); - uint32_t num = *(uint32_t *)(request.params(0).c_str()); + const uint32_t num = + *(reinterpret_cast(request.params(0).c_str())); auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim; thread_local std::string req_buffer; @@ -539,7 +543,7 @@ int32_t BrpcPsService::PullSparse(Table *table, table->Pull(table_context); // table->PullSparse(res_data->data(), value); - cntl->response_attachment().append((char *)(res_data->data()), + cntl->response_attachment().append(reinterpret_cast(res_data->data()), res_data->size() * sizeof(float)); butil::return_object(res_data); return 0; @@ -565,7 +569,8 @@ int32_t BrpcPsService::PushSparse(Table *table, return 0; } CostTimer timer("pserver_server_push_sparse"); - uint32_t num = *(uint32_t *)(request.params(0).c_str()); + const uint32_t num = + *(reinterpret_cast(request.params(0).c_str())); /* Push Content: |---keysData---|---valuesData---| @@ -767,6 +772,29 @@ int32_t BrpcPsService::GetCacheThreshold(Table *table, return 0; } +int32_t BrpcPsService::Revert(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); + for (auto &itr : table_map) { + itr.second->Flush(); + itr.second->Revert(); + } + return 0; +} + +int32_t BrpcPsService::CheckSavePrePatchDone(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + auto &table_map = *(_server->GetTable()); + for (auto &itr : table_map) { + itr.second->CheckSavePrePatchDone(); + } + return 0; +} + int32_t BrpcPsService::ShrinkTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response, diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.h b/paddle/fluid/distributed/ps/service/brpc_ps_server.h index a142afa5eff..0343b3f8c58 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.h @@ -53,12 +53,12 @@ class BrpcPsServer : public PSServer { } int32_t Port(); - virtual int32_t StartS2S() override; - virtual ::std::future SendPServer2PServerMsg( + int32_t StartS2S() override; + ::std::future SendPServer2PServerMsg( int msg_type, int to_pserver_id, const std::string &msg) override; - virtual int32_t ReceiveFromPServer(int msg_type, - int pserver_id, - const std::string &msg) override; + int32_t ReceiveFromPServer(int msg_type, + int pserver_id, + const std::string &msg) override; private: virtual int32_t Initialize(); @@ -75,118 +75,128 @@ class BrpcPsService; typedef int32_t (BrpcPsService::*serviceHandlerFunc)( Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); class BrpcPsService : public PsBaseService { public: - virtual int32_t Initialize() override; + int32_t Initialize() override; - virtual void service(::google::protobuf::RpcController *controller, - const PsRequestMessage *request, - PsResponseMessage *response, - ::google::protobuf::Closure *done) override; + void service(::google::protobuf::RpcController *controller, + const PsRequestMessage *request, + PsResponseMessage *response, + ::google::protobuf::Closure *done) override; private: int32_t InitializeShardInfo(); int32_t PullDense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PushDense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PushDenseParam(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PushSparseParam(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PullSparse(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PullGeoParam(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t Barrier(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PushSparse(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t LoadOneTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t LoadAllTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t SaveOneTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t SaveAllTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t ShrinkTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t ClearOneTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t ClearAllTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t StopServer(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t StartProfiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t StopProfiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PrintTableStat(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t PushGlobalStep(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t CacheShuffle(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t SaveCacheTable(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); int32_t GetCacheThreshold(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, + PsResponseMessage &response, // NOLINT brpc::Controller *cntl); + int32_t Revert(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, // NOLINT + brpc::Controller *cntl); + + int32_t CheckSavePrePatchDone(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, // NOLINT + brpc::Controller *cntl); + bool _is_initialize_shard_info; std::mutex _initialize_shard_mutex; std::unordered_map _service_handler_map; @@ -208,7 +218,7 @@ class DownpourPServerBrpcClosure : public PServerClosure { } virtual ~DownpourPServerBrpcClosure() {} - virtual void Run() override { + void Run() override { if (_waiting_num.fetch_sub(1) == 1) { _callback(this); delete this; diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index 01bf29b4291..b9a6aa0390f 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -67,12 +67,12 @@ class PSClient { PSClient(PSClient &&) = delete; PSClient(const PSClient &) = delete; - virtual int32_t Configure( // NOLINT + virtual int32_t Configure( const PSParameter &config, const std::map> ®ions, - PSEnvironment &_env, - size_t client_id) final; // NOLINT + PSEnvironment &_env, // NOLINT + size_t client_id) final; virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms, int pserver_connect_timeout_ms, @@ -293,8 +293,25 @@ class PSClient { return fut; } - virtual std::future GetCacheThreshold(uint32_t table_id, - double &cache_threshold) { + virtual std::future GetCacheThreshold( + uint32_t table_id, + double &cache_threshold) { // NOLINT + VLOG(0) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + + virtual std::future Revert() { + VLOG(0) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + + virtual std::future CheckSavePrePatchDone() { VLOG(0) << "Did not implement"; std::promise promise; std::future fut = promise.get_future(); diff --git a/paddle/fluid/distributed/ps/service/sendrecv.proto b/paddle/fluid/distributed/ps/service/sendrecv.proto index ae6364dd837..57919b6a706 100755 --- a/paddle/fluid/distributed/ps/service/sendrecv.proto +++ b/paddle/fluid/distributed/ps/service/sendrecv.proto @@ -65,6 +65,8 @@ enum PsCmdID { PS_SAVE_WITH_SHARD = 44; PS_QUERY_WITH_SCOPE = 45; PS_QUERY_WITH_SHARD = 46; + PS_REVERT = 47; + PS_CHECK_SAVE_PRE_PATCH_DONE = 48; // pserver2pserver cmd start from 100 PS_S2S_MSG = 101; } diff --git a/paddle/fluid/distributed/ps/table/depends/dense.h b/paddle/fluid/distributed/ps/table/depends/dense.h index 0780103d5d9..d98a91750f4 100644 --- a/paddle/fluid/distributed/ps/table/depends/dense.h +++ b/paddle/fluid/distributed/ps/table/depends/dense.h @@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer { } float* summary_decay_rate; - double summary_decay_rate_d = 0.999999; + double summary_decay_rate_d = 0.9999999; float* param; }; diff --git a/paddle/fluid/distributed/ps/table/memory_dense_table.cc b/paddle/fluid/distributed/ps/table/memory_dense_table.cc index 2560ac91510..9bad2113d17 100644 --- a/paddle/fluid/distributed/ps/table/memory_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_dense_table.cc @@ -339,34 +339,37 @@ int32_t MemoryDenseTable::Save(const std::string& path, _value_accesor->Converter(save_param).deconverter; bool is_write_failed = false; - std::vector> result_buffer_param( - param_dim_, std::vector()); - std::vector result_buffer_fixed_len; - result_buffer_fixed_len.reserve(fixed_len_params_dim_); - + std::vector result_buffer_param; + result_buffer_param.reserve(param_dim_); auto common = _config.common(); int size = static_cast(common.params().size()); if (_config.common().name() == "summary") { for (int x = 0; x < param_dim_; ++x) { - result_buffer_param[x].emplace_back( - std::to_string(values_[param_idx_][x])); + result_buffer_param.emplace_back(std::to_string(values_[param_idx_][x])); + } + } else if (_config.common().name() == "adam_d2sum") { + std::ostringstream os; + for (int y = 0; y < param_dim_; ++y) { + os.clear(); + os.str(""); + os << values_[param_col_ids_[0]][y] << " 0"; + for (int x = 2; x < param_col_ids_.size(); ++x) { + os << " "; + os << values_[param_col_ids_[x]][y]; + } + result_buffer_param.emplace_back(std::move(os.str())); } - } else { std::ostringstream os; - for (int x = 0; x < size; ++x) { - int dim = common.dims()[x]; - VLOG(3) << "MemoryDenseTable::save dim " << x << " size: " << dim; - for (int y = 0; y < dim; ++y) { - os.clear(); - os.str(""); - os << values_[x][y]; - if (dim == param_dim_) { - result_buffer_param[y].emplace_back(std::move(os.str())); - } else { - result_buffer_fixed_len.emplace_back(std::move(os.str())); - } + for (int y = 0; y < param_dim_; ++y) { + os.clear(); + os.str(""); + os << values_[param_col_ids_[0]][y]; + for (int x = 1; x < param_col_ids_.size(); ++x) { + os << " "; + os << values_[param_col_ids_[x]][y]; } + result_buffer_param.emplace_back(std::move(os.str())); } } @@ -379,12 +382,9 @@ int32_t MemoryDenseTable::Save(const std::string& path, // 40M auto write_channel = _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); + for (auto& t : result_buffer_param) { - if (_config.common().name() == "adam_d2sum") { - t.insert(t.begin() + 1, "0"); // avg_w - } - if (0 != - write_channel->write_line(paddle::string::join_strings(t, ' '))) { + if (0 != write_channel->write_line(t)) { ++retry_num; is_write_failed = true; LOG(ERROR) << "DownpourDenseTable save failed, retry it! " @@ -395,6 +395,7 @@ int32_t MemoryDenseTable::Save(const std::string& path, } ++feasign_size; + VLOG(3) << "save begin close " << channel_config.path; write_channel->close(); if (err_no == -1) { ++retry_num; diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index 115f8bcf58e..f53954dce7c 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -12,15 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h" - #include - #include #include "glog/logging.h" #include "paddle/fluid/distributed/common/cost_timer.h" +#include "paddle/fluid/distributed/common/local_random.h" +#include "paddle/fluid/distributed/common/topk_calculator.h" +#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h" +#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/io/fs.h" + +// #include "boost/lexical_cast.hpp" #include "paddle/fluid/platform/enforce.h" DEFINE_bool(pserver_print_missed_key_num_every_push, @@ -68,6 +71,30 @@ int32_t MemorySparseTable::InitializeValue() { _local_shards.reset(new shard_type[_real_local_shard_num]); + if (_config.enable_revert()) { + // calculate merged shard number based on config param; + _shard_merge_rate = _config.has_shard_merge_rate() + ? _config.shard_merge_rate() + : _shard_merge_rate; + CHECK((_m_avg_local_shard_num = static_cast( + std::ceil(_avg_local_shard_num * _shard_merge_rate)), + _m_avg_local_shard_num <= _avg_local_shard_num)); + CHECK((_m_real_local_shard_num = static_cast( + std::ceil(_real_local_shard_num * _shard_merge_rate)), + _m_real_local_shard_num <= _real_local_shard_num)); + + uint32_t avg_shard_server_num = + _sparse_table_shard_num / _avg_local_shard_num; + uint32_t last_server_shard_num = + _sparse_table_shard_num - avg_shard_server_num * _avg_local_shard_num; + _m_sparse_table_shard_num = + avg_shard_server_num * _m_avg_local_shard_num + + std::ceil(last_server_shard_num * _shard_merge_rate); + LOG(INFO) << "merged shard info: [" << _m_sparse_table_shard_num << "|" + << _m_avg_local_shard_num << "|" << _m_real_local_shard_num + << "]"; + _local_shards_new.reset(new shard_type[_real_local_shard_num]); + } return 0; } @@ -93,8 +120,16 @@ int32_t MemorySparseTable::Load(const std::string& path, return -1; } + if (load_param == 5) { + return LoadPatch(file_list, load_param); + } + size_t file_start_idx = _shard_idx * _avg_local_shard_num; + if (file_start_idx >= file_list.size()) { + return 0; + } + size_t feature_value_size = _value_accesor->GetAccessorInfo().size / sizeof(float); @@ -161,30 +196,37 @@ int32_t MemorySparseTable::Load(const std::string& path, return 0; } -int32_t MemorySparseTable::LoadLocalFS(const std::string& path, - const std::string& param) { - std::string table_path = TableDir(path); - auto file_list = paddle::framework::localfs_list(table_path); - size_t expect_shard_num = _sparse_table_shard_num; - if (file_list.size() != expect_shard_num) { - LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size() - << " not equal to expect_shard_num:" << expect_shard_num; - return -1; - } - if (file_list.size() == 0) { - LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path; - return -1; +int32_t MemorySparseTable::LoadPatch(const std::vector& file_list, + int load_param) { + if (!_config.enable_revert()) { + LOG(INFO) << "MemorySparseTable should be enabled revert."; + return 0; } + // 聚合分片数据索引 + int start_idx = _shard_idx * _m_avg_local_shard_num; + int end_idx = start_idx + _m_real_local_shard_num; + // 原始分片数据索引 + int o_start_idx = _shard_idx * _avg_local_shard_num; + int o_end_idx = o_start_idx + _real_local_shard_num; - size_t file_start_idx = _shard_idx * _avg_local_shard_num; - + if (start_idx >= file_list.size()) { + return 0; + } size_t feature_value_size = _value_accesor->GetAccessorInfo().size / sizeof(float); + end_idx = + end_idx < _m_sparse_table_shard_num ? end_idx : _m_sparse_table_shard_num; + int thread_num = (end_idx - start_idx) < 15 ? (end_idx - start_idx) : 15; - int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15; omp_set_num_threads(thread_num); #pragma omp parallel for schedule(dynamic) - for (int i = 0; i < _real_local_shard_num; ++i) { + for (size_t i = start_idx; i < end_idx; ++i) { + FsChannelConfig channel_config; + channel_config.path = file_list[i]; + channel_config.converter = _value_accesor->Converter(load_param).converter; + channel_config.deconverter = + _value_accesor->Converter(load_param).deconverter; + bool is_read_failed = false; int retry_num = 0; int err_no = 0; @@ -192,31 +234,55 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, is_read_failed = false; err_no = 0; std::string line_data; - std::ifstream file(file_list[file_start_idx + i]); + auto read_channel = _afs_client.open_r(channel_config, 0, &err_no); char* end = NULL; - auto& shard = _local_shards[i]; + int m_local_shard_id = i % _m_avg_local_shard_num; + std::unordered_set global_shard_idx; + std::string global_shard_idx_str; + for (size_t j = o_start_idx; j < o_end_idx; ++j) { + if ((j % _avg_local_shard_num) % _m_real_local_shard_num == + m_local_shard_id) { + global_shard_idx.insert(j); + global_shard_idx_str.append(std::to_string(j)).append(","); + } + } try { - while (std::getline(file, line_data) && line_data.size() > 1) { + while (read_channel->read_line(line_data) == 0 && + line_data.size() > 1) { uint64_t key = std::strtoul(line_data.data(), &end, 10); + + auto index_iter = + global_shard_idx.find(key % _sparse_table_shard_num); + if (index_iter == global_shard_idx.end()) { + LOG(WARNING) << "MemorySparseTable key:" << key + << " not match shard," + << " file_idx:" << i + << " global_shard_idx:" << global_shard_idx_str + << " shard num:" << _sparse_table_shard_num + << " file:" << channel_config.path; + continue; + } + size_t local_shard_idx = *index_iter % _avg_local_shard_num; + auto& shard = _local_shards[local_shard_idx]; + auto& value = shard[key]; value.resize(feature_value_size); int parse_size = _value_accesor->ParseFromString(++end, value.data()); value.resize(parse_size); } - file.close(); + read_channel->close(); if (err_no == -1) { ++retry_num; is_read_failed = true; LOG(ERROR) << "MemorySparseTable load failed after read, retry it! path:" - << file_list[file_start_idx + i] << " , retry_num=" << retry_num; + << channel_config.path << " , retry_num=" << retry_num; } } catch (...) { ++retry_num; is_read_failed = true; LOG(ERROR) << "MemorySparseTable load failed, retry it! path:" - << file_list[file_start_idx + i] - << " , retry_num=" << retry_num; + << channel_config.path << " , retry_num=" << retry_num; } if (retry_num > FLAGS_pserver_table_save_max_retry) { LOG(ERROR) << "MemorySparseTable load failed reach max limit!"; @@ -225,16 +291,44 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, } while (is_read_failed); } LOG(INFO) << "MemorySparseTable load success, path from " - << file_list[file_start_idx] << " to " - << file_list[file_start_idx + _real_local_shard_num - 1]; + << file_list[start_idx] << " to " << file_list[end_idx - 1]; return 0; } +void MemorySparseTable::Revert() { + for (size_t i = 0; i < _real_local_shard_num; ++i) { + _local_shards_new[i].clear(); + } +} + +void MemorySparseTable::CheckSavePrePatchDone() { + _save_patch_model_thread.join(); +} + int32_t MemorySparseTable::Save(const std::string& dirname, const std::string& param) { + if (_real_local_shard_num == 0) { + _local_show_threshold = -1; + return 0; + } + VLOG(0) << "MemorySparseTable::save dirname: " << dirname; int save_param = atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 + + // patch model + if (save_param == 5) { + _local_shards_patch_model.reset(_local_shards_new.release()); + _local_shards_new.reset(new shard_type[_real_local_shard_num]); + _save_patch_model_thread = std::thread(std::bind( + &MemorySparseTable::SavePatch, this, std::string(dirname), save_param)); + return 0; + } + + // cache model + int64_t tk_size = LocalSize() * _config.sparse_table_cache_rate(); + TopkCalculator tk(_real_local_shard_num, tk_size); + std::string table_path = TableDir(dirname); _afs_client.remove(paddle::string::format_string( "%s/part-%03d-*", table_path.c_str(), _shard_idx)); @@ -274,6 +368,13 @@ int32_t MemorySparseTable::Save(const std::string& dirname, auto write_channel = _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); for (auto it = shard.begin(); it != shard.end(); ++it) { + if (_config.enable_sparse_table_cache() && + (save_param == 1 || save_param == 2) && + _value_accesor->Save(it.value().data(), 4)) { + CostTimer timer10("sprase table top push"); + tk.push(i, _value_accesor->GetField(it.value().data(), "show")); + } + if (_value_accesor->Save(it.value().data(), save_param)) { std::string format_value = _value_accesor->ParseToString( it.value().data(), it.value().size()); @@ -310,55 +411,266 @@ int32_t MemorySparseTable::Save(const std::string& dirname, _value_accesor->UpdateStatAfterSave(it.value().data(), save_param); } LOG(INFO) << "MemorySparseTable save prefix success, path: " - << channel_config.path; + << channel_config.path << " feasign_size: " << feasign_size; } + _local_show_threshold = tk.top(); // int32 may overflow need to change return value return 0; } -int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname, - const std::string& param, - const std::string& prefix) { - int save_param = - atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 - std::string table_path = TableDir(dirname); - int feasign_cnt = 0; - size_t file_start_idx = _avg_local_shard_num * _shard_idx; +int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) { + if (!_config.enable_revert()) { + LOG(INFO) << "MemorySparseTable should be enabled revert."; + return 0; + } + size_t file_start_idx = _m_avg_local_shard_num * _shard_idx; + std::string table_path = TableDir(path); + _afs_client.remove(paddle::string::format_string( + "%s/part-%03d-*", table_path.c_str(), _shard_idx)); + int thread_num = _m_real_local_shard_num < 20 ? _m_real_local_shard_num : 20; - int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; std::atomic feasign_size_all{0}; omp_set_num_threads(thread_num); #pragma omp parallel for schedule(dynamic) - for (int i = 0; i < _real_local_shard_num; ++i) { - feasign_cnt = 0; - auto& shard = _local_shards[i]; - std::string file_name = - paddle::string::format_string("%s/part-%s-%03d-%05d", - table_path.c_str(), - prefix.c_str(), - _shard_idx, - file_start_idx + i); - std::ofstream os; - os.open(file_name); - for (auto it = shard.begin(); it != shard.end(); ++it) { - if (_value_accesor->Save(it.value().data(), save_param)) { - std::string format_value = - _value_accesor->ParseToString(it.value().data(), it.value().size()); - std::string out_line = paddle::string::format_string( - "%lu %s\n", it.key(), format_value.c_str()); - // VLOG(2) << out_line.c_str(); - os.write(out_line.c_str(), sizeof(char) * out_line.size()); - ++feasign_cnt; + for (size_t i = 0; i < _m_real_local_shard_num; ++i) { + FsChannelConfig channel_config; + channel_config.path = paddle::string::format_string("%s/part-%03d-%05d", + table_path.c_str(), + _shard_idx, + file_start_idx + i); + + channel_config.converter = _value_accesor->Converter(save_param).converter; + channel_config.deconverter = + _value_accesor->Converter(save_param).deconverter; + + bool is_write_failed = false; + int feasign_size = 0; + int retry_num = 0; + int err_no = 0; + do { + err_no = 0; + feasign_size = 0; + is_write_failed = false; + auto write_channel = + _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); + + for (size_t j = 0; j < _real_local_shard_num; ++j) { + if (j % _m_real_local_shard_num == i) { + auto& shard = _local_shards_patch_model[j]; + for (auto it = shard.begin(); it != shard.end(); ++it) { + if (_value_accesor->Save(it.value().data(), save_param)) { + std::string format_value = _value_accesor->ParseToString( + it.value().data(), it.value().size()); + if (0 != write_channel->write_line(paddle::string::format_string( + "%lu %s", it.key(), format_value.c_str()))) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) << "MemorySparseTable save failed, retry it! path:" + << channel_config.path + << " , retry_num=" << retry_num; + break; + } + ++feasign_size; + } + } + } + if (is_write_failed) break; } + write_channel->close(); + if (err_no == -1) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) + << "MemorySparseTable save patch failed after write, retry it! " + << "path:" << channel_config.path << " , retry_num=" << retry_num; + } + if (is_write_failed) { + _afs_client.remove(channel_config.path); + } + if (retry_num > FLAGS_pserver_table_save_max_retry) { + LOG(ERROR) << "MemorySparseTable save patch failed reach max limit!"; + exit(-1); + } + } while (is_write_failed); + feasign_size_all += feasign_size; + } + LOG(INFO) << "MemorySparseTable save patch success, path:" + << paddle::string::format_string("%s/%03d/part-%03d-", + path.c_str(), + _config.table_id(), + _shard_idx) + << " from " << file_start_idx << " to " + << file_start_idx + _m_real_local_shard_num - 1 + << ", feasign size: " << feasign_size_all; + return 0; +} + +int64_t MemorySparseTable::CacheShuffle( + const std::string& path, + const std::string& param, + double cache_threshold, + std::function( + int msg_type, int to_pserver_id, std::string& msg)> send_msg_func, + paddle::framework::Channel>& + shuffled_channel, + const std::vector& table_ptrs) { + LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold; + int save_param = atoi(param.c_str()); // batch_model:0 xbox:1 + if (!_config.enable_sparse_table_cache() || cache_threshold < 0) { + LOG(WARNING) + << "cache shuffle failed not enable table cache or cache threshold < 0 " + << _config.enable_sparse_table_cache() << " or " << cache_threshold; + // return -1; + } + int shuffle_node_num = _config.sparse_table_cache_file_num(); + LOG(INFO) << "Table>> shuffle node num is: " << shuffle_node_num; + // TODO(zhaocaibei123): check shuffle_node_num <= server_node_num + size_t file_start_idx = _avg_local_shard_num * _shard_idx; + int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; + + std::vector< + paddle::framework::ChannelWriter>> + writers(_real_local_shard_num); + std::vector>> datas( + _real_local_shard_num); + + int feasign_size = 0; + std::vector>> + tmp_channels; + for (size_t i = 0; i < _real_local_shard_num; ++i) { + tmp_channels.push_back( + paddle::framework::MakeChannel>()); + } + + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < _real_local_shard_num; ++i) { + paddle::framework::ChannelWriter>& writer = + writers[i]; + writer.Reset(tmp_channels[i].get()); + + for (size_t idx = 0; idx < table_ptrs.size(); idx++) { + Table* table_ptr = table_ptrs[idx]; + auto value_accesor = table_ptr->ValueAccesor(); + shard_type* shard_ptr = static_cast(table_ptr->GetShard(i)); + + for (auto it = shard_ptr->begin(); it != shard_ptr->end(); ++it) { + if (value_accesor->SaveCache( + it.value().data(), save_param, cache_threshold)) { + std::string format_value = value_accesor->ParseToString( + it.value().data(), it.value().size()); + std::pair pkv(it.key(), format_value.c_str()); + writer << pkv; + ++feasign_size; + } + } + } + writer.Flush(); + writer.channel()->Close(); + } + // LOG(INFO) << "MemorySparseTable cache KV save success to Channel feasigh + // size: " << feasign_size << " and start sparse cache data shuffle real local + // shard num: " << _real_local_shard_num; + std::vector> local_datas; + for (size_t idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) { + paddle::framework::ChannelWriter>& writer = + writers[idx_shard]; + auto channel = writer.channel(); + std::vector>& data = datas[idx_shard]; + std::vector ars(shuffle_node_num); + while (channel->Read(data)) { + for (auto& t : data) { + auto pserver_id = + paddle::distributed::local_random_engine()() % shuffle_node_num; + if (pserver_id != _shard_idx) { + ars[pserver_id] << t; + } else { + local_datas.emplace_back(std::move(t)); + } + } + std::vector> total_status; + std::vector send_data_size(shuffle_node_num, 0); + std::vector send_index(shuffle_node_num); + for (int i = 0; i < shuffle_node_num; ++i) { + send_index[i] = i; + } + std::random_shuffle(send_index.begin(), send_index.end()); + for (auto index = 0u; index < shuffle_node_num; ++index) { + int i = send_index[index]; + if (i == _shard_idx) { + continue; + } + if (ars[i].Length() == 0) { + continue; + } + std::string msg(ars[i].Buffer(), ars[i].Length()); + auto ret = send_msg_func(101, i, msg); + total_status.push_back(std::move(ret)); + send_data_size[i] += ars[i].Length(); + } + for (auto& t : total_status) { + t.wait(); + } + ars.clear(); + ars = std::vector(shuffle_node_num); + data = std::vector>(); } - os.close(); - LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name - << "feasign_cnt: " << feasign_cnt; } + shuffled_channel->Write(std::move(local_datas)); return 0; } +int32_t MemorySparseTable::SaveCache( + const std::string& path, + const std::string& param, + paddle::framework::Channel>& + shuffled_channel) { + if (_shard_idx >= _config.sparse_table_cache_file_num()) { + return 0; + } + int save_param = atoi(param.c_str()); // batch_model:0 xbox:1 + size_t file_start_idx = _avg_local_shard_num * _shard_idx; + std::string table_path = paddle::string::format_string( + "%s/%03d_cache/", path.c_str(), _config.table_id()); + _afs_client.remove(paddle::string::format_string( + "%s/part-%03d", table_path.c_str(), _shard_idx)); + uint32_t feasign_size = 0; + FsChannelConfig channel_config; + // not compress cache model + channel_config.path = paddle::string::format_string( + "%s/part-%03d", table_path.c_str(), _shard_idx); + channel_config.converter = _value_accesor->Converter(save_param).converter; + channel_config.deconverter = + _value_accesor->Converter(save_param).deconverter; + auto write_channel = _afs_client.open_w(channel_config, 1024 * 1024 * 40); + std::vector> data; + bool is_write_failed = false; + shuffled_channel->Close(); + while (shuffled_channel->Read(data)) { + for (auto& t : data) { + ++feasign_size; + if (0 != write_channel->write_line(paddle::string::format_string( + "%lu %s", t.first, t.second.c_str()))) { + LOG(ERROR) << "Cache Table save failed, " + "path:" + << channel_config.path << ", retry it!"; + is_write_failed = true; + break; + } + } + data = std::vector>(); + } + if (is_write_failed) { + _afs_client.remove(channel_config.path); + } + write_channel->close(); + LOG(INFO) << "MemorySparseTable cache save success, feasign: " << feasign_size + << ", path: " << channel_config.path; + shuffled_channel->Open(); + return feasign_size; +} + int64_t MemorySparseTable::LocalSize() { int64_t local_size = 0; for (int i = 0; i < _real_local_shard_num; ++i) { @@ -548,7 +860,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values, ret = itr.value_ptr(); } int pull_data_idx = keys[i].second; - pull_values[pull_data_idx] = (char*)ret; // NOLINT + pull_values[pull_data_idx] = reinterpret_cast(ret); } return 0; }); @@ -589,6 +901,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, &task_keys]() -> int { auto& keys = task_keys[shard_id]; auto& local_shard = _local_shards[shard_id]; + auto& local_shard_new = _local_shards_new[shard_id]; float data_buffer[value_col]; // NOLINT float* data_buffer_ptr = data_buffer; for (size_t i = 0; i < keys.size(); ++i) { @@ -630,6 +943,14 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, } memcpy(value_data, data_buffer_ptr, value_size * sizeof(float)); } + if (_config.enable_revert()) { + FixedFeatureValue* feature_value_new = &(local_shard_new[key]); + auto new_size = feature_value.size(); + feature_value_new->resize(new_size); + memcpy(feature_value_new->data(), + value_data, + new_size * sizeof(float)); + } } return 0; }); diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_table.h index 1c4732a081c..9d48d530d89 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.h @@ -65,17 +65,25 @@ class MemorySparseTable : public Table { int32_t InitializeShard() override { return 0; } int32_t InitializeValue(); - virtual int32_t Load(const std::string& path, - const std::string& param) override; - - virtual int32_t Save(const std::string& path, - const std::string& param) override; - - int32_t LoadLocalFS(const std::string& path, const std::string& param); - int32_t SaveLocalFS(const std::string& path, - const std::string& param, - const std::string& prefix); - + int32_t Load(const std::string& path, const std::string& param) override; + + int32_t Save(const std::string& path, const std::string& param) override; + + int32_t SaveCache( + const std::string& path, + const std::string& param, + paddle::framework::Channel>& + shuffled_channel) override; + virtual double GetCacheThreshold() { return _local_show_threshold; } + int64_t CacheShuffle( + const std::string& path, + const std::string& param, + double cache_threshold, + std::function( + int msg_type, int to_pserver_id, std::string& msg)> send_msg_func, + paddle::framework::Channel>& + shuffled_channel, + const std::vector& table_ptrs) override; int64_t LocalSize(); int64_t LocalMFSize(); @@ -89,20 +97,38 @@ class MemorySparseTable : public Table { int32_t PushSparse(const uint64_t* keys, const float** values, size_t num); int32_t Flush() override; - virtual int32_t Shrink(const std::string& param) override; + int32_t Shrink(const std::string& param) override; void Clear() override; void* GetShard(size_t shard_idx) override { return &_local_shards[shard_idx]; } + virtual void Revert(); + virtual void CheckSavePrePatchDone(); + protected: + virtual int32_t SavePatch(const std::string& path, int save_param); + virtual int32_t LoadPatch(const std::vector& file_list, + int save_param); + const int _task_pool_size = 24; int _avg_local_shard_num; int _real_local_shard_num; int _sparse_table_shard_num; std::vector> _shards_task_pool; std::unique_ptr _local_shards; + + // for patch model + int _m_avg_local_shard_num; + int _m_real_local_shard_num; + int _m_sparse_table_shard_num; + float _shard_merge_rate{1.0f}; + double _local_show_threshold{0.0}; + + std::unique_ptr _local_shards_new; + std::unique_ptr _local_shards_patch_model; + std::thread _save_patch_model_thread; }; } // namespace distributed diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index 2f6717df4e2..aee707712f6 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -71,8 +71,8 @@ class Table { virtual int32_t Initialize(const TableParameter &config, const FsClientParameter &fs_config); - virtual int32_t Pull(TableContext &context) = 0; - virtual int32_t Push(TableContext &context) = 0; + virtual int32_t Pull(TableContext &context) = 0; // NOLINT + virtual int32_t Push(TableContext &context) = 0; // NOLINT // only for barrier virtual int32_t Barrier(const uint32_t trainer_id, @@ -125,7 +125,8 @@ class Table { const std::string ¶m, double cache_threshold, std::function( - int msg_type, int to_pserver_id, std::string &msg)> send_msg_func, + int msg_type, int to_pserver_id, std::string &msg)> // NOLINT + send_msg_func, paddle::framework::Channel> &shuffled_channel, const std::vector &table_ptrs) { @@ -147,6 +148,10 @@ class Table { virtual void *GetShard(size_t shard_idx) = 0; virtual std::pair PrintTableStat() { return {0, 0}; } + // for patch model + virtual void Revert() {} + virtual void CheckSavePrePatchDone() {} + protected: virtual int32_t Initialize() = 0; virtual int32_t InitializeAccessor(); diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index 57ff1d3bcd4..bbefeba5599 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -853,6 +853,24 @@ int32_t FleetWrapper::SaveCache(int table_id, return feasign_cnt; } +void FleetWrapper::Revert() { + auto ret = worker_ptr_->Revert(); + ret.wait(); + if (ret.get() == -1) { + LOG(ERROR) << "table revert failed"; + exit(-1); + } +} + +void FleetWrapper::CheckSavePrePatchDone() { + auto ret = worker_ptr_->CheckSavePrePatchDone(); + ret.wait(); + if (ret.get() == -1) { + LOG(ERROR) << "table revert failed"; + exit(-1); + } +} + std::default_random_engine& FleetWrapper::LocalRandomEngine() { struct engine_wrapper_t { std::default_random_engine engine; diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.h b/paddle/fluid/distributed/ps/wrapper/fleet.h index 2e8bb704bfe..3ff6cfaf8e4 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.h +++ b/paddle/fluid/distributed/ps/wrapper/fleet.h @@ -300,6 +300,8 @@ class FleetWrapper { const int mode, const double cache_threshold); int32_t SaveCache(int table_id, const std::string& path, const int mode); + void Revert(); + void CheckSavePrePatchDone(); static std::shared_ptr pserver_ptr_; static std::shared_ptr worker_ptr_; diff --git a/paddle/fluid/distributed/test/memory_sparse_table_test.cc b/paddle/fluid/distributed/test/memory_sparse_table_test.cc index da311a7691f..391d387b76c 100644 --- a/paddle/fluid/distributed/test/memory_sparse_table_test.cc +++ b/paddle/fluid/distributed/test/memory_sparse_table_test.cc @@ -150,9 +150,6 @@ TEST(MemorySparseTable, SGD) { VLOG(3) << update_val << ": " << pull_values[i * (emb_dim + 1) + j]; } } - - MemorySparseTable *ctr_table = dynamic_cast(table); - ctr_table->SaveLocalFS("./work/table.save", "0", "test"); } } // namespace distributed diff --git a/paddle/fluid/distributed/the_one_ps.proto b/paddle/fluid/distributed/the_one_ps.proto index a78bc8cddc3..38c9f4d5eb3 100644 --- a/paddle/fluid/distributed/the_one_ps.proto +++ b/paddle/fluid/distributed/the_one_ps.proto @@ -114,12 +114,14 @@ message TableParameter { optional TensorAccessorParameter tensor = 5; optional CommonAccessorParameter common = 6; optional TableType type = 7; - optional bool compress_in_save = 8 [ default = false ]; + optional bool compress_in_save = 8 [ default = true ]; optional GraphParameter graph_parameter = 9; // for cache model optional bool enable_sparse_table_cache = 10 [ default = true ]; optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; + optional bool enable_revert = 13 [ default = true ]; + optional float shard_merge_rate = 14 [ default = 1.0 ]; } message TableAccessorParameter { diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 65948a645f6..03de3520959 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -75,7 +75,9 @@ void BindDistFleetWrapper(py::module* m) { .def("client_flush", &FleetWrapper::ClientFlush) .def("get_cache_threshold", &FleetWrapper::GetCacheThreshold) .def("cache_shuffle", &FleetWrapper::CacheShuffle) - .def("save_cache", &FleetWrapper::SaveCache); + .def("save_cache", &FleetWrapper::SaveCache) + .def("revert", &FleetWrapper::Revert) + .def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone); } void BindPSHost(py::module* m) { -- GitLab