未验证 提交 cca57c4a 编写于 作者: Z zhaocaibei123 提交者: GitHub

Ssd sparse table (#41812)

* [cherry-pick2.3]fix compile bug of windows cuda11.5 (#41464)

cherry-pick

fix compile bug of windows cuda11.5 #41433

* fix bug of missing boost when compile cache.cc (#41449)

【chery-pick #41430】fix bug of random compile failure, due to incorrect compile order of dependencies

* Fix eager try catch (#41438) (#41477)

[Cherry-Pick]Fix eager try catch (#41438)

* Cherry-pick-PR41407, fix device_id bug for final_state op in multiprocess testcase (#41407) (#41475)

Cherry-pick PR #41407

* [BugFix] Add error hint for one_hot gpu version (#41335) (#41495)

* add one_hot gpu hint

* move allow_out_of_range judgement

* delete useless unittest

* fix bugs of reshape double grad infermeta (#41459) (#41493)

* [cherrypick-2.3] modify infer gpu memory strategy (#41427), remove cudnn_deterministic=True (#41341)  (#41491)
Co-authored-by: NJingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>

* [Cherry-pick][ROCm] fix dcu error in device event base, test=develop (#41523)

Cherry-pick of #41521

* [Cherry-Pick]Cherry pick PR41200, PR41474, PR41382 (#41509)

* Use `self`as a parameter of _hash_with_id function to avoid error caused by hash_id reuse (#41200)

* Add fill_constant_batch_size YAML and UT (#41474)

* Switch some dy2st UT to eager mode (#41382)

* Sitch some dy2st UT to eager mode

* Fix test_lstm and remove test_transformer

* Run test_resnet_v2 in old dy mode

* Unittest recover (#41431)

* update name

* update name

* fix test

* fix fleet bind

* update name

* update name

* fix test

* fix gpups wrapper

* remove Push/Pull/Load/Save with context in client and wrapper base class

* fix

* fix

* remove some interface

* fix

* remove

* code style

* recover

* fix

* remove code unused

* remove some unused table & accessor & CommonDenseTable => MemoryDenseTable

* fix

* fix

* fix

* recover

* remove unused code

* recover unittest

* fix

* remove

* fix

* remove code unuseful

* remove

* fix

* recover

* remove
Co-authored-by: Nesythan <esythan@126.com>

* add ssd sparse table

* fix

* add cache shuffle

* fix

* fix

* fix

* fix

* fix

* fix

* add unit test

* fix
Co-authored-by: zhouweiwei2014's avatarZhou Wei <1183042833@qq.com>
Co-authored-by: NSing_chan <51314274+betterpig@users.noreply.github.com>
Co-authored-by: N0x45f <23097963+0x45f@users.noreply.github.com>
Co-authored-by: Npangyoki <pangyoki@126.com>
Co-authored-by: NSiming Dai <908660116@qq.com>
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: NZhang Jun <ewalker@live.cn>
Co-authored-by: NJingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com>
Co-authored-by: NQi Li <qili93@qq.com>
Co-authored-by: Nesythan <esythan@126.com>
上级 4fd190d5
......@@ -357,10 +357,8 @@ if (WITH_PSCORE)
include(external/libmct) # download, build, install libmct
list(APPEND third_party_deps extern_libmct)
if (WITH_HETERPS)
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
include(external/rocksdb) # download, build, install libmct
list(APPEND third_party_deps extern_rocksdb)
endif()
if(WITH_XBYAK)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <unordered_map>
namespace paddle {
namespace distributed {
class TopkCalculator {
public:
TopkCalculator(int shard_num, size_t k)
: _shard_num(shard_num), _total_max_size(k) {
_shard_max_size = _total_max_size / shard_num;
_shard_max_size = _shard_max_size > 1 ? _shard_max_size : 1;
for (int i = 0; i < shard_num; ++i) {
_mpq.emplace(i, std::priority_queue<double, std::vector<double>,
std::greater<double>>());
}
}
~TopkCalculator() {}
bool push(int shard_id, double value) {
if (_mpq.find(shard_id) == _mpq.end()) {
return false;
}
auto &pq = _mpq[shard_id];
if (pq.size() < _shard_max_size) {
pq.push(value);
} else {
if (pq.top() < value) {
pq.pop();
pq.push(value);
}
}
return true;
}
// TODO 再进行一次堆排序merge各个shard的结果
int top() {
double total = 0;
for (const auto &item : _mpq) {
auto &pq = item.second;
if (!pq.empty()) {
total += pq.top();
}
}
return total / _shard_num;
}
private:
std::unordered_map<int, std::priority_queue<double, std::vector<double>,
std::greater<double>>>
_mpq;
int _shard_num;
size_t _total_max_size;
size_t _shard_max_size;
};
} // namespace distributed
} // namespace paddle
set(BRPC_SRCS ps_client.cc server.cc)
set_source_files_properties(${BRPC_SRCS})
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
if(WITH_HETERPS)
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context rocksdb)
else()
set(BRPC_DEPS brpc ssl crypto protobuf gflags glog zlib leveldb snappy gflags glog device_context)
endif()
brpc_library(sendrecv_rpc SRCS
${BRPC_SRCS}
......
......@@ -429,6 +429,82 @@ std::future<int32_t> BrpcPsClient::Save(uint32_t table_id,
return SendSaveCmd(table_id, PS_SAVE_ONE_TABLE, {epoch, mode});
}
std::future<int32_t> BrpcPsClient::CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle";
return SendSaveCmd(table_id, PS_CACHE_SHUFFLE, {path, mode, cache_threshold});
}
std::future<int32_t> BrpcPsClient::CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(1) << "BrpcPsClient send cmd for cache shuffle multi table one path";
std::vector<std::string> param;
param.push_back(path);
param.push_back(mode);
param.push_back(cache_threshold);
for (size_t i = 0; i < tables.size(); i++) {
param.push_back(std::to_string(tables[i]));
}
return SendSaveCmd(0, PS_CACHE_SHUFFLE, param);
}
std::future<int32_t> BrpcPsClient::SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
return SendSaveCmd(table_id, PS_SAVE_ONE_CACHE_TABLE, {path, mode});
}
std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
int cmd_id = PS_GET_CACHE_THRESHOLD;
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
ret = -1;
break;
}
std::string cur_res = closure->get_response(i, cmd_id);
cache_thresholds[i] = std::stod(cur_res);
}
double sum_threshold = 0.0;
int count = 0;
for (auto t : cache_thresholds) {
if (t >= 0) {
sum_threshold += t;
++count;
}
}
if (count == 0) {
cache_threshold = 0;
} else {
cache_threshold = sum_threshold / count;
}
VLOG(1) << "client get cache threshold: " << cache_threshold;
closure->set_promise_value(ret);
});
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
std::future<int> fut = promise->get_future();
for (size_t i = 0; i < request_call_num; ++i) {
closure->request(i)->set_cmd_id(cmd_id);
closure->request(i)->set_table_id(table_id);
closure->request(i)->set_client_id(_client_id);
PsService_Stub rpc_stub(GetCmdChannel(i));
closure->cntl(i)->set_timeout_ms(10800000);
rpc_stub.service(closure->cntl(i), closure->request(i),
closure->response(i), closure);
}
return fut;
}
std::future<int32_t> BrpcPsClient::Clear() {
return SendCmd(-1, PS_CLEAR_ALL_TABLE, {});
}
......
......@@ -219,6 +219,20 @@ class BrpcPsClient : public PSClient {
virtual int32_t RecvAndSaveTable(const uint64_t table_id,
const std::string &path);
std::future<int32_t> CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) override;
std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold);
std::future<int32_t> SaveCache(uint32_t table_id, const std::string &path,
const std::string &mode) override;
std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) override;
void PrintQueueSize();
void PrintQueueSizeThread();
......
......@@ -28,6 +28,13 @@ class RpcController;
} // namespace protobuf
} // namespace google
DEFINE_int32(pserver_timeout_ms_s2s, 10000,
"pserver request server timeout_ms");
DEFINE_int32(pserver_connect_timeout_ms_s2s, 10000,
"pserver connect server timeout_ms");
DEFINE_string(pserver_connection_type_s2s, "pooled",
"pserver connection_type[pooled:single]");
namespace paddle {
namespace distributed {
......@@ -93,6 +100,84 @@ uint64_t BrpcPsServer::Start(const std::string &ip, uint32_t port) {
return host.rank;
}
int32_t BrpcPsServer::StartS2S() {
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.timeout_ms = FLAGS_pserver_timeout_ms_s2s;
options.connection_type = FLAGS_pserver_connection_type_s2s;
options.connect_timeout_ms = FLAGS_pserver_connect_timeout_ms_s2s;
options.max_retry = 3;
std::vector<PSHost> pserver_list = _environment->GetPsServers();
_pserver_channels.resize(pserver_list.size());
VLOG(2) << "pserver start s2s server_list size: " << _pserver_channels.size();
std::ostringstream os;
std::string server_ip_port;
for (size_t i = 0; i < pserver_list.size(); ++i) {
server_ip_port.assign(pserver_list[i].ip.c_str());
server_ip_port.append(":");
server_ip_port.append(std::to_string(pserver_list[i].port));
_pserver_channels[i].reset(new brpc::Channel());
if (_pserver_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "pserver connect to pserver:" << server_ip_port
<< " Failed!";
}
os << server_ip_port << ",";
}
LOG(INFO) << "pserver connect success: " << os.str();
return 0;
}
std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) {
auto promise = std::make_shared<std::promise<int32_t>>();
std::future<int> fut = promise->get_future();
if (to_pserver_id >= _pserver_channels.size()) {
LOG(FATAL) << "to_pserver_id is out of range pservers, which size is "
<< _pserver_channels.size();
promise->set_value(-1);
return fut;
}
auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourPServerBrpcClosure *)done;
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
closure->add_promise(promise);
closure->request(0)->set_cmd_id(101);
closure->request(0)->set_client_id(_rank);
closure->request(0)->set_table_id(0);
closure->request(0)->set_data(msg);
PsService_Stub rpc_stub(_pserver_channels[to_pserver_id].get());
rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
closure);
return fut;
}
int32_t BrpcPsServer::ReceiveFromPServer(int msg_type, int pserver_id,
const std::string &msg) {
if (msg.length() == 0) {
LOG(WARNING) << "SERVER>>RESPONSE>>msg = 0 Finish S2S Response";
return 0;
}
paddle::framework::BinaryArchive ar;
ar.SetReadBuffer(const_cast<char *>(msg.c_str()), msg.length(), nullptr);
if (ar.Cursor() == ar.Finish()) {
LOG(WARNING) << "SERVER>>RESPONSE ar = 0>> Finish S2S Response";
return 0;
}
std::vector<std::pair<uint64_t, std::string>> data;
while (ar.Cursor() < ar.Finish()) {
data.push_back(ar.Get<std::pair<uint64_t, std::string>>());
}
CHECK(ar.Cursor() == ar.Finish());
this->_shuffled_ins->Write(std::move(data));
return 0;
}
int32_t BrpcPsServer::Port() { return _server.listen_address().port; }
int32_t BrpcPsService::Initialize() {
......@@ -117,6 +202,14 @@ int32_t BrpcPsService::Initialize() {
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::StartProfiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
// for save cache
_service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
&BrpcPsService::SaveCacheTable;
_service_handler_map[PS_GET_CACHE_THRESHOLD] =
&BrpcPsService::GetCacheThreshold;
_service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense");
......@@ -168,19 +261,29 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
response->set_err_msg("");
auto *table = _server->GetTable(request->table_id());
brpc::Controller *cntl = static_cast<brpc::Controller *>(cntl_base);
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
if (request->cmd_id() < 100) {
auto itr = _service_handler_map.find(request->cmd_id());
if (itr == _service_handler_map.end()) {
std::string err_msg(
"undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
err_msg.append(std::to_string(request->cmd_id()));
set_response_code(*response, -1, err_msg.c_str());
return;
}
serviceHandlerFunc handler_func = itr->second;
int service_ret = (this->*handler_func)(table, *request, *response, cntl);
if (service_ret != 0) {
response->set_err_code(service_ret);
response->set_err_msg("server internal error");
}
} else {
int service_ret = _server->HandlePServer2PServerMsg(
request->cmd_id(), request->client_id(), request->data());
if (service_ret != 0) {
response->set_err_code(-1);
response->set_err_msg("handle_pserver2pserver_msg failed");
}
}
}
......@@ -561,6 +664,90 @@ int32_t BrpcPsService::SaveAllTable(Table *table,
return 0;
}
int32_t BrpcPsService::SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
response, -1,
"PsRequestMessage.datas is requeired at least 3, path&mode");
return -1;
}
table->Flush();
int32_t feasign_size = 0;
// if (_server->_shuffled_ins->size() <= 0) {
// LOG(WARNING) << "shuffled ins size <= 0";
//}
feasign_size = table->SaveCache(request.params(0), request.params(1),
_server->_shuffled_ins);
if (feasign_size < 0) {
set_response_code(response, -1, "table save failed");
return -1;
}
return feasign_size;
}
int32_t BrpcPsService::CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
// start cache shuffle
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 3) {
set_response_code(response, -1,
"PsRequestMessage.datas is requeired at least 3, "
"path&mode&cache_threshold");
return -1;
}
table->Flush();
double cache_threshold = std::stod(request.params(2));
LOG(INFO) << "cache threshold for cache shuffle: " << cache_threshold;
// auto shuffled_ins = paddle::ps::make_channel<std::pair<uint64_t,
// std::string>>();
// shuffled_ins->set_block_size(80000);
_server->StartS2S();
std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
const std::string &msg)>
send_msg_func = [this](int msg_type, int to_pserver_id,
const std::string &msg) -> std::future<int32_t> {
return this->_server->SendPServer2PServerMsg(msg_type, to_pserver_id, msg);
};
std::vector<Table *> table_ptrs;
for (size_t i = 3; i < request.params_size(); ++i) {
int table_id = std::stoi(request.params(i));
Table *table_ptr = _server->GetTable(table_id);
table_ptrs.push_back(table_ptr);
}
if (table_ptrs.empty()) {
table_ptrs.push_back(table);
}
table->CacheShuffle(request.params(0), request.params(1), cache_threshold,
send_msg_func, _server->_shuffled_ins, table_ptrs);
return 0;
}
int32_t BrpcPsService::GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->Flush();
double cache_threshold = 0.0;
cache_threshold = table->GetCacheThreshold();
if (cache_threshold < 0) {
LOG(WARNING) << "wrong threshold: " << cache_threshold;
}
std::stringstream ss;
ss << std::setprecision(15) << cache_threshold;
std::string cache_threshold_str = ss.str();
response.set_data(cache_threshold_str);
return 0;
}
int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
......
......@@ -53,6 +53,12 @@ class BrpcPsServer : public PSServer {
}
int32_t Port();
virtual int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type, int pserver_id,
const std::string &msg) override;
private:
virtual int32_t Initialize();
mutable std::mutex mutex_;
......@@ -123,6 +129,16 @@ class BrpcPsService : public PsBaseService {
int32_t PushGlobalStep(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t CacheShuffle(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
......
......@@ -198,6 +198,7 @@ class PSClient {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandleClient2ClientMsg(int msg_type, int from_client_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
......@@ -239,6 +240,46 @@ class PSClient {
const float **update_values,
size_t num) = 0;
// for save cache
virtual std::future<int32_t> CacheShuffle(
uint32_t table_id, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CacheShuffleMultiTable(
std::vector<int> tables, const std::string &path, const std::string &mode,
const std::string &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> SaveCache(uint32_t table_id,
const std::string &path,
const std::string &mode) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
protected:
virtual int32_t Initialize() = 0;
size_t _client_id;
......
......@@ -65,6 +65,8 @@ enum PsCmdID {
PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46;
// pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
}
message PsRequestMessage {
......
......@@ -67,6 +67,8 @@ int32_t PSServer::Configure(
_config = config.server_param();
_rank = server_rank;
_environment = &env;
_shuffled_ins =
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>();
size_t shard_num = env.GetPsServers().size();
const auto &downpour_param = _config.downpour_server_param();
......
......@@ -89,6 +89,45 @@ class PSServer {
return &_table_map;
}
// for cache
virtual int32_t StartS2S() { return 0; }
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) {
LOG(FATAL) << "NotImplementError: PSServer::send_pserver2pserver_msg";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
typedef std::function<int32_t(int, int, const std::string &)> MsgHandlerFunc;
virtual int RegistePServer2PServerMsgHandler(int msg_type,
MsgHandlerFunc handler) {
_msg_handler_map[msg_type] = handler;
return 0;
}
virtual int HandlePServer2PServerMsg(int msg_type, int from_pserver_id,
const std::string &msg) {
auto itr = _msg_handler_map.find(msg_type);
if (itr == _msg_handler_map.end()) {
if (msg_type == 101) {
return ReceiveFromPServer(msg_type, from_pserver_id, msg);
} else {
LOG(WARNING) << "unknown pserver2pserver_msg type:" << msg_type;
return -1;
}
}
return itr->second(msg_type, from_pserver_id, msg);
}
virtual int32_t ReceiveFromPServer(int msg_type, int pserver_id,
const std::string &msg) {
LOG(FATAL) << "NotImplementError::PSServer::ReceiveFromPServer";
return -1;
}
paddle::framework::Channel<std::pair<uint64_t, std::string>> _shuffled_ins;
protected:
virtual int32_t Initialize() = 0;
......@@ -97,6 +136,7 @@ class PSServer {
ServerParameter _config;
PSEnvironment *_environment;
std::unordered_map<uint32_t, std::shared_ptr<Table>> _table_map;
std::unordered_map<int32_t, MsgHandlerFunc> _msg_handler_map;
protected:
std::shared_ptr<framework::Scope> scope_;
......
......@@ -18,17 +18,12 @@ include_directories(${PADDLE_LIB_THIRD_PARTY_PATH}libmct/src/extern_libmct/libmc
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
set(EXTERN_DEP "")
if(WITH_HETERPS)
set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc)
set(EXTERN_DEP rocksdb)
else()
set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc)
endif()
set(TABLE_SRC memory_dense_table.cc barrier_table.cc common_graph_table.cc)
#set(EXTERN_DEP rocksdb)
cc_library(common_table SRCS ${TABLE_SRC} DEPS ${TABLE_DEPS}
${RPC_DEPS} graph_edge graph_node device_context string_helper
simple_threadpool xxhash generator ${EXTERN_DEP})
simple_threadpool xxhash generator)
set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......@@ -41,13 +36,13 @@ set_source_files_properties(ctr_double_accessor.cc PROPERTIES COMPILE_FLAGS ${DI
set_source_files_properties(ctr_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(sparse_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ssd_sparse_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(sparse_sgd_rule SRCS sparse_sgd_rule.cc DEPS ${TABLE_DEPS} ps_framework_proto)
cc_library(ctr_accessor SRCS ctr_accessor.cc ctr_double_accessor.cc sparse_accessor.cc DEPS ${TABLE_DEPS} ps_framework_proto sparse_sgd_rule)
cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table)
set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)
cc_library(sparse_table SRCS memory_sparse_table.cc ssd_sparse_table.cc memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} fs afs_wrapper ctr_accessor common_table rocksdb)
cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
cc_library(table SRCS table.cc DEPS sparse_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)
target_link_libraries(table -fopenmp)
......@@ -117,6 +117,11 @@ class ValueAccessor {
virtual bool Save(float* value, int param) = 0;
// update delta_score and unseen_days after save
virtual void UpdateStatAfterSave(float* value, int param) {}
// 判断该value是否保存到ssd
virtual bool SaveSSD(float* value) = 0;
//
virtual bool SaveCache(float* value, int param,
double global_cache_threshold) = 0;
// keys不存在时,为values生成随机值
virtual int32_t Create(float** value, size_t num) = 0;
......
......@@ -38,13 +38,13 @@
#include <vector>
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/common_table.h"
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/distributed/ps/table/graph/class_macro.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/core/utils/rw_lock.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h"
#endif
namespace paddle {
......
......@@ -34,6 +34,8 @@ int CtrCommonAccessor::Initialize() {
common_feature_value.embedx_dim = _config.embedx_dim();
common_feature_value.embedx_sgd_dim = _embedx_sgd_rule->Dim();
_show_click_decay_rate = _config.ctr_accessor_param().show_click_decay_rate();
_ssd_unseenday_threshold =
_config.ctr_accessor_param().ssd_unseenday_threshold();
if (_config.ctr_accessor_param().show_scale()) {
_show_scale = true;
......@@ -77,6 +79,25 @@ bool CtrCommonAccessor::Shrink(float* value) {
return false;
}
bool CtrCommonAccessor::SaveCache(float* value, int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(common_feature_value.Show(value),
common_feature_value.Click(value)) >= base_threshold &&
common_feature_value.UnseenDays(value) <= delta_keep_days) {
return common_feature_value.Show(value) > global_cache_threshold;
}
return false;
}
bool CtrCommonAccessor::SaveSSD(float* value) {
if (common_feature_value.UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
bool CtrCommonAccessor::Save(float* value, int param) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
......
......@@ -148,6 +148,9 @@ class CtrCommonAccessor : public ValueAccessor {
// param = 1, save delta feature
// param = 2, save xbox base feature
bool Save(float* value, int param) override;
bool SaveCache(float* value, int param,
double global_cache_threshold) override;
bool SaveSSD(float* value) override;
// update delta_score and unseen_days after save
void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值
......
......@@ -74,25 +74,26 @@ bool CtrDoubleAccessor::Shrink(float* value) {
}
return false;
}
bool CtrDoubleAccessor::SaveSSD(float* value) {
if (CtrDoubleFeatureValue::UnseenDays(value) > _ssd_unseenday_threshold) {
return true;
}
return false;
}
// bool CtrDoubleAccessor::save_cache(
// float* value, int param, double global_cache_threshold) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
// if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
// CtrDoubleFeatureValue::Click(value)) >= base_threshold
// && CtrDoubleFeatureValue::UnseenDays(value) <=
// delta_keep_days) {
// return CtrDoubleFeatureValue::Show(value) >
// global_cache_threshold;
// }
// return false;
// }
bool CtrDoubleAccessor::SaveCache(float* value, int param,
double global_cache_threshold) {
auto base_threshold = _config.ctr_accessor_param().base_threshold();
auto delta_keep_days = _config.ctr_accessor_param().delta_keep_days();
if (ShowClickScore(CtrDoubleFeatureValue::Show(value),
CtrDoubleFeatureValue::Click(value)) >= base_threshold &&
CtrDoubleFeatureValue::UnseenDays(value) <= delta_keep_days) {
return CtrDoubleFeatureValue::Show(value) > global_cache_threshold;
}
return false;
}
bool CtrDoubleAccessor::Save(float* value, int param) {
// auto base_threshold = _config.ctr_accessor_param().base_threshold();
// auto delta_threshold = _config.ctr_accessor_param().delta_threshold();
......
......@@ -167,6 +167,8 @@ class CtrDoubleAccessor : public ValueAccessor {
// param = 1, save delta feature
// param = 3, save all feature with time decay
virtual bool Save(float* value, int param) override;
bool SaveCache(float* value, int param,
double global_cache_threshold) override;
// update delta_score and unseen_days after save
virtual void UpdateStatAfterSave(float* value, int param) override;
// 判断该value是否保存到ssd
......
......@@ -11,9 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_HETERPS
#include <glog/logging.h>
#include <rocksdb/db.h>
#include <rocksdb/filter_policy.h>
......@@ -154,6 +153,5 @@ class RocksDBHandler {
std::vector<rocksdb::ColumnFamilyHandle*> _handles;
rocksdb::DB* _db;
};
}
}
#endif
} // distributed
} // paddle
......@@ -23,14 +23,17 @@
#include "glog/logging.h"
#include "paddle/fluid/platform/enforce.h"
DEFINE_bool(pserver_print_missed_key_num_every_push, false,
"pserver_print_missed_key_num_every_push");
DEFINE_bool(pserver_create_value_when_push, true,
"pserver create value when push");
DEFINE_bool(pserver_enable_create_feasign_randomly, false,
"pserver_enable_create_feasign_randomly");
DEFINE_int32(pserver_table_save_max_retry, 3, "pserver_table_save_max_retry");
namespace paddle {
namespace distributed {
// TODO(zhaocaibei123): configure
bool FLAGS_pserver_create_value_when_push = true;
int FLAGS_pserver_table_save_max_retry = 3;
bool FLAGS_pserver_enable_create_feasign_randomly = false;
int32_t MemorySparseTable::Initialize() {
_shards_task_pool.resize(_task_pool_size);
for (int i = 0; i < _shards_task_pool.size(); ++i) {
......@@ -142,7 +145,7 @@ int32_t MemorySparseTable::Load(const std::string& path,
LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
<< channel_config.path << " , retry_num=" << retry_num;
}
if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
exit(-1);
}
......@@ -213,7 +216,7 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
<< file_list[file_start_idx + i]
<< " , retry_num=" << retry_num;
}
if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
exit(-1);
}
......@@ -293,7 +296,7 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
if (retry_num > paddle::distributed::FLAGS_pserver_table_save_max_retry) {
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable save prefix failed reach max limit!";
exit(-1);
}
......
......@@ -62,9 +62,11 @@ class MemorySparseTable : public Table {
int32_t InitializeShard() override { return 0; }
int32_t InitializeValue();
int32_t Load(const std::string& path, const std::string& param) override;
virtual int32_t Load(const std::string& path,
const std::string& param) override;
int32_t Save(const std::string& path, const std::string& param) override;
virtual int32_t Save(const std::string& path,
const std::string& param) override;
int32_t LoadLocalFS(const std::string& path, const std::string& param);
int32_t SaveLocalFS(const std::string& path, const std::string& param,
......@@ -83,7 +85,7 @@ class MemorySparseTable : public Table {
int32_t PushSparse(const uint64_t* keys, const float** values, size_t num);
int32_t Flush() override;
int32_t Shrink(const std::string& param) override;
virtual int32_t Shrink(const std::string& param) override;
void Clear() override;
void* GetShard(size_t shard_idx) override {
......
......@@ -135,6 +135,11 @@ class SparseAccessor : public ValueAccessor {
// param = 1, save delta feature
// param = 2, save xbox base feature
bool Save(float* value, int param) override;
bool SaveCache(float* value, int param, double global_cache_threshold) {
return false;
}
bool SaveSSD(float* value) { return false; }
// update delta_score and unseen_days after save
void UpdateStatAfterSave(float* value, int param) override;
// keys不存在时,为values生成随机值
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/common/local_random.h"
#include "paddle/fluid/distributed/common/topk_calculator.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/utils/string/string_helper.h"
DECLARE_bool(pserver_print_missed_key_num_every_push);
DECLARE_bool(pserver_create_value_when_push);
DECLARE_bool(pserver_enable_create_feasign_randomly);
DEFINE_bool(pserver_open_strict_check, false, "pserver_open_strict_check");
DEFINE_string(rocksdb_path, "database", "path of sparse table rocksdb file");
DEFINE_int32(pserver_load_batch_size, 5000, "load batch size for ssd");
namespace paddle {
namespace distributed {
int32_t SSDSparseTable::Initialize() {
MemorySparseTable::Initialize();
_db = paddle::distributed::RocksDBHandler::GetInstance();
_db->initialize(FLAGS_rocksdb_path, _real_local_shard_num);
return 0;
}
int32_t SSDSparseTable::InitializeShard() { return 0; }
int32_t SSDSparseTable::PullSparse(float* pull_values, const uint64_t* keys,
size_t num) {
CostTimer timer("pserver_downpour_sparse_select_all");
size_t value_size = _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_size =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t select_value_size =
_value_accesor->GetAccessorInfo().select_size / sizeof(float);
{ // 从table取值 or create
std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
_real_local_shard_num);
for (size_t i = 0; i < num; ++i) {
int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
task_keys[shard_id].push_back({keys[i], i});
}
std::atomic<uint32_t> missed_keys{0};
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this, shard_id, &task_keys, value_size, mf_value_size,
select_value_size, pull_values, keys, &missed_keys]() -> int {
auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id];
float data_buffer[value_size];
float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first;
auto itr = local_shard.find(key);
size_t data_size = value_size - mf_value_size;
if (itr == local_shard.end()) {
// pull rocksdb
std::string tmp_string("");
if (_db->get(shard_id, (char*)&key, sizeof(uint64_t),
tmp_string) > 0) {
++missed_keys;
if (FLAGS_pserver_create_value_when_push) {
memset(data_buffer, 0, sizeof(float) * data_size);
} else {
auto& feature_value = local_shard[key];
feature_value.resize(data_size);
float* data_ptr =
const_cast<float*>(feature_value.data());
_value_accesor->Create(&data_buffer_ptr, 1);
memcpy(data_ptr, data_buffer_ptr,
data_size * sizeof(float));
}
} else {
data_size = tmp_string.size() / sizeof(float);
memcpy(data_buffer_ptr,
paddle::string::str_to_float(tmp_string),
data_size * sizeof(float));
// from rocksdb to mem
auto& feature_value = local_shard[key];
feature_value.resize(data_size);
memcpy(const_cast<float*>(feature_value.data()),
data_buffer_ptr, data_size * sizeof(float));
_db->del_data(shard_id, (char*)&key, sizeof(uint64_t));
}
} else {
data_size = itr.value().size();
memcpy(data_buffer_ptr, itr.value().data(),
data_size * sizeof(float));
}
for (int mf_idx = data_size; mf_idx < value_size; ++mf_idx) {
data_buffer[mf_idx] = 0.0;
}
int pull_data_idx = keys[i].second;
float* select_data =
pull_values + pull_data_idx * select_value_size;
_value_accesor->Select(&select_data,
(const float**)&data_buffer_ptr, 1);
}
return 0;
});
}
for (size_t i = 0; i < _real_local_shard_num; ++i) {
tasks[i].wait();
}
if (FLAGS_pserver_print_missed_key_num_every_push) {
LOG(WARNING) << "total pull keys:" << num
<< " missed_keys:" << missed_keys.load();
}
}
return 0;
}
int32_t SSDSparseTable::PushSparse(const uint64_t* keys, const float* values,
size_t num) {
CostTimer timer("pserver_downpour_sparse_update_all");
// 构造value push_value的数据指针
size_t value_col = _value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_col =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
size_t update_value_col =
_value_accesor->GetAccessorInfo().update_size / sizeof(float);
{
std::vector<std::future<int>> tasks(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, int>>> task_keys(
_real_local_shard_num);
for (size_t i = 0; i < num; ++i) {
int shard_id = (keys[i] % _sparse_table_shard_num) % _avg_local_shard_num;
task_keys[shard_id].push_back({keys[i], i});
}
for (size_t shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) {
tasks[shard_id] =
_shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue(
[this, shard_id, value_col, mf_value_col, update_value_col,
values, &task_keys]() -> int {
auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id];
float data_buffer[value_col];
float* data_buffer_ptr = data_buffer;
for (int i = 0; i < keys.size(); ++i) {
uint64_t key = keys[i].first;
uint64_t push_data_idx = keys[i].second;
const float* update_data =
values + push_data_idx * update_value_col;
auto itr = local_shard.find(key);
if (itr == local_shard.end()) {
if (FLAGS_pserver_enable_create_feasign_randomly &&
!_value_accesor->CreateValue(1, update_data)) {
continue;
}
auto value_size = value_col - mf_value_col;
auto& feature_value = local_shard[key];
feature_value.resize(value_size);
_value_accesor->Create(&data_buffer_ptr, 1);
memcpy(const_cast<float*>(feature_value.data()),
data_buffer_ptr, value_size * sizeof(float));
itr = local_shard.find(key);
}
auto& feature_value = itr.value();
float* value_data = const_cast<float*>(feature_value.data());
size_t value_size = feature_value.size();
if (value_size ==
value_col) { // 已拓展到最大size, 则就地update
_value_accesor->Update(&value_data, &update_data, 1);
} else { // 拷入buffer区进行update,然后再回填,不需要的mf则回填时抛弃了
memcpy(data_buffer_ptr, value_data,
value_size * sizeof(float));
_value_accesor->Update(&data_buffer_ptr, &update_data, 1);
if (_value_accesor->NeedExtendMF(data_buffer)) {
feature_value.resize(value_col);
value_data = const_cast<float*>(feature_value.data());
_value_accesor->Create(&value_data, 1);
}
memcpy(value_data, data_buffer_ptr,
value_size * sizeof(float));
}
}
return 0;
});
}
for (size_t i = 0; i < _real_local_shard_num; ++i) {
tasks[i].wait();
}
}
/*
//update && value 的转置
thread_local Eigen::MatrixXf update_matrix;
float* transposed_update_data[update_value_col];
make_matrix_with_eigen(num, update_value_col, update_matrix,
transposed_update_data);
copy_array_to_eigen(values, update_matrix);
thread_local Eigen::MatrixXf value_matrix;
float* transposed_value_data[value_col];
make_matrix_with_eigen(num, value_col, value_matrix, transposed_value_data);
copy_matrix_to_eigen((const float**)(value_ptrs->data()), value_matrix);
//批量update
{
CostTimer accessor_timer("pslib_downpour_sparse_update_accessor");
_value_accesor->update(transposed_value_data, (const
float**)transposed_update_data, num);
}
copy_eigen_to_matrix(value_matrix, value_ptrs->data());
*/
return 0;
}
int32_t SSDSparseTable::Shrink(const std::string& param) {
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) {
uint64_t mem_count = 0;
uint64_t ssd_count = 0;
LOG(INFO) << "SSDSparseTable begin shrink shard:" << i;
auto& shard = _local_shards[i];
for (auto it = shard.begin(); it != shard.end();) {
if (_value_accesor->Shrink(it.value().data())) {
it = shard.erase(it);
mem_count++;
} else {
++it;
}
}
auto* it = _db->get_iterator(i);
for (it->SeekToFirst(); it->Valid(); it->Next()) {
if (_value_accesor->Shrink(
paddle::string::str_to_float(it->value().data()))) {
_db->del_data(i, it->key().data(), it->key().size());
ssd_count++;
} else {
_db->put(i, it->key().data(), it->key().size(), it->value().data(),
it->value().size());
}
}
delete it;
LOG(INFO) << "SSDSparseTable shrink success. shard:" << i << " delete MEM["
<< mem_count << "] SSD[" << ssd_count << "]";
//_db->flush(i);
}
return 0;
}
int32_t SSDSparseTable::UpdateTable() {
// TODO implement with multi-thread
int count = 0;
for (size_t i = 0; i < _real_local_shard_num; ++i) {
auto& shard = _local_shards[i];
// from mem to ssd
for (auto it = shard.begin(); it != shard.end();) {
if (_value_accesor->SaveSSD(it.value().data())) {
_db->put(i, (char*)&it.key(), sizeof(uint64_t),
(char*)it.value().data(), it.value().size() * sizeof(float));
count++;
it = shard.erase(it);
} else {
++it;
}
}
_db->flush(i);
}
LOG(INFO) << "Table>> update count: " << count;
return 0;
}
int64_t SSDSparseTable::LocalSize() {
int64_t local_size = 0;
for (size_t i = 0; i < _real_local_shard_num; ++i) {
local_size += _local_shards[i].size();
}
// TODO rocksdb size
uint64_t ssd_size = 0;
// _db->get_estimate_key_num(ssd_size);
// return local_size + ssd_size;
return local_size;
}
int32_t SSDSparseTable::Save(const std::string& path,
const std::string& param) {
if (_real_local_shard_num == 0) {
_local_show_threshold = -1;
return 0;
}
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
// if (save_param == 5) {
// return save_patch(path, save_param);
// }
// LOG(INFO) << "table cache rate is: " << _config.sparse_table_cache_rate();
LOG(INFO) << "table cache rate is: " << _config.sparse_table_cache_rate();
LOG(INFO) << "enable_sparse_table_cache: "
<< _config.enable_sparse_table_cache();
LOG(INFO) << "LocalSize: " << LocalSize();
if (_config.enable_sparse_table_cache()) {
LOG(INFO) << "Enable sparse table cache, top n:" << _cache_tk_size;
}
_cache_tk_size = LocalSize() * _config.sparse_table_cache_rate();
TopkCalculator tk(_real_local_shard_num, _cache_tk_size);
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
std::string table_path = TableDir(path);
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx));
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
// std::atomic<uint32_t> feasign_size;
std::atomic<uint32_t> feasign_size_all{0};
// feasign_size = 0;
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) {
FsChannelConfig channel_config;
if (_config.compress_in_save() && (save_param == 0 || save_param == 3)) {
channel_config.path = paddle::string::format_string(
"%s/part-%03d-%05d.gz", table_path.c_str(), _shard_idx,
file_start_idx + i);
} else {
channel_config.path =
paddle::string::format_string("%s/part-%03d-%05d", table_path.c_str(),
_shard_idx, file_start_idx + i);
}
channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter =
_value_accesor->Converter(save_param).deconverter;
int err_no = 0;
int retry_num = 0;
bool is_write_failed = false;
int feasign_size = 0;
auto& shard = _local_shards[i];
do {
err_no = 0;
feasign_size = 0;
is_write_failed = false;
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_config.enable_sparse_table_cache() &&
(save_param == 1 || save_param == 2) &&
_value_accesor->Save(it.value().data(), 4)) {
// tk.push(i, it.value().data()[2]);
tk.push(i, _value_accesor->GetField(it.value().data(), "show"));
}
if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->ParseToString(
it.value().data(), it.value().size());
if (0 !=
write_channel->write_line(paddle::string::format_string(
"%lu %s", it.key(), format_value.c_str()))) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "SSDSparseTable save failed, retry it! path:"
<< channel_config.path << ", retry_num=" << retry_num;
break;
}
++feasign_size;
}
}
if (err_no == -1 && !is_write_failed) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "SSDSparseTable save failed after write, retry it! "
<< "path:" << channel_config.path
<< " , retry_num=" << retry_num;
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
continue;
}
// delta and cache and revert is all in mem, base in rocksdb
if (save_param != 1) {
auto* it = _db->get_iterator(i);
for (it->SeekToFirst(); it->Valid(); it->Next()) {
bool need_save = _value_accesor->Save(
paddle::string::str_to_float(it->value().data()), save_param);
_value_accesor->UpdateStatAfterSave(
paddle::string::str_to_float(it->value().data()), save_param);
if (need_save) {
std::string format_value = _value_accesor->ParseToString(
paddle::string::str_to_float(it->value().data()),
it->value().size() / sizeof(float));
if (0 !=
write_channel->write_line(paddle::string::format_string(
"%lu %s", *((uint64_t*)const_cast<char*>(it->key().data())),
format_value.c_str()))) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "SSDSparseTable save failed, retry it! path:"
<< channel_config.path << ", retry_num=" << retry_num;
break;
}
if (save_param == 3) {
_db->put(i, it->key().data(), it->key().size(),
it->value().data(), it->value().size());
}
++feasign_size;
}
}
delete it;
}
write_channel->close();
if (err_no == -1) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "SSDSparseTable save failed after write, retry it! "
<< "path:" << channel_config.path
<< " , retry_num=" << retry_num;
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
} while (is_write_failed);
feasign_size_all += feasign_size;
for (auto it = shard.begin(); it != shard.end(); ++it) {
_value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
}
}
if (save_param == 3) {
UpdateTable();
_cache_tk_size = LocalSize() * _config.sparse_table_cache_rate();
LOG(INFO) << "SSDSparseTable update success.";
}
LOG(INFO) << "SSDSparseTable save success, path:"
<< paddle::string::format_string("%s/%03d/part-%03d-", path.c_str(),
_config.table_id(), _shard_idx)
<< " from " << file_start_idx << " to "
<< file_start_idx + _real_local_shard_num - 1;
// return feasign_size_all;
_local_show_threshold = tk.top();
LOG(INFO) << "local cache threshold: " << _local_show_threshold;
// int32 may overflow need to change return value
return 0;
}
int64_t SSDSparseTable::CacheShuffle(
const std::string& path, const std::string& param, double cache_threshold,
std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
std::string& msg)>
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) {
LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold
<< " param:" << param;
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
if (!_config.enable_sparse_table_cache() || cache_threshold < 0) {
LOG(WARNING)
<< "cache shuffle failed not enable table cache or cache threshold < 0 "
<< _config.enable_sparse_table_cache() << " or " << cache_threshold;
// return -1;
}
int shuffle_node_num = _config.sparse_table_cache_file_num();
LOG(INFO) << "Table>> shuffle node num is: " << shuffle_node_num;
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
std::vector<
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>>
writers(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, std::string>>> datas(
_real_local_shard_num);
int feasign_size = 0;
std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>>
tmp_channels;
for (size_t i = 0; i < _real_local_shard_num; ++i) {
tmp_channels.push_back(
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>());
}
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer =
writers[i];
// std::shared_ptr<paddle::framework::ChannelObject<std::pair<uint64_t,
// std::string>>> tmp_chan =
// paddle::framework::MakeChannel<std::pair<uint64_t,
// std::string>>();
writer.Reset(tmp_channels[i].get());
auto& shard = _local_shards[i];
for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_value_accesor->SaveCache(it.value().data(), save_param,
cache_threshold)) {
std::string format_value =
_value_accesor->ParseToString(it.value().data(), it.value().size());
std::pair<uint64_t, std::string> pkv(it.key(), format_value.c_str());
writer << pkv;
++feasign_size;
}
}
writer.Flush();
writer.channel()->Close();
}
LOG(INFO) << "SSDSparseTable cache KV save success to Channel feasigh size: "
<< feasign_size
<< " and start sparse cache data shuffle real local shard num: "
<< _real_local_shard_num;
std::vector<std::pair<uint64_t, std::string>> local_datas;
for (size_t idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer =
writers[idx_shard];
auto channel = writer.channel();
std::vector<std::pair<uint64_t, std::string>>& data = datas[idx_shard];
std::vector<paddle::framework::BinaryArchive> ars(shuffle_node_num);
while (channel->Read(data)) {
for (auto& t : data) {
auto pserver_id =
paddle::distributed::local_random_engine()() % shuffle_node_num;
if (pserver_id != _shard_idx) {
ars[pserver_id] << t;
} else {
local_datas.emplace_back(std::move(t));
}
}
std::vector<std::future<int32_t>> total_status;
std::vector<uint32_t> send_data_size(shuffle_node_num, 0);
std::vector<int> send_index(shuffle_node_num);
for (int i = 0; i < shuffle_node_num; ++i) {
send_index[i] = i;
}
std::random_shuffle(send_index.begin(), send_index.end());
for (auto index = 0u; index < shuffle_node_num; ++index) {
int i = send_index[index];
if (i == _shard_idx) {
continue;
}
if (ars[i].Length() == 0) {
continue;
}
std::string msg(ars[i].Buffer(), ars[i].Length());
auto ret = send_msg_func(101, i, msg);
total_status.push_back(std::move(ret));
send_data_size[i] += ars[i].Length();
}
for (auto& t : total_status) {
t.wait();
}
ars.clear();
ars = std::vector<paddle::framework::BinaryArchive>(shuffle_node_num);
data = std::vector<std::pair<uint64_t, std::string>>();
}
}
shuffled_channel->Write(std::move(local_datas));
LOG(INFO) << "cache shuffle finished";
return 0;
}
int32_t SSDSparseTable::SaveCache(
const std::string& path, const std::string& param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) {
if (_shard_idx >= _config.sparse_table_cache_file_num()) {
return 0;
}
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
std::string table_path = paddle::string::format_string(
"%s/%03d_cache/", path.c_str(), _config.table_id());
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d", table_path.c_str(), _shard_idx));
uint32_t feasign_size = 0;
FsChannelConfig channel_config;
// not compress cache model
channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_path.c_str(), _shard_idx);
channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter =
_value_accesor->Converter(save_param).deconverter;
auto write_channel = _afs_client.open_w(channel_config, 1024 * 1024 * 40);
std::vector<std::pair<uint64_t, std::string>> data;
bool is_write_failed = false;
shuffled_channel->Close();
while (shuffled_channel->Read(data)) {
for (auto& t : data) {
++feasign_size;
if (0 !=
write_channel->write_line(paddle::string::format_string(
"%lu %s", t.first, t.second.c_str()))) {
LOG(ERROR) << "Cache Table save failed, "
"path:"
<< channel_config.path << ", retry it!";
is_write_failed = true;
break;
}
}
data = std::vector<std::pair<uint64_t, std::string>>();
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
write_channel->close();
LOG(INFO) << "SSDSparseTable cache save success, feasign: " << feasign_size
<< ", path: " << channel_config.path;
shuffled_channel->Open();
return feasign_size;
}
int32_t SSDSparseTable::Load(const std::string& path,
const std::string& param) {
return MemorySparseTable::Load(path, param);
}
//加载path目录下数据[start_idx, end_idx)
int32_t SSDSparseTable::Load(size_t start_idx, size_t end_idx,
const std::vector<std::string>& file_list,
const std::string& param) {
if (start_idx >= file_list.size()) {
return 0;
}
int load_param = atoi(param.c_str());
size_t feature_value_size =
_value_accesor->GetAccessorInfo().size / sizeof(float);
size_t mf_value_size =
_value_accesor->GetAccessorInfo().mf_size / sizeof(float);
end_idx =
end_idx < _sparse_table_shard_num ? end_idx : _sparse_table_shard_num;
int thread_num = (end_idx - start_idx) < 20 ? (end_idx - start_idx) : 20;
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (size_t i = start_idx; i < end_idx; ++i) {
FsChannelConfig channel_config;
channel_config.path = file_list[i];
channel_config.converter = _value_accesor->Converter(load_param).converter;
channel_config.deconverter =
_value_accesor->Converter(load_param).deconverter;
int retry_num = 0;
int err_no = 0;
bool is_read_failed = false;
std::vector<std::pair<char*, int>> ssd_keys;
std::vector<std::pair<char*, int>> ssd_values;
std::vector<uint64_t> tmp_key;
ssd_keys.reserve(FLAGS_pserver_load_batch_size);
ssd_values.reserve(FLAGS_pserver_load_batch_size);
tmp_key.reserve(FLAGS_pserver_load_batch_size);
do {
ssd_keys.clear();
ssd_values.clear();
tmp_key.clear();
err_no = 0;
is_read_failed = false;
std::string line_data;
auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
char* end = NULL;
int local_shard_id = i % _avg_local_shard_num;
auto& shard = _local_shards[local_shard_id];
float data_buffer[FLAGS_pserver_load_batch_size * feature_value_size];
float* data_buffer_ptr = data_buffer;
uint64_t mem_count = 0;
uint64_t ssd_count = 0;
uint64_t mem_mf_count = 0;
uint64_t ssd_mf_count = 0;
try {
while (read_channel->read_line(line_data) == 0 &&
line_data.size() > 1) {
uint64_t key = std::strtoul(line_data.data(), &end, 10);
if (FLAGS_pserver_open_strict_check) {
if (key % _sparse_table_shard_num != i) {
LOG(WARNING) << "SSDSparseTable key:" << key
<< " not match shard,"
<< " file_idx:" << i
<< " shard num:" << _sparse_table_shard_num
<< " file:" << channel_config.path;
continue;
}
}
int value_size =
_value_accesor->ParseFromString(++end, data_buffer_ptr);
// ssd or mem
if (_value_accesor->SaveSSD(data_buffer_ptr)) {
tmp_key.emplace_back(key);
ssd_keys.emplace_back(
std::make_pair((char*)&tmp_key.back(), sizeof(uint64_t)));
ssd_values.emplace_back(std::make_pair((char*)data_buffer_ptr,
value_size * sizeof(float)));
data_buffer_ptr += feature_value_size;
if (ssd_keys.size() == FLAGS_pserver_load_batch_size) {
_db->put_batch(local_shard_id, ssd_keys, ssd_values,
ssd_keys.size());
ssd_keys.clear();
ssd_values.clear();
tmp_key.clear();
data_buffer_ptr = data_buffer;
}
ssd_count++;
if (value_size > feature_value_size - mf_value_size) {
ssd_mf_count++;
}
} else {
auto& value = shard[key];
value.resize(value_size);
_value_accesor->ParseFromString(end, value.data());
mem_count++;
if (value_size > feature_value_size - mf_value_size) {
mem_mf_count++;
}
}
}
// last batch
if (ssd_keys.size() > 0) {
_db->put_batch(local_shard_id, ssd_keys, ssd_values, ssd_keys.size());
}
read_channel->close();
if (err_no == -1) {
++retry_num;
is_read_failed = true;
LOG(ERROR) << "SSDSparseTable load failed after read, retry it! path:"
<< channel_config.path << " , retry_num=" << retry_num;
continue;
}
_db->flush(local_shard_id);
LOG(INFO) << "Table>> load done. ALL[" << mem_count + ssd_count
<< "] MEM[" << mem_count << "] MEM_MF[" << mem_mf_count
<< "] SSD[" << ssd_count << "] SSD_MF[" << ssd_mf_count
<< "].";
} catch (...) {
++retry_num;
is_read_failed = true;
LOG(ERROR) << "SSDSparseTable load failed after read, retry it! path:"
<< channel_config.path << " , retry_num=" << retry_num;
}
} while (is_read_failed);
}
LOG(INFO) << "load num:" << LocalSize();
LOG(INFO) << "SSDSparseTable load success, path from " << file_list[start_idx]
<< " to " << file_list[end_idx - 1];
_cache_tk_size = LocalSize() * _config.sparse_table_cache_rate();
return 0;
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
namespace paddle {
namespace distributed {
class SSDSparseTable : public MemorySparseTable {
public:
typedef SparseTableShard<uint64_t, FixedFeatureValue> shard_type;
SSDSparseTable() {}
virtual ~SSDSparseTable() {}
int32_t Initialize() override;
int32_t InitializeShard() override;
// exchange data
int32_t UpdateTable();
int32_t Pull(TableContext& context) override {
CHECK(context.value_type == Sparse);
float* pull_values = context.pull_context.values;
const PullSparseValue& pull_value = context.pull_context.pull_value;
return PullSparse(pull_values, pull_value.feasigns_, pull_value.numel_);
}
int32_t Push(TableContext& context) override {
const uint64_t* keys = context.push_context.keys;
const float* values = context.push_context.values;
size_t num = context.num;
return PushSparse(keys, values, num);
}
virtual int32_t PullSparse(float* pull_values, const uint64_t* keys,
size_t num);
virtual int32_t PushSparse(const uint64_t* keys, const float* values,
size_t num);
int32_t Flush() override { return 0; }
virtual int32_t Shrink(const std::string& param) override;
virtual void Clear() override {
for (size_t i = 0; i < _real_local_shard_num; ++i) {
_local_shards[i].clear();
}
}
virtual int32_t Save(const std::string& path,
const std::string& param) override;
virtual int32_t SaveCache(
const std::string& path, const std::string& param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) override;
virtual double GetCacheThreshold() override { return _local_show_threshold; }
virtual int64_t CacheShuffle(
const std::string& path, const std::string& param, double cache_threshold,
std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
std::string& msg)>
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) override;
//加载path目录下数据
virtual int32_t Load(const std::string& path,
const std::string& param) override;
//加载path目录下数据[start_idx, end_idx)
virtual int32_t Load(size_t start_idx, size_t end_idx,
const std::vector<std::string>& file_list,
const std::string& param);
int64_t LocalSize();
private:
RocksDBHandler* _db;
int64_t _cache_tk_size;
double _local_show_threshold{0.0};
};
} // namespace distributed
} // namespace paddle
......@@ -25,6 +25,7 @@
#include "paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/sparse_accessor.h"
#include "paddle/fluid/distributed/ps/table/ssd_sparse_table.h"
#include "paddle/fluid/distributed/ps/table/tensor_accessor.h"
#include "paddle/fluid/distributed/ps/table/tensor_table.h"
......@@ -37,6 +38,7 @@ REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseTable);
REGISTER_PSCORE_CLASS(Table, SSDSparseTable);
REGISTER_PSCORE_CLASS(Table, MemorySparseGeoTable);
REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
REGISTER_PSCORE_CLASS(ValueAccessor, CtrCommonAccessor);
......
......@@ -24,6 +24,7 @@
#include "paddle/fluid/distributed/ps/table/accessor.h"
#include "paddle/fluid/distributed/ps/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -107,6 +108,26 @@ class Table {
// 指定保存路径
virtual int32_t Save(const std::string &path,
const std::string &converter) = 0;
// for cache
virtual int32_t SaveCache(
const std::string &path, const std::string &param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel) {
return 0;
}
virtual int64_t CacheShuffle(
const std::string &path, const std::string &param, double cache_threshold,
std::function<std::future<int32_t>(int msg_type, int to_pserver_id,
std::string &msg)>
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel,
const std::vector<Table *> &table_ptrs) {
return 0;
}
virtual double GetCacheThreshold() { return 0.0; }
virtual int32_t SetShard(size_t shard_idx, size_t shard_num) {
_shard_idx = shard_idx;
......
......@@ -38,6 +38,12 @@ class CommMergeAccessor : public ValueAccessor {
// param作为参数用于标识save阶段,如downpour的xbox与batch_model
virtual bool Save(float * /*value*/, int /*param*/);
bool SaveCache(float *value, int param, double global_cache_threshold) {
return false;
}
bool SaveSSD(float *value) { return false; }
// keys不存在时,为values生成随机值
virtual int32_t Create(float **value, size_t num);
// 从values中选取到select_values中
......
......@@ -754,6 +754,46 @@ std::future<int32_t> FleetWrapper::SendClientToClientMsg(
return worker_ptr_->SendClient2ClientMsg(msg_type, to_client_id, msg);
}
double FleetWrapper::GetCacheThreshold(int table_id) {
double cache_threshold = 0.0;
auto ret = worker_ptr_->Flush();
ret.wait();
ret = worker_ptr_->GetCacheThreshold(table_id, cache_threshold);
ret.wait();
if (cache_threshold < 0) {
LOG(ERROR) << "get cache threshold failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return cache_threshold;
}
void FleetWrapper::CacheShuffle(int table_id, const std::string& path,
const int mode, const double cache_threshold) {
auto ret = worker_ptr_->CacheShuffle(table_id, path, std::to_string(mode),
std::to_string(cache_threshold));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "cache shuffle failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
}
int32_t FleetWrapper::SaveCache(int table_id, const std::string& path,
const int mode) {
auto ret = worker_ptr_->SaveCache(table_id, path, std::to_string(mode));
ret.wait();
int32_t feasign_cnt = ret.get();
if (feasign_cnt == -1) {
LOG(ERROR) << "table save cache failed";
sleep(sleep_seconds_before_fail_exit_);
exit(-1);
}
return feasign_cnt;
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
......
......@@ -259,6 +259,11 @@ class FleetWrapper {
// for init worker
void InitGFlag(const std::string& gflags);
double GetCacheThreshold(int table_id);
void CacheShuffle(int table_id, const std::string& path, const int mode,
const double cache_threshold);
int32_t SaveCache(int table_id, const std::string& path, const int mode);
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
......
......@@ -116,6 +116,10 @@ message TableParameter {
optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ];
optional GraphParameter graph_parameter = 9;
// for cache model
optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
}
message TableAccessorParameter {
......
......@@ -78,7 +78,11 @@ void BindDistFleetWrapper(py::module* m) {
.def("set_clients", &FleetWrapper::SetClients)
.def("get_client_info", &FleetWrapper::GetClientsInfo)
.def("create_client2client_connection",
&FleetWrapper::CreateClient2ClientConnection);
&FleetWrapper::CreateClient2ClientConnection)
.def("client_flush", &FleetWrapper::ClientFlush)
.def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &FleetWrapper::CacheShuffle)
.def("save_cache", &FleetWrapper::SaveCache);
}
void BindPSHost(py::module* m) {
......
......@@ -100,6 +100,14 @@ inline int str_to_float(const char* str, float* v) {
return index;
}
inline float* str_to_float(std::string& str) {
return (float*)const_cast<char*>(str.c_str());
}
inline float* str_to_float(const char* str) {
return (float*)const_cast<char*>(str);
}
// checks whether the test string is a suffix of the input string.
bool ends_with(std::string const& input, std::string const& test);
......
......@@ -77,6 +77,7 @@ stop_worker = fleet.stop_worker
distributed_optimizer = fleet.distributed_optimizer
save_inference_model = fleet.save_inference_model
save_persistables = fleet.save_persistables
save_cache_model = fleet.save_cache_model
load_model = fleet.load_model
minimize = fleet.minimize
distributed_model = fleet.distributed_model
......
......@@ -869,6 +869,11 @@ class Fleet(object):
self._runtime_handle._save_persistables(executor, dirname, main_program,
mode)
@is_non_distributed_check
@inited_runtime_handler
def save_cache_model(self, dirname, **configs):
return self._runtime_handle._save_cache_model(dirname, **configs)
def shrink(self, threshold=None):
self._runtime_handle._shrink(threshold)
......
......@@ -1315,6 +1315,30 @@ class TheOnePSRuntime(RuntimeBase):
def _save_persistables(self, *args, **kwargs):
self._ps_inference_save_persistables(*args, **kwargs)
def _save_cache_model(self, dirname, **kwargs):
mode = kwargs.get("mode", 0)
table_id = kwargs.get("table_id", 0)
self._worker.client_flush()
fleet.util.barrier()
cache_threshold = 0.0
if self.role_maker._is_first_worker():
cache_threshold = self._worker.get_cache_threshold(table_id)
#check cache threshold right or not
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.cache_shuffle(table_id, dirname, mode, cache_threshold)
fleet.util.barrier()
feasign_num = -1
if self.role_maker._is_first_worker():
feasign_num = self._worker.save_cache(table_id, dirname, mode)
fleet.util.barrier()
return feasign_num
def _load_sparse_params(self, dirname, context, main_program, mode):
distributed_varnames = get_sparse_tablenames(self.origin_main_programs,
True)
......
......@@ -339,5 +339,9 @@ class TestDistCTR2x2(FleetDistRunnerBase):
if dirname:
fleet.save_persistables(exe, dirname=dirname)
cache_dirname = os.getenv("SAVE_CACHE_DIRNAME", None)
if cache_dirname:
fleet.save_cache_model(cache_dirname)
if __name__ == "__main__":
runtime_main(TestDistCTR2x2)
......@@ -39,6 +39,8 @@ class TestDistMnistAsyncInMemoryDataset2x2(TestFleetBase):
"http_proxy": "",
"CPU_NUM": "2",
"LOG_DIRNAME": "/tmp",
"SAVE_CACHE_DIRNAME":
"/tmp/TestDistMnistAsyncInMemoryDataset2x2/cache_model",
"LOG_PREFIX": self.__class__.__name__,
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册