diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index f8a841fecbc0aaf89baf773e1a12b9b47c577024..c8ef4ad16ea9dcb728ad8ff4c7362b832551eee1 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -357,10 +357,8 @@ if (WITH_PSCORE) include(external/libmct) # download, build, install libmct list(APPEND third_party_deps extern_libmct) - if (WITH_HETERPS) - include(external/rocksdb) # download, build, install libmct - list(APPEND third_party_deps extern_rocksdb) - endif() + include(external/rocksdb) # download, build, install libmct + list(APPEND third_party_deps extern_rocksdb) endif() if(WITH_XBYAK) diff --git a/paddle/fluid/distributed/common/topk_calculator.h b/paddle/fluid/distributed/common/topk_calculator.h new file mode 100644 index 0000000000000000000000000000000000000000..326f0f718e9bd3145e4a61a015712d6add8d8eff --- /dev/null +++ b/paddle/fluid/distributed/common/topk_calculator.h @@ -0,0 +1,70 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +namespace paddle { +namespace distributed { +class TopkCalculator { + public: + TopkCalculator(int shard_num, size_t k) + : _shard_num(shard_num), _total_max_size(k) { + _shard_max_size = _total_max_size / shard_num; + _shard_max_size = _shard_max_size > 1 ? _shard_max_size : 1; + for (int i = 0; i < shard_num; ++i) { + _mpq.emplace(i, std::priority_queue, + std::greater>()); + } + } + ~TopkCalculator() {} + bool push(int shard_id, double value) { + if (_mpq.find(shard_id) == _mpq.end()) { + return false; + } + auto &pq = _mpq[shard_id]; + if (pq.size() < _shard_max_size) { + pq.push(value); + } else { + if (pq.top() < value) { + pq.pop(); + pq.push(value); + } + } + return true; + } + // TODO 再进行一次堆排序merge各个shard的结果 + int top() { + double total = 0; + for (const auto &item : _mpq) { + auto &pq = item.second; + if (!pq.empty()) { + total += pq.top(); + } + } + return total / _shard_num; + } + + private: + std::unordered_map, + std::greater>> + _mpq; + int _shard_num; + size_t _total_max_size; + size_t _shard_max_size; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/ps/service/CMakeLists.txt b/paddle/fluid/distributed/ps/service/CMakeLists.txt index b8de291072a1f5ed1f1672ee9b881edbb7ee8741..f0ac7bc6a06359b952881af1200b88ff042367cc 100755 --- a/paddle/fluid/distributed/ps/service/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/service/CMakeLists.txt @@ -1,7 +1,11 @@ set(BRPC_SRCS ps_client.cc server.cc) set_source_files_properties(${BRPC_SRCS}) -set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context) +if(WITH_HETERPS) + set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context rocksdb) +else() + set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context) +endif() brpc_library(sendrecv_rpc SRCS ${BRPC_SRCS} diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index 971c448bf2714b2d99763041dd9216e03161174f..921a110984a4a7b59be8c3d573febe0dc616e67a 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -429,6 +429,82 @@ std::future BrpcPsClient::Save(uint32_t table_id, return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode}); } +std::future BrpcPsClient::CacheShuffle( + uint32_t table_id, const std::string &path, const std::string &mode, + const std::string &cache_threshold) { + VLOG(1) << "BrpcPsClient send cmd for cache shuffle"; + return SendSaveCmd(table_id, PS_CACHE_SHUFFLE, {path, mode, cache_threshold}); +} + +std::future BrpcPsClient::CacheShuffleMultiTable( + std::vector tables, const std::string &path, const std::string &mode, + const std::string &cache_threshold) { + VLOG(1) << "BrpcPsClient send cmd for cache shuffle multi table one path"; + std::vector param; + param.push_back(path); + param.push_back(mode); + param.push_back(cache_threshold); + for (size_t i = 0; i < tables.size(); i++) { + param.push_back(std::to_string(tables[i])); + } + return SendSaveCmd(0, PS_CACHE_SHUFFLE, param); +} + +std::future BrpcPsClient::SaveCache(uint32_t table_id, + const std::string &path, + const std::string &mode) { + return SendSaveCmd(table_id, PS_SAVE_ONE_CACHE_TABLE, {path, mode}); +} + +std::future BrpcPsClient::GetCacheThreshold(uint32_t table_id, + double &cache_threshold) { + int cmd_id = PS_GET_CACHE_THRESHOLD; + size_t request_call_num = _server_channels.size(); + DownpourBrpcClosure *closure = new DownpourBrpcClosure( + request_call_num, + [request_call_num, cmd_id, &cache_threshold](void *done) { + int ret = 0; + auto *closure = (DownpourBrpcClosure *)done; + std::vector cache_thresholds(request_call_num, 0); + for (size_t i = 0; i < request_call_num; ++i) { + if (closure->check_response(i, cmd_id) != 0) { + ret = -1; + break; + } + std::string cur_res = closure->get_response(i, cmd_id); + cache_thresholds[i] = std::stod(cur_res); + } + double sum_threshold = 0.0; + int count = 0; + for (auto t : cache_thresholds) { + if (t >= 0) { + sum_threshold += t; + ++count; + } + } + if (count == 0) { + cache_threshold = 0; + } else { + cache_threshold = sum_threshold / count; + } + VLOG(1) << "client get cache threshold: " << cache_threshold; + closure->set_promise_value(ret); + }); + auto promise = std::make_shared>(); + closure->add_promise(promise); + std::future fut = promise->get_future(); + for (size_t i = 0; i < request_call_num; ++i) { + closure->request(i)->set_cmd_id(cmd_id); + closure->request(i)->set_table_id(table_id); + closure->request(i)->set_client_id(_client_id); + PsService_Stub rpc_stub(GetCmdChannel(i)); + closure->cntl(i)->set_timeout_ms(10800000); + rpc_stub.service(closure->cntl(i), closure->request(i), + closure->response(i), closure); + } + return fut; +} + std::future BrpcPsClient::Clear() { return SendCmd(-1, PS_CLEAR_ALL_TABLE, {}); } diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.h b/paddle/fluid/distributed/ps/service/brpc_ps_client.h index f109b473ca1f455140559037f05b10d2f18d8027..e2c16d496c42c2675500c404cc300523d3c3924e 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.h @@ -219,6 +219,20 @@ class BrpcPsClient : public PSClient { virtual int32_t RecvAndSaveTable(const uint64_t table_id, const std::string &path); + std::future CacheShuffle( + uint32_t table_id, const std::string &path, const std::string &mode, + const std::string &cache_threshold) override; + + std::future CacheShuffleMultiTable( + std::vector tables, const std::string &path, const std::string &mode, + const std::string &cache_threshold); + + std::future SaveCache(uint32_t table_id, const std::string &path, + const std::string &mode) override; + + std::future GetCacheThreshold(uint32_t table_id, + double &cache_threshold) override; + void PrintQueueSize(); void PrintQueueSizeThread(); diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index d22cca91f7816dad88fa2ccc2e6fc2576f5a95a8..d0bf06d49504a430c207e373410af6bbf72f588b 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -28,6 +28,13 @@ class RpcController; } // namespace protobuf } // namespace google +DEFINE_int32(pserver_timeout_ms_s2s, 10000, + "pserver request server timeout_ms"); +DEFINE_int32(pserver_connect_timeout_ms_s2s, 10000, + "pserver connect server timeout_ms"); +DEFINE_string(pserver_connection_type_s2s, "pooled", + "pserver connection_type[pooled:single]"); + namespace paddle { namespace distributed { @@ -93,6 +100,84 @@ uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) { return host.rank; } +int32_t BrpcPsServer::StartS2S() { + brpc::ChannelOptions options; + options.protocol = "baidu_std"; + options.timeout_ms = FLAGS_pserver_timeout_ms_s2s; + options.connection_type = FLAGS_pserver_connection_type_s2s; + options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms_s2s; + options.max_retry = 3; + + std::vector pserver_list = _environment->GetPsServers(); + _pserver_channels.resize(pserver_list.size()); + VLOG(2) << "pserver start s2s server_list size: " << _pserver_channels.size(); + + std::ostringstream os; + std::string server_ip_port; + + for (size_t i = 0; i < pserver_list.size(); ++i) { + server_ip_port.assign(pserver_list[i].ip.c_str()); + server_ip_port.append(":"); + server_ip_port.append(std::to_string(pserver_list[i].port)); + _pserver_channels[i].reset(new brpc::Channel()); + if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) { + LOG(ERROR) << "pserver connect to pserver:" << server_ip_port + << " Failed!"; + } + os << server_ip_port << ","; + } + LOG(INFO) << "pserver connect success: " << os.str(); + return 0; +} + +std::future BrpcPsServer::SendPServer2PServerMsg( + int msg_type, int to_pserver_id, const std::string &msg) { + auto promise = std::make_shared>(); + std::future fut = promise->get_future(); + if (to_pserver_id >= _pserver_channels.size()) { + LOG(FATAL) << "to_pserver_id is out of range pservers, which size is " + << _pserver_channels.size(); + promise->set_value(-1); + return fut; + } + auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) { + auto *closure = (DownpourPServerBrpcClosure *)done; + int32_t ret = closure->check_response(0, msg_type + 1000); + closure->set_promise_value(ret); + }); + + closure->add_promise(promise); + closure->request(0)->set_cmd_id(101); + closure->request(0)->set_client_id(_rank); + closure->request(0)->set_table_id(0); + closure->request(0)->set_data(msg); + PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get()); + rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0), + closure); + return fut; +} + +int32_t BrpcPsServer::ReceiveFromPServer(int msg_type, int pserver_id, + const std::string &msg) { + if (msg.length() == 0) { + LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response"; + return 0; + } + paddle::framework::BinaryArchive ar; + ar.SetReadBuffer(const_cast(msg.c_str()), msg.length(), nullptr); + if (ar.Cursor() == ar.Finish()) { + LOG(WARNING) << "SERVER>>RESPONSE ar = 0>> Finish S2S Response"; + return 0; + } + std::vector> data; + while (ar.Cursor() < ar.Finish()) { + data.push_back(ar.Get>()); + } + CHECK(ar.Cursor() == ar.Finish()); + this->_shuffled_ins->Write(std::move(data)); + return 0; +} + int32_t BrpcPsServer::Port() { return _server.listen_address().port; } int32_t BrpcPsService::Initialize() { @@ -117,6 +202,14 @@ int32_t BrpcPsService::Initialize() { _service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler; _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler; _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep; + // for save cache + + _service_handler_map[PS_SAVE_ONE_CACHE_TABLE] = + &BrpcPsService::SaveCacheTable; + _service_handler_map[PS_GET_CACHE_THRESHOLD] = + &BrpcPsService::GetCacheThreshold; + _service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle; + auto &profiler = CostProfiler::instance(); profiler.register_profiler("pserver_server_pull_dense"); profiler.register_profiler("pserver_server_push_dense"); @@ -168,19 +261,29 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, response->set_err_msg(""); auto *table = _server->GetTable(request->table_id()); brpc::Controller *cntl = static_cast(cntl_base); - auto itr = _service_handler_map.find(request->cmd_id()); - if (itr == _service_handler_map.end()) { - std::string err_msg( - "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); - err_msg.append(std::to_string(request->cmd_id())); - set_response_code(*response, -1, err_msg.c_str()); - return; - } - serviceHandlerFunc handler_func = itr->second; - int service_ret = (this->*handler_func)(table, *request, *response, cntl); - if (service_ret != 0) { - response->set_err_code(service_ret); - response->set_err_msg("server internal error"); + + if (request->cmd_id() < 100) { + auto itr = _service_handler_map.find(request->cmd_id()); + if (itr == _service_handler_map.end()) { + std::string err_msg( + "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:"); + err_msg.append(std::to_string(request->cmd_id())); + set_response_code(*response, -1, err_msg.c_str()); + return; + } + serviceHandlerFunc handler_func = itr->second; + int service_ret = (this->*handler_func)(table, *request, *response, cntl); + if (service_ret != 0) { + response->set_err_code(service_ret); + response->set_err_msg("server internal error"); + } + } else { + int service_ret = _server->HandlePServer2PServerMsg( + request->cmd_id(), request->client_id(), request->data()); + if (service_ret != 0) { + response->set_err_code(-1); + response->set_err_msg("handle_pserver2pserver_msg failed"); + } } } @@ -561,6 +664,90 @@ int32_t BrpcPsService::SaveAllTable(Table *table, return 0; } +int32_t BrpcPsService::SaveCacheTable(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 2) { + set_response_code( + response, -1, + "PsRequestMessage.datas is requeired at least 3, path&mode"); + return -1; + } + table->Flush(); + int32_t feasign_size = 0; + // if (_server->_shuffled_ins->size() <= 0) { + // LOG(WARNING) << "shuffled ins size <= 0"; + //} + feasign_size = table->SaveCache(request.params(0), request.params(1), + _server->_shuffled_ins); + if (feasign_size < 0) { + set_response_code(response, -1, "table save failed"); + return -1; + } + return feasign_size; +} + +int32_t BrpcPsService::CacheShuffle(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + // start cache shuffle + CHECK_TABLE_EXIST(table, request, response) + if (request.params_size() < 3) { + set_response_code(response, -1, + "PsRequestMessage.datas is requeired at least 3, " + "path&mode&cache_threshold"); + return -1; + } + table->Flush(); + double cache_threshold = std::stod(request.params(2)); + LOG(INFO) << "cache threshold for cache shuffle: " << cache_threshold; + // auto shuffled_ins = paddle::ps::make_channel>(); + // shuffled_ins->set_block_size(80000); + _server->StartS2S(); + std::function(int msg_type, int to_pserver_id, + const std::string &msg)> + send_msg_func = [this](int msg_type, int to_pserver_id, + const std::string &msg) -> std::future { + return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg); + }; + + std::vector table_ptrs; + for (size_t i = 3; i < request.params_size(); ++i) { + int table_id = std::stoi(request.params(i)); + Table *table_ptr = _server->GetTable(table_id); + table_ptrs.push_back(table_ptr); + } + if (table_ptrs.empty()) { + table_ptrs.push_back(table); + } + + table->CacheShuffle(request.params(0), request.params(1), cache_threshold, + send_msg_func, _server->_shuffled_ins, table_ptrs); + return 0; +} + +int32_t BrpcPsService::GetCacheThreshold(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { + CHECK_TABLE_EXIST(table, request, response) + table->Flush(); + double cache_threshold = 0.0; + cache_threshold = table->GetCacheThreshold(); + if (cache_threshold < 0) { + LOG(WARNING) << "wrong threshold: " << cache_threshold; + } + std::stringstream ss; + ss << std::setprecision(15) << cache_threshold; + std::string cache_threshold_str = ss.str(); + response.set_data(cache_threshold_str); + return 0; +} + int32_t BrpcPsService::ShrinkTable(Table *table, const PsRequestMessage &request, PsResponseMessage &response, diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.h b/paddle/fluid/distributed/ps/service/brpc_ps_server.h index 250f465d84253731df3198ca92baca022864974b..40ed652ec6be331acdd08612d7f0319b8c870b79 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.h +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.h @@ -53,6 +53,12 @@ class BrpcPsServer : public PSServer { } int32_t Port(); + virtual int32_t StartS2S() override; + virtual ::std::future SendPServer2PServerMsg( + int msg_type, int to_pserver_id, const std::string &msg) override; + virtual int32_t ReceiveFromPServer(int msg_type, int pserver_id, + const std::string &msg) override; + private: virtual int32_t Initialize(); mutable std::mutex mutex_; @@ -123,6 +129,16 @@ class BrpcPsService : public PsBaseService { int32_t PushGlobalStep(Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); + int32_t CacheShuffle(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + int32_t SaveCacheTable(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, brpc::Controller *cntl); + + int32_t GetCacheThreshold(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl); + bool _is_initialize_shard_info; std::mutex _initialize_shard_mutex; std::unordered_map _service_handler_map; diff --git a/paddle/fluid/distributed/ps/service/ps_client.h b/paddle/fluid/distributed/ps/service/ps_client.h index 6f27b0eb046245c722100bcfdb2e6b89d92ec488..0d3d23be4e8d137d241adfaaadcd985c7d272708 100644 --- a/paddle/fluid/distributed/ps/service/ps_client.h +++ b/paddle/fluid/distributed/ps/service/ps_client.h @@ -198,6 +198,7 @@ class PSClient { _msg_handler_map[msg_type] = handler; return 0; } + virtual int HandleClient2ClientMsg(int msg_type, int from_client_id, const std::string &msg) { auto itr = _msg_handler_map.find(msg_type); @@ -239,6 +240,46 @@ class PSClient { const float **update_values, size_t num) = 0; + // for save cache + virtual std::future CacheShuffle( + uint32_t table_id, const std::string &path, const std::string &mode, + const std::string &cache_threshold) { + VLOG(0) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + + virtual std::future CacheShuffleMultiTable( + std::vector tables, const std::string &path, const std::string &mode, + const std::string &cache_threshold) { + VLOG(0) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + + virtual std::future SaveCache(uint32_t table_id, + const std::string &path, + const std::string &mode) { + VLOG(0) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + + virtual std::future GetCacheThreshold(uint32_t table_id, + double &cache_threshold) { + VLOG(0) << "Did not implement"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + protected: virtual int32_t Initialize() = 0; size_t _client_id; diff --git a/paddle/fluid/distributed/ps/service/sendrecv.proto b/paddle/fluid/distributed/ps/service/sendrecv.proto index 580f411c28c07ce1eb6afa14a7cc49e1052a83ef..46dcc2058f4b8754ef1dc27634ff31782291f324 100755 --- a/paddle/fluid/distributed/ps/service/sendrecv.proto +++ b/paddle/fluid/distributed/ps/service/sendrecv.proto @@ -65,6 +65,8 @@ enum PsCmdID { PS_SAVE_WITH_SHARD = 44; PS_QUERY_WITH_SCOPE = 45; PS_QUERY_WITH_SHARD = 46; + // pserver2pserver cmd start from 100 + PS_S2S_MSG = 101; } message PsRequestMessage { diff --git a/paddle/fluid/distributed/ps/service/server.cc b/paddle/fluid/distributed/ps/service/server.cc index 65f7ae821cef1ace041711cb1bab9794935c6dfa..a6e0f39474b060b879cc6ca25d060e9b66cd6b41 100644 --- a/paddle/fluid/distributed/ps/service/server.cc +++ b/paddle/fluid/distributed/ps/service/server.cc @@ -67,6 +67,8 @@ int32_t PSServer::Configure( _config = config.server_param(); _rank = server_rank; _environment = &env; + _shuffled_ins = + paddle::framework::MakeChannel>(); size_t shard_num = env.GetPsServers().size(); const auto &downpour_param = _config.downpour_server_param(); diff --git a/paddle/fluid/distributed/ps/service/server.h b/paddle/fluid/distributed/ps/service/server.h index 5da819326b05260630c22d73f074d75915130211..c044e82884604b3a7104dfaa94696fb61eb7b049 100644 --- a/paddle/fluid/distributed/ps/service/server.h +++ b/paddle/fluid/distributed/ps/service/server.h @@ -89,6 +89,45 @@ class PSServer { return &_table_map; } + // for cache + virtual int32_t StartS2S() { return 0; } + + virtual ::std::future SendPServer2PServerMsg( + int msg_type, int to_pserver_id, const std::string &msg) { + LOG(FATAL) << "NotImplementError: PSServer::send_pserver2pserver_msg"; + std::promise promise; + std::future fut = promise.get_future(); + promise.set_value(-1); + return fut; + } + + typedef std::function MsgHandlerFunc; + virtual int RegistePServer2PServerMsgHandler(int msg_type, + MsgHandlerFunc handler) { + _msg_handler_map[msg_type] = handler; + return 0; + } + virtual int HandlePServer2PServerMsg(int msg_type, int from_pserver_id, + const std::string &msg) { + auto itr = _msg_handler_map.find(msg_type); + if (itr == _msg_handler_map.end()) { + if (msg_type == 101) { + return ReceiveFromPServer(msg_type, from_pserver_id, msg); + } else { + LOG(WARNING) << "unknown pserver2pserver_msg type:" << msg_type; + return -1; + } + } + return itr->second(msg_type, from_pserver_id, msg); + } + virtual int32_t ReceiveFromPServer(int msg_type, int pserver_id, + const std::string &msg) { + LOG(FATAL) << "NotImplementError::PSServer::ReceiveFromPServer"; + return -1; + } + + paddle::framework::Channel> _shuffled_ins; + protected: virtual int32_t Initialize() = 0; @@ -97,6 +136,7 @@ class PSServer { ServerParameter _config; PSEnvironment *_environment; std::unordered_map> _table_map; + std::unordered_map _msg_handler_map; protected: std::shared_ptr scope_; diff --git a/paddle/fluid/distributed/ps/table/CMakeLists.txt b/paddle/fluid/distributed/ps/table/CMakeLists.txt index bb6725b08425a1a4ac93e75fc4d88cb370eec8ba..f2b9eb71f5a640af6375d5092e123e76df63512b 100644 --- a/paddle/fluid/distributed/ps/table/CMakeLists.txt +++ b/paddle/fluid/distributed/ps/table/CMakeLists.txt @@ -18,17 +18,12 @@ include_directories(${PADDLE_LIB_THIRD_PARTY_PATH}libmct/src/extern_libmct/libmc set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") -set(EXTERN_DEP "") -if(WITH_HETERPS) - set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc) - set(EXTERN_DEP rocksdb) -else() - set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc) -endif() +set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc) +#set(EXTERN_DEP rocksdb) cc_library(common_table SRCS ${TABLE_SRC} DEPS ${TABLE_DEPS} ${RPC_DEPS} graph_edge graph_node device_context string_helper -simple_threadpool xxhash generator ${EXTERN_DEP}) +simple_threadpool xxhash generator) set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) @@ -41,13 +36,13 @@ set_source_files_properties(ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DI set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) +set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto) cc_library(ctr_accessor SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule) -cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table) - -set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) -cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table) +cc_library(sparse_table SRCS memory_sparse_table.cc ssd_sparse_table.cc memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table rocksdb) -cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) +cc_library(table SRCS table.cc DEPS sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost) target_link_libraries(table -fopenmp) diff --git a/paddle/fluid/distributed/ps/table/accessor.h b/paddle/fluid/distributed/ps/table/accessor.h index 024af327a33afcae4056ceb77a82b5c1828f7e50..7713c2bda295fad1ab1c5289d3e2dd893b054591 100644 --- a/paddle/fluid/distributed/ps/table/accessor.h +++ b/paddle/fluid/distributed/ps/table/accessor.h @@ -117,6 +117,11 @@ class ValueAccessor { virtual bool Save(float* value, int param) = 0; // update delta_score and unseen_days after save virtual void UpdateStatAfterSave(float* value, int param) {} + // 判断该value是否保存到ssd + virtual bool SaveSSD(float* value) = 0; + // + virtual bool SaveCache(float* value, int param, + double global_cache_threshold) = 0; // keys不存在时,为values生成随机值 virtual int32_t Create(float** value, size_t num) = 0; diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 863c397b08ad2612fa249b1c7c7a35a5c6d7bafd..df0d8b2d3a8abf22095c0aef20ba7fdbd3f90245 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -38,13 +38,13 @@ #include #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/common_table.h" -#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h" #include "paddle/fluid/distributed/ps/table/graph/class_macro.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/string/string_helper.h" #include "paddle/phi/core/utils/rw_lock.h" #ifdef PADDLE_WITH_HETERPS +#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #endif namespace paddle { diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_accessor.cc index 715abe270e52b5793c7c63537b4e2cf237c040ff..ef7311824faa6d6aad7247e6f6a71732cbf6445b 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.cc @@ -34,6 +34,8 @@ int CtrCommonAccessor::Initialize() { common_feature_value.embedx_dim = _config.embedx_dim(); common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim(); _show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate(); + _ssd_unseenday_threshold = + _config.ctr_accessor_param().ssd_unseenday_threshold(); if (_config.ctr_accessor_param().show_scale()) { _show_scale = true; @@ -77,6 +79,25 @@ bool CtrCommonAccessor::Shrink(float* value) { return false; } +bool CtrCommonAccessor::SaveCache(float* value, int param, + double global_cache_threshold) { + auto base_threshold = _config.ctr_accessor_param().base_threshold(); + auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); + if (ShowClickScore(common_feature_value.Show(value), + common_feature_value.Click(value)) >= base_threshold && + common_feature_value.UnseenDays(value) <= delta_keep_days) { + return common_feature_value.Show(value) > global_cache_threshold; + } + return false; +} + +bool CtrCommonAccessor::SaveSSD(float* value) { + if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) { + return true; + } + return false; +} + bool CtrCommonAccessor::Save(float* value, int param) { auto base_threshold = _config.ctr_accessor_param().base_threshold(); auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.h b/paddle/fluid/distributed/ps/table/ctr_accessor.h index a599bfca7f6d290cdec4b9d5e4ac5ac36e2e61bb..327c4cea760ebed7f5abd152fc0bbd7e887e0c97 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.h @@ -148,6 +148,9 @@ class CtrCommonAccessor : public ValueAccessor { // param = 1, save delta feature // param = 2, save xbox base feature bool Save(float* value, int param) override; + bool SaveCache(float* value, int param, + double global_cache_threshold) override; + bool SaveSSD(float* value) override; // update delta_score and unseen_days after save void UpdateStatAfterSave(float* value, int param) override; // keys不存在时,为values生成随机值 diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc index f0d9426343d7bf523a131717a1ee57d3b2fed612..4b84b7e8c36c309ec4fe3f2c65fcea09b85d90e0 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc @@ -74,25 +74,26 @@ bool CtrDoubleAccessor::Shrink(float* value) { } return false; } + bool CtrDoubleAccessor::SaveSSD(float* value) { if (CtrDoubleFeatureValue::UnseenDays(value) > _ssd_unseenday_threshold) { return true; } return false; } -// bool CtrDoubleAccessor::save_cache( -// float* value, int param, double global_cache_threshold) { -// auto base_threshold = _config.ctr_accessor_param().base_threshold(); -// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); -// if (ShowClickScore(CtrDoubleFeatureValue::Show(value), -// CtrDoubleFeatureValue::Click(value)) >= base_threshold -// && CtrDoubleFeatureValue::UnseenDays(value) <= -// delta_keep_days) { -// return CtrDoubleFeatureValue::Show(value) > -// global_cache_threshold; -// } -// return false; -// } + +bool CtrDoubleAccessor::SaveCache(float* value, int param, + double global_cache_threshold) { + auto base_threshold = _config.ctr_accessor_param().base_threshold(); + auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days(); + if (ShowClickScore(CtrDoubleFeatureValue::Show(value), + CtrDoubleFeatureValue::Click(value)) >= base_threshold && + CtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) { + return CtrDoubleFeatureValue::Show(value) > global_cache_threshold; + } + return false; +} + bool CtrDoubleAccessor::Save(float* value, int param) { // auto base_threshold = _config.ctr_accessor_param().base_threshold(); // auto delta_threshold = _config.ctr_accessor_param().delta_threshold(); diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.h b/paddle/fluid/distributed/ps/table/ctr_double_accessor.h index c58602065036fd157b8d98c102b638cf575d6fd0..5b781b2621c5bc9c864c41abe0d028ccdbe25052 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.h +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.h @@ -167,6 +167,8 @@ class CtrDoubleAccessor : public ValueAccessor { // param = 1, save delta feature // param = 3, save all feature with time decay virtual bool Save(float* value, int param) override; + bool SaveCache(float* value, int param, + double global_cache_threshold) override; // update delta_score and unseen_days after save virtual void UpdateStatAfterSave(float* value, int param) override; // 判断该value是否保存到ssd diff --git a/paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h b/paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h index ff2271d468e39fad874ea0b73da345906a335b37..223c8fafd26ab7f84bff8c088d00c71ce29bb342 100644 --- a/paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h +++ b/paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h @@ -11,9 +11,8 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - #pragma once -#ifdef PADDLE_WITH_HETERPS + #include #include #include @@ -154,6 +153,5 @@ class RocksDBHandler { std::vector _handles; rocksdb::DB* _db; }; -} -} -#endif +} // distributed +} // paddle diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index e6c52e0b9b0c8e0f0ea675254130d4a6f34a49ea..ee6a801fa91834b0eb8ae795caf3f1b7a579b7ef 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -23,14 +23,17 @@ #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" +DEFINE_bool(pserver_print_missed_key_num_every_push, false, + "pserver_print_missed_key_num_every_push"); +DEFINE_bool(pserver_create_value_when_push, true, + "pserver create value when push"); +DEFINE_bool(pserver_enable_create_feasign_randomly, false, + "pserver_enable_create_feasign_randomly"); +DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry"); + namespace paddle { namespace distributed { -// TODO(zhaocaibei123): configure -bool FLAGS_pserver_create_value_when_push = true; -int FLAGS_pserver_table_save_max_retry = 3; -bool FLAGS_pserver_enable_create_feasign_randomly = false; - int32_t MemorySparseTable::Initialize() { _shards_task_pool.resize(_task_pool_size); for (int i = 0; i < _shards_task_pool.size(); ++i) { @@ -142,7 +145,7 @@ int32_t MemorySparseTable::Load(const std::string& path, LOG(ERROR) << "MemorySparseTable load failed, retry it! path:" << channel_config.path << " , retry_num=" << retry_num; } - if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) { + if (retry_num > FLAGS_pserver_table_save_max_retry) { LOG(ERROR) << "MemorySparseTable load failed reach max limit!"; exit(-1); } @@ -213,7 +216,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, << file_list[file_start_idx + i] << " , retry_num=" << retry_num; } - if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) { + if (retry_num > FLAGS_pserver_table_save_max_retry) { LOG(ERROR) << "MemorySparseTable load failed reach max limit!"; exit(-1); } @@ -293,7 +296,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname, if (is_write_failed) { _afs_client.remove(channel_config.path); } - if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) { + if (retry_num > FLAGS_pserver_table_save_max_retry) { LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!"; exit(-1); } diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_table.h index 87a73bd22fa2f972eb795728fc9ff9722e639bb6..ec86239ffb161f0b7718e46c572ac0e1f447b593 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.h @@ -62,9 +62,11 @@ class MemorySparseTable : public Table { int32_t InitializeShard() override { return 0; } int32_t InitializeValue(); - int32_t Load(const std::string& path, const std::string& param) override; + virtual int32_t Load(const std::string& path, + const std::string& param) override; - int32_t Save(const std::string& path, const std::string& param) override; + virtual int32_t Save(const std::string& path, + const std::string& param) override; int32_t LoadLocalFS(const std::string& path, const std::string& param); int32_t SaveLocalFS(const std::string& path, const std::string& param, @@ -83,7 +85,7 @@ class MemorySparseTable : public Table { int32_t PushSparse(const uint64_t* keys, const float** values, size_t num); int32_t Flush() override; - int32_t Shrink(const std::string& param) override; + virtual int32_t Shrink(const std::string& param) override; void Clear() override; void* GetShard(size_t shard_idx) override { diff --git a/paddle/fluid/distributed/ps/table/sparse_accessor.h b/paddle/fluid/distributed/ps/table/sparse_accessor.h index 5ca5d21707a2b5dd1322c5783692d05dc0c0e080..875904847b2ea113570f2a3268dfd3fdc4bce64b 100644 --- a/paddle/fluid/distributed/ps/table/sparse_accessor.h +++ b/paddle/fluid/distributed/ps/table/sparse_accessor.h @@ -135,6 +135,11 @@ class SparseAccessor : public ValueAccessor { // param = 1, save delta feature // param = 2, save xbox base feature bool Save(float* value, int param) override; + + bool SaveCache(float* value, int param, double global_cache_threshold) { + return false; + } + bool SaveSSD(float* value) { return false; } // update delta_score and unseen_days after save void UpdateStatAfterSave(float* value, int param) override; // keys不存在时,为values生成随机值 diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1359d1323d8972d856953b2f9435556a7879195 --- /dev/null +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc @@ -0,0 +1,759 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h" +#include "paddle/fluid/distributed/common/cost_timer.h" +#include "paddle/fluid/distributed/common/local_random.h" +#include "paddle/fluid/distributed/common/topk_calculator.h" +#include "paddle/fluid/framework/archive.h" +#include "paddle/utils/string/string_helper.h" + +DECLARE_bool(pserver_print_missed_key_num_every_push); +DECLARE_bool(pserver_create_value_when_push); +DECLARE_bool(pserver_enable_create_feasign_randomly); +DEFINE_bool(pserver_open_strict_check, false, "pserver_open_strict_check"); +DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file"); +DEFINE_int32(pserver_load_batch_size, 5000, "load batch size for ssd"); + +namespace paddle { +namespace distributed { + +int32_t SSDSparseTable::Initialize() { + MemorySparseTable::Initialize(); + _db = paddle::distributed::RocksDBHandler::GetInstance(); + _db->initialize(FLAGS_rocksdb_path, _real_local_shard_num); + return 0; +} + +int32_t SSDSparseTable::InitializeShard() { return 0; } + +int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys, + size_t num) { + CostTimer timer("pserver_downpour_sparse_select_all"); + size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_size = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); + size_t select_value_size = + _value_accesor->GetAccessorInfo().select_size / sizeof(float); + + { // 从table取值 or create + std::vector> tasks(_real_local_shard_num); + std::vector>> task_keys( + _real_local_shard_num); + for (size_t i = 0; i < num; ++i) { + int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num; + task_keys[shard_id].push_back({keys[i], i}); + } + + std::atomic missed_keys{0}; + for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + tasks[shard_id] = + _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( + [this, shard_id, &task_keys, value_size, mf_value_size, + select_value_size, pull_values, keys, &missed_keys]() -> int { + auto& keys = task_keys[shard_id]; + auto& local_shard = _local_shards[shard_id]; + float data_buffer[value_size]; + float* data_buffer_ptr = data_buffer; + for (int i = 0; i < keys.size(); ++i) { + uint64_t key = keys[i].first; + auto itr = local_shard.find(key); + size_t data_size = value_size - mf_value_size; + if (itr == local_shard.end()) { + // pull rocksdb + std::string tmp_string(""); + if (_db->get(shard_id, (char*)&key, sizeof(uint64_t), + tmp_string) > 0) { + ++missed_keys; + if (FLAGS_pserver_create_value_when_push) { + memset(data_buffer, 0, sizeof(float) * data_size); + } else { + auto& feature_value = local_shard[key]; + feature_value.resize(data_size); + float* data_ptr = + const_cast(feature_value.data()); + _value_accesor->Create(&data_buffer_ptr, 1); + memcpy(data_ptr, data_buffer_ptr, + data_size * sizeof(float)); + } + } else { + data_size = tmp_string.size() / sizeof(float); + memcpy(data_buffer_ptr, + paddle::string::str_to_float(tmp_string), + data_size * sizeof(float)); + // from rocksdb to mem + auto& feature_value = local_shard[key]; + feature_value.resize(data_size); + memcpy(const_cast(feature_value.data()), + data_buffer_ptr, data_size * sizeof(float)); + _db->del_data(shard_id, (char*)&key, sizeof(uint64_t)); + } + } else { + data_size = itr.value().size(); + memcpy(data_buffer_ptr, itr.value().data(), + data_size * sizeof(float)); + } + for (int mf_idx = data_size; mf_idx < value_size; ++mf_idx) { + data_buffer[mf_idx] = 0.0; + } + int pull_data_idx = keys[i].second; + float* select_data = + pull_values + pull_data_idx * select_value_size; + _value_accesor->Select(&select_data, + (const float**)&data_buffer_ptr, 1); + } + return 0; + }); + } + for (size_t i = 0; i < _real_local_shard_num; ++i) { + tasks[i].wait(); + } + if (FLAGS_pserver_print_missed_key_num_every_push) { + LOG(WARNING) << "total pull keys:" << num + << " missed_keys:" << missed_keys.load(); + } + } + return 0; +} + +int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values, + size_t num) { + CostTimer timer("pserver_downpour_sparse_update_all"); + // 构造value push_value的数据指针 + size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_col = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); + size_t update_value_col = + _value_accesor->GetAccessorInfo().update_size / sizeof(float); + { + std::vector> tasks(_real_local_shard_num); + std::vector>> task_keys( + _real_local_shard_num); + for (size_t i = 0; i < num; ++i) { + int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num; + task_keys[shard_id].push_back({keys[i], i}); + } + for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { + tasks[shard_id] = + _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( + [this, shard_id, value_col, mf_value_col, update_value_col, + values, &task_keys]() -> int { + auto& keys = task_keys[shard_id]; + auto& local_shard = _local_shards[shard_id]; + float data_buffer[value_col]; + float* data_buffer_ptr = data_buffer; + for (int i = 0; i < keys.size(); ++i) { + uint64_t key = keys[i].first; + uint64_t push_data_idx = keys[i].second; + const float* update_data = + values + push_data_idx * update_value_col; + auto itr = local_shard.find(key); + if (itr == local_shard.end()) { + if (FLAGS_pserver_enable_create_feasign_randomly && + !_value_accesor->CreateValue(1, update_data)) { + continue; + } + auto value_size = value_col - mf_value_col; + auto& feature_value = local_shard[key]; + feature_value.resize(value_size); + _value_accesor->Create(&data_buffer_ptr, 1); + memcpy(const_cast(feature_value.data()), + data_buffer_ptr, value_size * sizeof(float)); + itr = local_shard.find(key); + } + auto& feature_value = itr.value(); + float* value_data = const_cast(feature_value.data()); + size_t value_size = feature_value.size(); + + if (value_size == + value_col) { // 已拓展到最大size, 则就地update + _value_accesor->Update(&value_data, &update_data, 1); + } else { // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了 + memcpy(data_buffer_ptr, value_data, + value_size * sizeof(float)); + _value_accesor->Update(&data_buffer_ptr, &update_data, 1); + if (_value_accesor->NeedExtendMF(data_buffer)) { + feature_value.resize(value_col); + value_data = const_cast(feature_value.data()); + _value_accesor->Create(&value_data, 1); + } + memcpy(value_data, data_buffer_ptr, + value_size * sizeof(float)); + } + } + return 0; + }); + } + for (size_t i = 0; i < _real_local_shard_num; ++i) { + tasks[i].wait(); + } + } + /* + //update && value 的转置 + thread_local Eigen::MatrixXf update_matrix; + float* transposed_update_data[update_value_col]; + make_matrix_with_eigen(num, update_value_col, update_matrix, + transposed_update_data); + copy_array_to_eigen(values, update_matrix); + + thread_local Eigen::MatrixXf value_matrix; + float* transposed_value_data[value_col]; + make_matrix_with_eigen(num, value_col, value_matrix, transposed_value_data); + copy_matrix_to_eigen((const float**)(value_ptrs->data()), value_matrix); + + //批量update + { + CostTimer accessor_timer("pslib_downpour_sparse_update_accessor"); + _value_accesor->update(transposed_value_data, (const + float**)transposed_update_data, num); + } + copy_eigen_to_matrix(value_matrix, value_ptrs->data()); + */ + return 0; +} + +int32_t SSDSparseTable::Shrink(const std::string& param) { + int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < _real_local_shard_num; ++i) { + uint64_t mem_count = 0; + uint64_t ssd_count = 0; + + LOG(INFO) << "SSDSparseTable begin shrink shard:" << i; + auto& shard = _local_shards[i]; + for (auto it = shard.begin(); it != shard.end();) { + if (_value_accesor->Shrink(it.value().data())) { + it = shard.erase(it); + mem_count++; + } else { + ++it; + } + } + auto* it = _db->get_iterator(i); + for (it->SeekToFirst(); it->Valid(); it->Next()) { + if (_value_accesor->Shrink( + paddle::string::str_to_float(it->value().data()))) { + _db->del_data(i, it->key().data(), it->key().size()); + ssd_count++; + } else { + _db->put(i, it->key().data(), it->key().size(), it->value().data(), + it->value().size()); + } + } + delete it; + LOG(INFO) << "SSDSparseTable shrink success. shard:" << i << " delete MEM[" + << mem_count << "] SSD[" << ssd_count << "]"; + //_db->flush(i); + } + return 0; +} + +int32_t SSDSparseTable::UpdateTable() { + // TODO implement with multi-thread + int count = 0; + for (size_t i = 0; i < _real_local_shard_num; ++i) { + auto& shard = _local_shards[i]; + // from mem to ssd + for (auto it = shard.begin(); it != shard.end();) { + if (_value_accesor->SaveSSD(it.value().data())) { + _db->put(i, (char*)&it.key(), sizeof(uint64_t), + (char*)it.value().data(), it.value().size() * sizeof(float)); + count++; + it = shard.erase(it); + } else { + ++it; + } + } + _db->flush(i); + } + LOG(INFO) << "Table>> update count: " << count; + return 0; +} + +int64_t SSDSparseTable::LocalSize() { + int64_t local_size = 0; + for (size_t i = 0; i < _real_local_shard_num; ++i) { + local_size += _local_shards[i].size(); + } + // TODO rocksdb size + uint64_t ssd_size = 0; + // _db->get_estimate_key_num(ssd_size); + // return local_size + ssd_size; + return local_size; +} + +int32_t SSDSparseTable::Save(const std::string& path, + const std::string& param) { + if (_real_local_shard_num == 0) { + _local_show_threshold = -1; + return 0; + } + int save_param = atoi(param.c_str()); // batch_model:0 xbox:1 + // if (save_param == 5) { + // return save_patch(path, save_param); + // } + + // LOG(INFO) << "table cache rate is: " << _config.sparse_table_cache_rate(); + LOG(INFO) << "table cache rate is: " << _config.sparse_table_cache_rate(); + LOG(INFO) << "enable_sparse_table_cache: " + << _config.enable_sparse_table_cache(); + LOG(INFO) << "LocalSize: " << LocalSize(); + if (_config.enable_sparse_table_cache()) { + LOG(INFO) << "Enable sparse table cache, top n:" << _cache_tk_size; + } + _cache_tk_size = LocalSize() * _config.sparse_table_cache_rate(); + TopkCalculator tk(_real_local_shard_num, _cache_tk_size); + size_t file_start_idx = _avg_local_shard_num * _shard_idx; + std::string table_path = TableDir(path); + _afs_client.remove(paddle::string::format_string( + "%s/part-%03d-*", table_path.c_str(), _shard_idx)); + int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; + + // std::atomic feasign_size; + std::atomic feasign_size_all{0}; + // feasign_size = 0; + + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < _real_local_shard_num; ++i) { + FsChannelConfig channel_config; + if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) { + channel_config.path = paddle::string::format_string( + "%s/part-%03d-%05d.gz", table_path.c_str(), _shard_idx, + file_start_idx + i); + } else { + channel_config.path = + paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(), + _shard_idx, file_start_idx + i); + } + channel_config.converter = _value_accesor->Converter(save_param).converter; + channel_config.deconverter = + _value_accesor->Converter(save_param).deconverter; + int err_no = 0; + int retry_num = 0; + bool is_write_failed = false; + int feasign_size = 0; + auto& shard = _local_shards[i]; + do { + err_no = 0; + feasign_size = 0; + is_write_failed = false; + auto write_channel = + _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); + for (auto it = shard.begin(); it != shard.end(); ++it) { + if (_config.enable_sparse_table_cache() && + (save_param == 1 || save_param == 2) && + _value_accesor->Save(it.value().data(), 4)) { + // tk.push(i, it.value().data()[2]); + tk.push(i, _value_accesor->GetField(it.value().data(), "show")); + } + if (_value_accesor->Save(it.value().data(), save_param)) { + std::string format_value = _value_accesor->ParseToString( + it.value().data(), it.value().size()); + if (0 != + write_channel->write_line(paddle::string::format_string( + "%lu %s", it.key(), format_value.c_str()))) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) << "SSDSparseTable save failed, retry it! path:" + << channel_config.path << ", retry_num=" << retry_num; + break; + } + ++feasign_size; + } + } + + if (err_no == -1 && !is_write_failed) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) << "SSDSparseTable save failed after write, retry it! " + << "path:" << channel_config.path + << " , retry_num=" << retry_num; + } + if (is_write_failed) { + _afs_client.remove(channel_config.path); + continue; + } + + // delta and cache and revert is all in mem, base in rocksdb + if (save_param != 1) { + auto* it = _db->get_iterator(i); + for (it->SeekToFirst(); it->Valid(); it->Next()) { + bool need_save = _value_accesor->Save( + paddle::string::str_to_float(it->value().data()), save_param); + _value_accesor->UpdateStatAfterSave( + paddle::string::str_to_float(it->value().data()), save_param); + if (need_save) { + std::string format_value = _value_accesor->ParseToString( + paddle::string::str_to_float(it->value().data()), + it->value().size() / sizeof(float)); + if (0 != + write_channel->write_line(paddle::string::format_string( + "%lu %s", *((uint64_t*)const_cast(it->key().data())), + format_value.c_str()))) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) << "SSDSparseTable save failed, retry it! path:" + << channel_config.path << ", retry_num=" << retry_num; + break; + } + if (save_param == 3) { + _db->put(i, it->key().data(), it->key().size(), + it->value().data(), it->value().size()); + } + ++feasign_size; + } + } + delete it; + } + + write_channel->close(); + if (err_no == -1) { + ++retry_num; + is_write_failed = true; + LOG(ERROR) << "SSDSparseTable save failed after write, retry it! " + << "path:" << channel_config.path + << " , retry_num=" << retry_num; + } + if (is_write_failed) { + _afs_client.remove(channel_config.path); + } + } while (is_write_failed); + feasign_size_all += feasign_size; + for (auto it = shard.begin(); it != shard.end(); ++it) { + _value_accesor->UpdateStatAfterSave(it.value().data(), save_param); + } + } + if (save_param == 3) { + UpdateTable(); + _cache_tk_size = LocalSize() * _config.sparse_table_cache_rate(); + LOG(INFO) << "SSDSparseTable update success."; + } + LOG(INFO) << "SSDSparseTable save success, path:" + << paddle::string::format_string("%s/%03d/part-%03d-", path.c_str(), + _config.table_id(), _shard_idx) + << " from " << file_start_idx << " to " + << file_start_idx + _real_local_shard_num - 1; + // return feasign_size_all; + _local_show_threshold = tk.top(); + LOG(INFO) << "local cache threshold: " << _local_show_threshold; + // int32 may overflow need to change return value + return 0; +} + +int64_t SSDSparseTable::CacheShuffle( + const std::string& path, const std::string& param, double cache_threshold, + std::function(int msg_type, int to_pserver_id, + std::string& msg)> + send_msg_func, + paddle::framework::Channel>& + shuffled_channel, + const std::vector& table_ptrs) { + LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold + << " param:" << param; + int save_param = atoi(param.c_str()); // batch_model:0 xbox:1 + if (!_config.enable_sparse_table_cache() || cache_threshold < 0) { + LOG(WARNING) + << "cache shuffle failed not enable table cache or cache threshold < 0 " + << _config.enable_sparse_table_cache() << " or " << cache_threshold; + // return -1; + } + int shuffle_node_num = _config.sparse_table_cache_file_num(); + LOG(INFO) << "Table>> shuffle node num is: " << shuffle_node_num; + size_t file_start_idx = _avg_local_shard_num * _shard_idx; + int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20; + + std::vector< + paddle::framework::ChannelWriter>> + writers(_real_local_shard_num); + std::vector>> datas( + _real_local_shard_num); + + int feasign_size = 0; + std::vector>> + tmp_channels; + for (size_t i = 0; i < _real_local_shard_num; ++i) { + tmp_channels.push_back( + paddle::framework::MakeChannel>()); + } + + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = 0; i < _real_local_shard_num; ++i) { + paddle::framework::ChannelWriter>& writer = + writers[i]; + // std::shared_ptr>> tmp_chan = + // paddle::framework::MakeChannel>(); + writer.Reset(tmp_channels[i].get()); + + auto& shard = _local_shards[i]; + for (auto it = shard.begin(); it != shard.end(); ++it) { + if (_value_accesor->SaveCache(it.value().data(), save_param, + cache_threshold)) { + std::string format_value = + _value_accesor->ParseToString(it.value().data(), it.value().size()); + std::pair pkv(it.key(), format_value.c_str()); + writer << pkv; + ++feasign_size; + } + } + + writer.Flush(); + writer.channel()->Close(); + } + LOG(INFO) << "SSDSparseTable cache KV save success to Channel feasigh size: " + << feasign_size + << " and start sparse cache data shuffle real local shard num: " + << _real_local_shard_num; + std::vector> local_datas; + for (size_t idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) { + paddle::framework::ChannelWriter>& writer = + writers[idx_shard]; + auto channel = writer.channel(); + std::vector>& data = datas[idx_shard]; + std::vector ars(shuffle_node_num); + while (channel->Read(data)) { + for (auto& t : data) { + auto pserver_id = + paddle::distributed::local_random_engine()() % shuffle_node_num; + if (pserver_id != _shard_idx) { + ars[pserver_id] << t; + } else { + local_datas.emplace_back(std::move(t)); + } + } + std::vector> total_status; + std::vector send_data_size(shuffle_node_num, 0); + std::vector send_index(shuffle_node_num); + for (int i = 0; i < shuffle_node_num; ++i) { + send_index[i] = i; + } + std::random_shuffle(send_index.begin(), send_index.end()); + for (auto index = 0u; index < shuffle_node_num; ++index) { + int i = send_index[index]; + if (i == _shard_idx) { + continue; + } + if (ars[i].Length() == 0) { + continue; + } + std::string msg(ars[i].Buffer(), ars[i].Length()); + auto ret = send_msg_func(101, i, msg); + total_status.push_back(std::move(ret)); + send_data_size[i] += ars[i].Length(); + } + for (auto& t : total_status) { + t.wait(); + } + ars.clear(); + ars = std::vector(shuffle_node_num); + data = std::vector>(); + } + } + shuffled_channel->Write(std::move(local_datas)); + LOG(INFO) << "cache shuffle finished"; + return 0; +} + +int32_t SSDSparseTable::SaveCache( + const std::string& path, const std::string& param, + paddle::framework::Channel>& + shuffled_channel) { + if (_shard_idx >= _config.sparse_table_cache_file_num()) { + return 0; + } + int save_param = atoi(param.c_str()); // batch_model:0 xbox:1 + size_t file_start_idx = _avg_local_shard_num * _shard_idx; + std::string table_path = paddle::string::format_string( + "%s/%03d_cache/", path.c_str(), _config.table_id()); + _afs_client.remove(paddle::string::format_string( + "%s/part-%03d", table_path.c_str(), _shard_idx)); + uint32_t feasign_size = 0; + FsChannelConfig channel_config; + // not compress cache model + channel_config.path = paddle::string::format_string( + "%s/part-%03d", table_path.c_str(), _shard_idx); + channel_config.converter = _value_accesor->Converter(save_param).converter; + channel_config.deconverter = + _value_accesor->Converter(save_param).deconverter; + auto write_channel = _afs_client.open_w(channel_config, 1024 * 1024 * 40); + std::vector> data; + bool is_write_failed = false; + shuffled_channel->Close(); + while (shuffled_channel->Read(data)) { + for (auto& t : data) { + ++feasign_size; + if (0 != + write_channel->write_line(paddle::string::format_string( + "%lu %s", t.first, t.second.c_str()))) { + LOG(ERROR) << "Cache Table save failed, " + "path:" + << channel_config.path << ", retry it!"; + is_write_failed = true; + break; + } + } + data = std::vector>(); + } + if (is_write_failed) { + _afs_client.remove(channel_config.path); + } + write_channel->close(); + LOG(INFO) << "SSDSparseTable cache save success, feasign: " << feasign_size + << ", path: " << channel_config.path; + shuffled_channel->Open(); + return feasign_size; +} + +int32_t SSDSparseTable::Load(const std::string& path, + const std::string& param) { + return MemorySparseTable::Load(path, param); +} + +//加载path目录下数据[start_idx, end_idx) +int32_t SSDSparseTable::Load(size_t start_idx, size_t end_idx, + const std::vector& file_list, + const std::string& param) { + if (start_idx >= file_list.size()) { + return 0; + } + int load_param = atoi(param.c_str()); + size_t feature_value_size = + _value_accesor->GetAccessorInfo().size / sizeof(float); + size_t mf_value_size = + _value_accesor->GetAccessorInfo().mf_size / sizeof(float); + + end_idx = + end_idx < _sparse_table_shard_num ? end_idx : _sparse_table_shard_num; + int thread_num = (end_idx - start_idx) < 20 ? (end_idx - start_idx) : 20; + omp_set_num_threads(thread_num); +#pragma omp parallel for schedule(dynamic) + for (size_t i = start_idx; i < end_idx; ++i) { + FsChannelConfig channel_config; + channel_config.path = file_list[i]; + channel_config.converter = _value_accesor->Converter(load_param).converter; + channel_config.deconverter = + _value_accesor->Converter(load_param).deconverter; + + int retry_num = 0; + int err_no = 0; + bool is_read_failed = false; + std::vector> ssd_keys; + std::vector> ssd_values; + std::vector tmp_key; + ssd_keys.reserve(FLAGS_pserver_load_batch_size); + ssd_values.reserve(FLAGS_pserver_load_batch_size); + tmp_key.reserve(FLAGS_pserver_load_batch_size); + do { + ssd_keys.clear(); + ssd_values.clear(); + tmp_key.clear(); + err_no = 0; + is_read_failed = false; + std::string line_data; + auto read_channel = _afs_client.open_r(channel_config, 0, &err_no); + char* end = NULL; + int local_shard_id = i % _avg_local_shard_num; + auto& shard = _local_shards[local_shard_id]; + float data_buffer[FLAGS_pserver_load_batch_size * feature_value_size]; + float* data_buffer_ptr = data_buffer; + uint64_t mem_count = 0; + uint64_t ssd_count = 0; + uint64_t mem_mf_count = 0; + uint64_t ssd_mf_count = 0; + try { + while (read_channel->read_line(line_data) == 0 && + line_data.size() > 1) { + uint64_t key = std::strtoul(line_data.data(), &end, 10); + if (FLAGS_pserver_open_strict_check) { + if (key % _sparse_table_shard_num != i) { + LOG(WARNING) << "SSDSparseTable key:" << key + << " not match shard," + << " file_idx:" << i + << " shard num:" << _sparse_table_shard_num + << " file:" << channel_config.path; + continue; + } + } + int value_size = + _value_accesor->ParseFromString(++end, data_buffer_ptr); + // ssd or mem + if (_value_accesor->SaveSSD(data_buffer_ptr)) { + tmp_key.emplace_back(key); + ssd_keys.emplace_back( + std::make_pair((char*)&tmp_key.back(), sizeof(uint64_t))); + ssd_values.emplace_back(std::make_pair((char*)data_buffer_ptr, + value_size * sizeof(float))); + data_buffer_ptr += feature_value_size; + if (ssd_keys.size() == FLAGS_pserver_load_batch_size) { + _db->put_batch(local_shard_id, ssd_keys, ssd_values, + ssd_keys.size()); + ssd_keys.clear(); + ssd_values.clear(); + tmp_key.clear(); + data_buffer_ptr = data_buffer; + } + ssd_count++; + if (value_size > feature_value_size - mf_value_size) { + ssd_mf_count++; + } + } else { + auto& value = shard[key]; + value.resize(value_size); + _value_accesor->ParseFromString(end, value.data()); + mem_count++; + if (value_size > feature_value_size - mf_value_size) { + mem_mf_count++; + } + } + } + // last batch + if (ssd_keys.size() > 0) { + _db->put_batch(local_shard_id, ssd_keys, ssd_values, ssd_keys.size()); + } + read_channel->close(); + if (err_no == -1) { + ++retry_num; + is_read_failed = true; + LOG(ERROR) << "SSDSparseTable load failed after read, retry it! path:" + << channel_config.path << " , retry_num=" << retry_num; + continue; + } + + _db->flush(local_shard_id); + LOG(INFO) << "Table>> load done. ALL[" << mem_count + ssd_count + << "] MEM[" << mem_count << "] MEM_MF[" << mem_mf_count + << "] SSD[" << ssd_count << "] SSD_MF[" << ssd_mf_count + << "]."; + } catch (...) { + ++retry_num; + is_read_failed = true; + LOG(ERROR) << "SSDSparseTable load failed after read, retry it! path:" + << channel_config.path << " , retry_num=" << retry_num; + } + } while (is_read_failed); + } + LOG(INFO) << "load num:" << LocalSize(); + LOG(INFO) << "SSDSparseTable load success, path from " << file_list[start_idx] + << " to " << file_list[end_idx - 1]; + + _cache_tk_size = LocalSize() * _config.sparse_table_cache_rate(); + return 0; +} + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.h b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h new file mode 100644 index 0000000000000000000000000000000000000000..2a43a27c229d12298455afab01a7b112d7c2b1d9 --- /dev/null +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "gflags/gflags.h" +#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h" +#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h" + +namespace paddle { +namespace distributed { + +class SSDSparseTable : public MemorySparseTable { + public: + typedef SparseTableShard shard_type; + SSDSparseTable() {} + virtual ~SSDSparseTable() {} + + int32_t Initialize() override; + int32_t InitializeShard() override; + + // exchange data + int32_t UpdateTable(); + + int32_t Pull(TableContext& context) override { + CHECK(context.value_type == Sparse); + float* pull_values = context.pull_context.values; + const PullSparseValue& pull_value = context.pull_context.pull_value; + return PullSparse(pull_values, pull_value.feasigns_, pull_value.numel_); + } + + int32_t Push(TableContext& context) override { + const uint64_t* keys = context.push_context.keys; + const float* values = context.push_context.values; + size_t num = context.num; + return PushSparse(keys, values, num); + } + + virtual int32_t PullSparse(float* pull_values, const uint64_t* keys, + size_t num); + virtual int32_t PushSparse(const uint64_t* keys, const float* values, + size_t num); + + int32_t Flush() override { return 0; } + virtual int32_t Shrink(const std::string& param) override; + virtual void Clear() override { + for (size_t i = 0; i < _real_local_shard_num; ++i) { + _local_shards[i].clear(); + } + } + + virtual int32_t Save(const std::string& path, + const std::string& param) override; + virtual int32_t SaveCache( + const std::string& path, const std::string& param, + paddle::framework::Channel>& + shuffled_channel) override; + virtual double GetCacheThreshold() override { return _local_show_threshold; } + virtual int64_t CacheShuffle( + const std::string& path, const std::string& param, double cache_threshold, + std::function(int msg_type, int to_pserver_id, + std::string& msg)> + send_msg_func, + paddle::framework::Channel>& + shuffled_channel, + const std::vector& table_ptrs) override; + //加载path目录下数据 + virtual int32_t Load(const std::string& path, + const std::string& param) override; + //加载path目录下数据[start_idx, end_idx) + virtual int32_t Load(size_t start_idx, size_t end_idx, + const std::vector& file_list, + const std::string& param); + int64_t LocalSize(); + + private: + RocksDBHandler* _db; + int64_t _cache_tk_size; + double _local_show_threshold{0.0}; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/ps/table/table.cc b/paddle/fluid/distributed/ps/table/table.cc index 333008482f167cb1365c7e7db33becc83ea3e264..5eb38d9c400b01d8ac80419f7fee1c02eb957e83 100644 --- a/paddle/fluid/distributed/ps/table/table.cc +++ b/paddle/fluid/distributed/ps/table/table.cc @@ -25,6 +25,7 @@ #include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h" #include "paddle/fluid/distributed/ps/table/memory_sparse_table.h" #include "paddle/fluid/distributed/ps/table/sparse_accessor.h" +#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h" #include "paddle/fluid/distributed/ps/table/tensor_accessor.h" #include "paddle/fluid/distributed/ps/table/tensor_table.h" @@ -37,6 +38,7 @@ REGISTER_PSCORE_CLASS(Table, TensorTable); REGISTER_PSCORE_CLASS(Table, DenseTensorTable); REGISTER_PSCORE_CLASS(Table, GlobalStepTable); REGISTER_PSCORE_CLASS(Table, MemorySparseTable); +REGISTER_PSCORE_CLASS(Table, SSDSparseTable); REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable); REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor); REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor); diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index c515e03e3fa4854ba44586afe9446644713e27a5..48fda782d489fff33e18ebfc902dade58cabc752 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -24,6 +24,7 @@ #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h" +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/device_context.h" @@ -107,6 +108,26 @@ class Table { // 指定保存路径 virtual int32_t Save(const std::string &path, const std::string &converter) = 0; + // for cache + virtual int32_t SaveCache( + const std::string &path, const std::string ¶m, + paddle::framework::Channel> + &shuffled_channel) { + return 0; + } + + virtual int64_t CacheShuffle( + const std::string &path, const std::string ¶m, double cache_threshold, + std::function(int msg_type, int to_pserver_id, + std::string &msg)> + send_msg_func, + paddle::framework::Channel> + &shuffled_channel, + const std::vector
&table_ptrs) { + return 0; + } + + virtual double GetCacheThreshold() { return 0.0; } virtual int32_t SetShard(size_t shard_idx, size_t shard_num) { _shard_idx = shard_idx; diff --git a/paddle/fluid/distributed/ps/table/tensor_accessor.h b/paddle/fluid/distributed/ps/table/tensor_accessor.h index 60951598482ad7dff172468c82fad324e870a4f6..fad31d5df7f47f707d31e36c25642cf7795362d3 100644 --- a/paddle/fluid/distributed/ps/table/tensor_accessor.h +++ b/paddle/fluid/distributed/ps/table/tensor_accessor.h @@ -38,6 +38,12 @@ class CommMergeAccessor : public ValueAccessor { // param作为参数用于标识save阶段,如downpour的xbox与batch_model virtual bool Save(float * /*value*/, int /*param*/); + bool SaveCache(float *value, int param, double global_cache_threshold) { + return false; + } + + bool SaveSSD(float *value) { return false; } + // keys不存在时,为values生成随机值 virtual int32_t Create(float **value, size_t num); // 从values中选取到select_values中 diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.cc b/paddle/fluid/distributed/ps/wrapper/fleet.cc index 7bc50a868104a0d8d459a96b768e4f426630afe7..955ba75e672d17663f54d744f5e2516409a822b6 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.cc +++ b/paddle/fluid/distributed/ps/wrapper/fleet.cc @@ -754,6 +754,46 @@ std::future FleetWrapper::SendClientToClientMsg( return worker_ptr_->SendClient2ClientMsg(msg_type, to_client_id, msg); } +double FleetWrapper::GetCacheThreshold(int table_id) { + double cache_threshold = 0.0; + auto ret = worker_ptr_->Flush(); + ret.wait(); + ret = worker_ptr_->GetCacheThreshold(table_id, cache_threshold); + ret.wait(); + if (cache_threshold < 0) { + LOG(ERROR) << "get cache threshold failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + return cache_threshold; +} + +void FleetWrapper::CacheShuffle(int table_id, const std::string& path, + const int mode, const double cache_threshold) { + auto ret = worker_ptr_->CacheShuffle(table_id, path, std::to_string(mode), + std::to_string(cache_threshold)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "cache shuffle failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } +} + +int32_t FleetWrapper::SaveCache(int table_id, const std::string& path, + const int mode) { + auto ret = worker_ptr_->SaveCache(table_id, path, std::to_string(mode)); + ret.wait(); + int32_t feasign_cnt = ret.get(); + if (feasign_cnt == -1) { + LOG(ERROR) << "table save cache failed"; + sleep(sleep_seconds_before_fail_exit_); + exit(-1); + } + return feasign_cnt; +} + std::default_random_engine& FleetWrapper::LocalRandomEngine() { struct engine_wrapper_t { std::default_random_engine engine; diff --git a/paddle/fluid/distributed/ps/wrapper/fleet.h b/paddle/fluid/distributed/ps/wrapper/fleet.h index e6ec09a12637d9d8d6da18ce45d0fd70dd45db7c..ce109b63cce9c73b1ccf4af39d632871345eca19 100644 --- a/paddle/fluid/distributed/ps/wrapper/fleet.h +++ b/paddle/fluid/distributed/ps/wrapper/fleet.h @@ -259,6 +259,11 @@ class FleetWrapper { // for init worker void InitGFlag(const std::string& gflags); + double GetCacheThreshold(int table_id); + void CacheShuffle(int table_id, const std::string& path, const int mode, + const double cache_threshold); + int32_t SaveCache(int table_id, const std::string& path, const int mode); + static std::shared_ptr pserver_ptr_; static std::shared_ptr worker_ptr_; diff --git a/paddle/fluid/distributed/the_one_ps.proto b/paddle/fluid/distributed/the_one_ps.proto index 32bf9eaa5aa06b77f731ab420f1b3bad951cdf88..1b20aca85422c0c6368f0245ee12da4f131326cc 100644 --- a/paddle/fluid/distributed/the_one_ps.proto +++ b/paddle/fluid/distributed/the_one_ps.proto @@ -116,6 +116,10 @@ message TableParameter { optional TableType type = 7; optional bool compress_in_save = 8 [ default = false ]; optional GraphParameter graph_parameter = 9; + // for cache model + optional bool enable_sparse_table_cache = 10 [ default = true ]; + optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; + optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; } message TableAccessorParameter { diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc index 8d8301689521b80136b415ba253e12e9a88a6902..d35419e87f3a5f87d466a2add23a70a5ea09d29a 100644 --- a/paddle/fluid/pybind/fleet_py.cc +++ b/paddle/fluid/pybind/fleet_py.cc @@ -78,7 +78,11 @@ void BindDistFleetWrapper(py::module* m) { .def("set_clients", &FleetWrapper::SetClients) .def("get_client_info", &FleetWrapper::GetClientsInfo) .def("create_client2client_connection", - &FleetWrapper::CreateClient2ClientConnection); + &FleetWrapper::CreateClient2ClientConnection) + .def("client_flush", &FleetWrapper::ClientFlush) + .def("get_cache_threshold", &FleetWrapper::GetCacheThreshold) + .def("cache_shuffle", &FleetWrapper::CacheShuffle) + .def("save_cache", &FleetWrapper::SaveCache); } void BindPSHost(py::module* m) { diff --git a/paddle/utils/string/string_helper.h b/paddle/utils/string/string_helper.h index a02b313ef0eba61682188f65d3d6a03d432dc7fb..e6cb2e90b8fa1ab8ad14fd5601d04ad992c1bb94 100644 --- a/paddle/utils/string/string_helper.h +++ b/paddle/utils/string/string_helper.h @@ -100,6 +100,14 @@ inline int str_to_float(const char* str, float* v) { return index; } +inline float* str_to_float(std::string& str) { + return (float*)const_cast(str.c_str()); +} + +inline float* str_to_float(const char* str) { + return (float*)const_cast(str); +} + // checks whether the test string is a suffix of the input string. bool ends_with(std::string const& input, std::string const& test); diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py index 3186df7db581a54d0417b40892ea5f3e6c91721c..ef0fff82833612c08173f14154162c3fe4f77cce 100644 --- a/python/paddle/distributed/fleet/__init__.py +++ b/python/paddle/distributed/fleet/__init__.py @@ -77,6 +77,7 @@ stop_worker = fleet.stop_worker distributed_optimizer = fleet.distributed_optimizer save_inference_model = fleet.save_inference_model save_persistables = fleet.save_persistables +save_cache_model = fleet.save_cache_model load_model = fleet.load_model minimize = fleet.minimize distributed_model = fleet.distributed_model diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 4e975e74bdb14e741fff5787b51df9fbd7e61f14..a1c967ab0639c512b20b87d43e7c53791dc30b16 100755 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -869,6 +869,11 @@ class Fleet(object): self._runtime_handle._save_persistables(executor, dirname, main_program, mode) + @is_non_distributed_check + @inited_runtime_handler + def save_cache_model(self, dirname, **configs): + return self._runtime_handle._save_cache_model(dirname, **configs) + def shrink(self, threshold=None): self._runtime_handle._shrink(threshold) diff --git a/python/paddle/distributed/ps/the_one_ps.py b/python/paddle/distributed/ps/the_one_ps.py index 5be739785ff44ad8f2e72f69c39af46326cd5e06..c6df7559a22e8114d085af4de14a4db98d3de3e5 100755 --- a/python/paddle/distributed/ps/the_one_ps.py +++ b/python/paddle/distributed/ps/the_one_ps.py @@ -1315,6 +1315,30 @@ class TheOnePSRuntime(RuntimeBase): def _save_persistables(self, *args, **kwargs): self._ps_inference_save_persistables(*args, **kwargs) + def _save_cache_model(self, dirname, **kwargs): + mode = kwargs.get("mode", 0) + table_id = kwargs.get("table_id", 0) + self._worker.client_flush() + fleet.util.barrier() + cache_threshold = 0.0 + + if self.role_maker._is_first_worker(): + cache_threshold = self._worker.get_cache_threshold(table_id) + #check cache threshold right or not + fleet.util.barrier() + + if self.role_maker._is_first_worker(): + self._worker.cache_shuffle(table_id, dirname, mode, cache_threshold) + + fleet.util.barrier() + + feasign_num = -1 + if self.role_maker._is_first_worker(): + feasign_num = self._worker.save_cache(table_id, dirname, mode) + + fleet.util.barrier() + return feasign_num + def _load_sparse_params(self, dirname, context, main_program, mode): distributed_varnames = get_sparse_tablenames(self.origin_main_programs, True) diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index 2bd397b0ef3f531a30ac45288689d0897a310b23..be5118f0acc18ff055243cbbfd9aadeea7073099 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -339,5 +339,9 @@ class TestDistCTR2x2(FleetDistRunnerBase): if dirname: fleet.save_persistables(exe, dirname=dirname) + cache_dirname = os.getenv("SAVE_CACHE_DIRNAME", None) + if cache_dirname: + fleet.save_cache_model(cache_dirname) + if __name__ == "__main__": runtime_main(TestDistCTR2x2) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py index 59d196fdf55e57b3175b3deb6036f4b88b565d34..09d64a318d6d8d3f2fae108ce652129b1a681ec7 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py @@ -39,6 +39,8 @@ class TestDistMnistAsyncInMemoryDataset2x2(TestFleetBase): "http_proxy": "", "CPU_NUM": "2", "LOG_DIRNAME": "/tmp", + "SAVE_CACHE_DIRNAME": + "/tmp/TestDistMnistAsyncInMemoryDataset2x2/cache_model", "LOG_PREFIX": self.__class__.__name__, }