未验证 提交 f382eb06 编写于 作者: Z zhaocaibei123 提交者: GitHub

add save_cache/patch (#44420)

* add save_cache/patch

* add pybind

* remove pybind

* remove const_cast

* add fleet
上级 2792b8de
...@@ -482,7 +482,7 @@ std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id, ...@@ -482,7 +482,7 @@ std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
request_call_num, request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) { [request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
std::vector<double> cache_thresholds(request_call_num, 0); std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) { if (closure->check_response(i, cmd_id) != 0) {
...@@ -530,6 +530,14 @@ std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) { ...@@ -530,6 +530,14 @@ std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) {
return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {}); return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {});
} }
std::future<int32_t> BrpcPsClient::Revert() {
return SendCmd(-1, PS_REVERT, {});
}
std::future<int32_t> BrpcPsClient::CheckSavePrePatchDone() {
return SendCmd(-1, PS_CHECK_SAVE_PRE_PATCH_DONE, {});
}
std::future<int32_t> BrpcPsClient::Flush() { std::future<int32_t> BrpcPsClient::Flush() {
VLOG(0) << "BrpcPsClient::flush begin"; VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true; _flushing = true;
...@@ -1170,6 +1178,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values, ...@@ -1170,6 +1178,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
closure->add_timer(timer); closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise); closure->add_promise(promise);
......
...@@ -178,6 +178,9 @@ class BrpcPsClient : public PSClient { ...@@ -178,6 +178,9 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> Clear(uint32_t table_id) override; std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> Revert() override;
std::future<int32_t> CheckSavePrePatchDone() override;
std::future<int32_t> StopServer() override; std::future<int32_t> StopServer() override;
std::future<int32_t> StartProfiler() override; std::future<int32_t> StartProfiler() override;
...@@ -298,16 +301,16 @@ class BrpcPsClient : public PSClient { ...@@ -298,16 +301,16 @@ class BrpcPsClient : public PSClient {
int PushSparseAsyncShardMerge( int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, std::vector<int> &request_kv_num, // NOLINT
int table_id, int table_id,
int shard_idx, // NOLINT int shard_idx,
ValueAccessor *accessor); ValueAccessor *accessor);
int PushSparseAsyncShardPush( int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, std::vector<int> &request_kv_num, // NOLINT
int table_id, int table_id,
int shard_idx, // NOLINT int shard_idx,
DownpourBrpcClosure *closure, DownpourBrpcClosure *closure,
ValueAccessor *accessor); ValueAccessor *accessor);
......
...@@ -146,7 +146,7 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg( ...@@ -146,7 +146,7 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
return fut; return fut;
} }
auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) { auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourPServerBrpcClosure *)done; auto *closure = reinterpret_cast<DownpourPServerBrpcClosure *>(done);
int32_t ret = closure->check_response(0, msg_type + 1000); int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
...@@ -209,13 +209,16 @@ int32_t BrpcPsService::Initialize() { ...@@ -209,13 +209,16 @@ int32_t BrpcPsService::Initialize() {
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler; _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep; _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
// for save cache // for save cache
_service_handler_map[PS_SAVE_ONE_CACHE_TABLE] = _service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
&BrpcPsService::SaveCacheTable; &BrpcPsService::SaveCacheTable;
_service_handler_map[PS_GET_CACHE_THRESHOLD] = _service_handler_map[PS_GET_CACHE_THRESHOLD] =
&BrpcPsService::GetCacheThreshold; &BrpcPsService::GetCacheThreshold;
_service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle; _service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;
_service_handler_map[PS_REVERT] = &BrpcPsService::Revert;
_service_handler_map[PS_CHECK_SAVE_PRE_PATCH_DONE] =
&BrpcPsService::CheckSavePrePatchDone;
auto &profiler = CostProfiler::instance(); auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense"); profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense"); profiler.register_profiler("pserver_server_push_dense");
...@@ -319,9 +322,8 @@ int32_t BrpcPsService::PullDense(Table *table, ...@@ -319,9 +322,8 @@ int32_t BrpcPsService::PullDense(Table *table,
table_context.pull_context.values = res_data->data(); table_context.pull_context.values = res_data->data();
table_context.num = num; table_context.num = num;
table->Pull(table_context); table->Pull(table_context);
// table->PullDense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()), cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float)); res_data->size() * sizeof(float));
butil::return_object(res_data); butil::return_object(res_data);
...@@ -356,7 +358,6 @@ int32_t BrpcPsService::PushDenseParam(Table *table, ...@@ -356,7 +358,6 @@ int32_t BrpcPsService::PushDenseParam(Table *table,
table_context.push_context.is_param = true; table_context.push_context.is_param = true;
table_context.num = num; table_context.num = num;
// if (table->PushDenseParam(values, num) != 0) {
if (table->Push(table_context) != 0) { if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushDenseParam failed"); set_response_code(response, -1, "PushDenseParam failed");
} }
...@@ -438,7 +439,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table, ...@@ -438,7 +439,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table,
"least 1 for num of sparse_key"); "least 1 for num of sparse_key");
return 0; return 0;
} }
uint32_t num = *(uint32_t *)(request.params(0).c_str()); const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/* /*
Push Content: Push Content:
|---keysData---|---valuesData---| |---keysData---|---valuesData---|
...@@ -484,10 +486,11 @@ int32_t BrpcPsService::PullGeoParam(Table *table, ...@@ -484,10 +486,11 @@ int32_t BrpcPsService::PullGeoParam(Table *table,
// table->PullGeoParam(trainer_id, &values, &ids); // table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size(); uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); cntl->response_attachment().append(reinterpret_cast<char *>(&num),
cntl->response_attachment().append((char *)ids.data(), sizeof(uint32_t));
cntl->response_attachment().append(reinterpret_cast<char *>(ids.data()),
ids.size() * sizeof(uint64_t)); ids.size() * sizeof(uint64_t));
cntl->response_attachment().append((char *)values.data(), cntl->response_attachment().append(reinterpret_cast<char *>(values.data()),
values.size() * sizeof(float)); values.size() * sizeof(float));
return 0; return 0;
} }
...@@ -517,7 +520,8 @@ int32_t BrpcPsService::PullSparse(Table *table, ...@@ -517,7 +520,8 @@ int32_t BrpcPsService::PullSparse(Table *table,
} }
CostTimer timer("pserver_server_pull_sparse"); CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim; auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
thread_local std::string req_buffer; thread_local std::string req_buffer;
...@@ -539,7 +543,7 @@ int32_t BrpcPsService::PullSparse(Table *table, ...@@ -539,7 +543,7 @@ int32_t BrpcPsService::PullSparse(Table *table,
table->Pull(table_context); table->Pull(table_context);
// table->PullSparse(res_data->data(), value); // table->PullSparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()), cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float)); res_data->size() * sizeof(float));
butil::return_object(res_data); butil::return_object(res_data);
return 0; return 0;
...@@ -565,7 +569,8 @@ int32_t BrpcPsService::PushSparse(Table *table, ...@@ -565,7 +569,8 @@ int32_t BrpcPsService::PushSparse(Table *table,
return 0; return 0;
} }
CostTimer timer("pserver_server_push_sparse"); CostTimer timer("pserver_server_push_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/* /*
Push Content: Push Content:
|---keysData---|---valuesData---| |---keysData---|---valuesData---|
...@@ -767,6 +772,29 @@ int32_t BrpcPsService::GetCacheThreshold(Table *table, ...@@ -767,6 +772,29 @@ int32_t BrpcPsService::GetCacheThreshold(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->Flush();
itr.second->Revert();
}
return 0;
}
int32_t BrpcPsService::CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->CheckSavePrePatchDone();
}
return 0;
}
int32_t BrpcPsService::ShrinkTable(Table *table, int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
......
...@@ -53,10 +53,10 @@ class BrpcPsServer : public PSServer { ...@@ -53,10 +53,10 @@ class BrpcPsServer : public PSServer {
} }
int32_t Port(); int32_t Port();
virtual int32_t StartS2S() override; int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg( ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override; int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type, int32_t ReceiveFromPServer(int msg_type,
int pserver_id, int pserver_id,
const std::string &msg) override; const std::string &msg) override;
...@@ -75,14 +75,14 @@ class BrpcPsService; ...@@ -75,14 +75,14 @@ class BrpcPsService;
typedef int32_t (BrpcPsService::*serviceHandlerFunc)( typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table, Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
class BrpcPsService : public PsBaseService { class BrpcPsService : public PsBaseService {
public: public:
virtual int32_t Initialize() override; int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller, void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, const PsRequestMessage *request,
PsResponseMessage *response, PsResponseMessage *response,
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
...@@ -91,100 +91,110 @@ class BrpcPsService : public PsBaseService { ...@@ -91,100 +91,110 @@ class BrpcPsService : public PsBaseService {
int32_t InitializeShardInfo(); int32_t InitializeShardInfo();
int32_t PullDense(Table *table, int32_t PullDense(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushDense(Table *table, int32_t PushDense(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushDenseParam(Table *table, int32_t PushDenseParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushSparseParam(Table *table, int32_t PushSparseParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PullSparse(Table *table, int32_t PullSparse(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PullGeoParam(Table *table, int32_t PullGeoParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t Barrier(Table *table, int32_t Barrier(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushSparse(Table *table, int32_t PushSparse(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t LoadOneTable(Table *table, int32_t LoadOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t LoadAllTable(Table *table, int32_t LoadAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t SaveOneTable(Table *table, int32_t SaveOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t SaveAllTable(Table *table, int32_t SaveAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ShrinkTable(Table *table, int32_t ShrinkTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ClearOneTable(Table *table, int32_t ClearOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ClearAllTable(Table *table, int32_t ClearAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t StopServer(Table *table, int32_t StopServer(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t StartProfiler(Table *table, int32_t StartProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t StopProfiler(Table *table, int32_t StopProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PrintTableStat(Table *table, int32_t PrintTableStat(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table, int32_t PushGlobalStep(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t CacheShuffle(Table *table, int32_t CacheShuffle(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table, int32_t SaveCacheTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table, int32_t GetCacheThreshold(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
bool _is_initialize_shard_info; bool _is_initialize_shard_info;
...@@ -208,7 +218,7 @@ class DownpourPServerBrpcClosure : public PServerClosure { ...@@ -208,7 +218,7 @@ class DownpourPServerBrpcClosure : public PServerClosure {
} }
virtual ~DownpourPServerBrpcClosure() {} virtual ~DownpourPServerBrpcClosure() {}
virtual void Run() override { void Run() override {
if (_waiting_num.fetch_sub(1) == 1) { if (_waiting_num.fetch_sub(1) == 1) {
_callback(this); _callback(this);
delete this; delete this;
......
...@@ -67,12 +67,12 @@ class PSClient { ...@@ -67,12 +67,12 @@ class PSClient {
PSClient(PSClient &&) = delete; PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete; PSClient(const PSClient &) = delete;
virtual int32_t Configure( // NOLINT virtual int32_t Configure(
const PSParameter &config, const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions, &regions,
PSEnvironment &_env, PSEnvironment &_env, // NOLINT
size_t client_id) final; // NOLINT size_t client_id) final;
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms, virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
...@@ -293,8 +293,25 @@ class PSClient { ...@@ -293,8 +293,25 @@ class PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> GetCacheThreshold(uint32_t table_id, virtual std::future<int32_t> GetCacheThreshold(
double &cache_threshold) { uint32_t table_id,
double &cache_threshold) { // NOLINT
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> Revert() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CheckSavePrePatchDone() {
VLOG(0) << "Did not implement"; VLOG(0) << "Did not implement";
std::promise<int32_t> promise; std::promise<int32_t> promise;
std::future<int> fut = promise.get_future(); std::future<int> fut = promise.get_future();
......
...@@ -65,6 +65,8 @@ enum PsCmdID { ...@@ -65,6 +65,8 @@ enum PsCmdID {
PS_SAVE_WITH_SHARD = 44; PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45; PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46; PS_QUERY_WITH_SHARD = 46;
PS_REVERT = 47;
PS_CHECK_SAVE_PRE_PATCH_DONE = 48;
// pserver2pserver cmd start from 100 // pserver2pserver cmd start from 100
PS_S2S_MSG = 101; PS_S2S_MSG = 101;
} }
......
...@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer { ...@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer {
} }
float* summary_decay_rate; float* summary_decay_rate;
double summary_decay_rate_d = 0.999999; double summary_decay_rate_d = 0.9999999;
float* param; float* param;
}; };
......
...@@ -339,34 +339,37 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -339,34 +339,37 @@ int32_t MemoryDenseTable::Save(const std::string& path,
_value_accesor->Converter(save_param).deconverter; _value_accesor->Converter(save_param).deconverter;
bool is_write_failed = false; bool is_write_failed = false;
std::vector<std::vector<std::string>> result_buffer_param( std::vector<std::string> result_buffer_param;
param_dim_, std::vector<std::string>()); result_buffer_param.reserve(param_dim_);
std::vector<std::string> result_buffer_fixed_len;
result_buffer_fixed_len.reserve(fixed_len_params_dim_);
auto common = _config.common(); auto common = _config.common();
int size = static_cast<int>(common.params().size()); int size = static_cast<int>(common.params().size());
if (_config.common().name() == "summary") { if (_config.common().name() == "summary") {
for (int x = 0; x < param_dim_; ++x) { for (int x = 0; x < param_dim_; ++x) {
result_buffer_param[x].emplace_back( result_buffer_param.emplace_back(std::to_string(values_[param_idx_][x]));
std::to_string(values_[param_idx_][x]));
} }
} else if (_config.common().name() == "adam_d2sum") {
} else {
std::ostringstream os; std::ostringstream os;
for (int x = 0; x < size; ++x) { for (int y = 0; y < param_dim_; ++y) {
int dim = common.dims()[x];
VLOG(3) << "MemoryDenseTable::save dim " << x << " size: " << dim;
for (int y = 0; y < dim; ++y) {
os.clear(); os.clear();
os.str(""); os.str("");
os << values_[x][y]; os << values_[param_col_ids_[0]][y] << " 0";
if (dim == param_dim_) { for (int x = 2; x < param_col_ids_.size(); ++x) {
result_buffer_param[y].emplace_back(std::move(os.str())); os << " ";
} else { os << values_[param_col_ids_[x]][y];
result_buffer_fixed_len.emplace_back(std::move(os.str()));
} }
result_buffer_param.emplace_back(std::move(os.str()));
} }
} else {
std::ostringstream os;
for (int y = 0; y < param_dim_; ++y) {
os.clear();
os.str("");
os << values_[param_col_ids_[0]][y];
for (int x = 1; x < param_col_ids_.size(); ++x) {
os << " ";
os << values_[param_col_ids_[x]][y];
}
result_buffer_param.emplace_back(std::move(os.str()));
} }
} }
...@@ -379,12 +382,9 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -379,12 +382,9 @@ int32_t MemoryDenseTable::Save(const std::string& path,
// 40M // 40M
auto write_channel = auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto& t : result_buffer_param) { for (auto& t : result_buffer_param) {
if (_config.common().name() == "adam_d2sum") { if (0 != write_channel->write_line(t)) {
t.insert(t.begin() + 1, "0"); // avg_w
}
if (0 !=
write_channel->write_line(paddle::string::join_strings(t, ' '))) {
++retry_num; ++retry_num;
is_write_failed = true; is_write_failed = true;
LOG(ERROR) << "DownpourDenseTable save failed, retry it! " LOG(ERROR) << "DownpourDenseTable save failed, retry it! "
...@@ -395,6 +395,7 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -395,6 +395,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
} }
++feasign_size; ++feasign_size;
VLOG(3) << "save begin close " << channel_config.path;
write_channel->close(); write_channel->close();
if (err_no == -1) { if (err_no == -1) {
++retry_num; ++retry_num;
......
...@@ -12,15 +12,18 @@ ...@@ -12,15 +12,18 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include <omp.h> #include <omp.h>
#include <sstream> #include <sstream>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/distributed/common/cost_timer.h" #include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/common/local_random.h"
#include "paddle/fluid/distributed/common/topk_calculator.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/framework/io/fs.h"
// #include "boost/lexical_cast.hpp"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
DEFINE_bool(pserver_print_missed_key_num_every_push, DEFINE_bool(pserver_print_missed_key_num_every_push,
...@@ -68,6 +71,30 @@ int32_t MemorySparseTable::InitializeValue() { ...@@ -68,6 +71,30 @@ int32_t MemorySparseTable::InitializeValue() {
_local_shards.reset(new shard_type[_real_local_shard_num]); _local_shards.reset(new shard_type[_real_local_shard_num]);
if (_config.enable_revert()) {
// calculate merged shard number based on config param;
_shard_merge_rate = _config.has_shard_merge_rate()
? _config.shard_merge_rate()
: _shard_merge_rate;
CHECK((_m_avg_local_shard_num = static_cast<int>(
std::ceil(_avg_local_shard_num * _shard_merge_rate)),
_m_avg_local_shard_num <= _avg_local_shard_num));
CHECK((_m_real_local_shard_num = static_cast<int>(
std::ceil(_real_local_shard_num * _shard_merge_rate)),
_m_real_local_shard_num <= _real_local_shard_num));
uint32_t avg_shard_server_num =
_sparse_table_shard_num / _avg_local_shard_num;
uint32_t last_server_shard_num =
_sparse_table_shard_num - avg_shard_server_num * _avg_local_shard_num;
_m_sparse_table_shard_num =
avg_shard_server_num * _m_avg_local_shard_num +
std::ceil(last_server_shard_num * _shard_merge_rate);
LOG(INFO) << "merged shard info: [" << _m_sparse_table_shard_num << "|"
<< _m_avg_local_shard_num << "|" << _m_real_local_shard_num
<< "]";
_local_shards_new.reset(new shard_type[_real_local_shard_num]);
}
return 0; return 0;
} }
...@@ -93,8 +120,16 @@ int32_t MemorySparseTable::Load(const std::string& path, ...@@ -93,8 +120,16 @@ int32_t MemorySparseTable::Load(const std::string& path,
return -1; return -1;
} }
if (load_param == 5) {
return LoadPatch(file_list, load_param);
}
size_t file_start_idx = _shard_idx * _avg_local_shard_num; size_t file_start_idx = _shard_idx * _avg_local_shard_num;
if (file_start_idx >= file_list.size()) {
return 0;
}
size_t feature_value_size = size_t feature_value_size =
_value_accesor->GetAccessorInfo().size / sizeof(float); _value_accesor->GetAccessorInfo().size / sizeof(float);
...@@ -161,30 +196,37 @@ int32_t MemorySparseTable::Load(const std::string& path, ...@@ -161,30 +196,37 @@ int32_t MemorySparseTable::Load(const std::string& path,
return 0; return 0;
} }
int32_t MemorySparseTable::LoadLocalFS(const std::string& path, int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
const std::string& param) { int load_param) {
std::string table_path = TableDir(path); if (!_config.enable_revert()) {
auto file_list = paddle::framework::localfs_list(table_path); LOG(INFO) << "MemorySparseTable should be enabled revert.";
size_t expect_shard_num = _sparse_table_shard_num; return 0;
if (file_list.size() != expect_shard_num) {
LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
<< " not equal to expect_shard_num:" << expect_shard_num;
return -1;
}
if (file_list.size() == 0) {
LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
return -1;
} }
// 聚合分片数据索引
int start_idx = _shard_idx * _m_avg_local_shard_num;
int end_idx = start_idx + _m_real_local_shard_num;
// 原始分片数据索引
int o_start_idx = _shard_idx * _avg_local_shard_num;
int o_end_idx = o_start_idx + _real_local_shard_num;
size_t file_start_idx = _shard_idx * _avg_local_shard_num; if (start_idx >= file_list.size()) {
return 0;
}
size_t feature_value_size = size_t feature_value_size =
_value_accesor->GetAccessorInfo().size / sizeof(float); _value_accesor->GetAccessorInfo().size / sizeof(float);
end_idx =
end_idx < _m_sparse_table_shard_num ? end_idx : _m_sparse_table_shard_num;
int thread_num = (end_idx - start_idx) < 15 ? (end_idx - start_idx) : 15;
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < _real_local_shard_num; ++i) { for (size_t i = start_idx; i < end_idx; ++i) {
FsChannelConfig channel_config;
channel_config.path = file_list[i];
channel_config.converter = _value_accesor->Converter(load_param).converter;
channel_config.deconverter =
_value_accesor->Converter(load_param).deconverter;
bool is_read_failed = false; bool is_read_failed = false;
int retry_num = 0; int retry_num = 0;
int err_no = 0; int err_no = 0;
...@@ -192,31 +234,55 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, ...@@ -192,31 +234,55 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
is_read_failed = false; is_read_failed = false;
err_no = 0; err_no = 0;
std::string line_data; std::string line_data;
std::ifstream file(file_list[file_start_idx + i]); auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
char* end = NULL; char* end = NULL;
auto& shard = _local_shards[i]; int m_local_shard_id = i % _m_avg_local_shard_num;
std::unordered_set<size_t> global_shard_idx;
std::string global_shard_idx_str;
for (size_t j = o_start_idx; j < o_end_idx; ++j) {
if ((j % _avg_local_shard_num) % _m_real_local_shard_num ==
m_local_shard_id) {
global_shard_idx.insert(j);
global_shard_idx_str.append(std::to_string(j)).append(",");
}
}
try { try {
while (std::getline(file, line_data) && line_data.size() > 1) { while (read_channel->read_line(line_data) == 0 &&
line_data.size() > 1) {
uint64_t key = std::strtoul(line_data.data(), &end, 10); uint64_t key = std::strtoul(line_data.data(), &end, 10);
auto index_iter =
global_shard_idx.find(key % _sparse_table_shard_num);
if (index_iter == global_shard_idx.end()) {
LOG(WARNING) << "MemorySparseTable key:" << key
<< " not match shard,"
<< " file_idx:" << i
<< " global_shard_idx:" << global_shard_idx_str
<< " shard num:" << _sparse_table_shard_num
<< " file:" << channel_config.path;
continue;
}
size_t local_shard_idx = *index_iter % _avg_local_shard_num;
auto& shard = _local_shards[local_shard_idx];
auto& value = shard[key]; auto& value = shard[key];
value.resize(feature_value_size); value.resize(feature_value_size);
int parse_size = _value_accesor->ParseFromString(++end, value.data()); int parse_size = _value_accesor->ParseFromString(++end, value.data());
value.resize(parse_size); value.resize(parse_size);
} }
file.close(); read_channel->close();
if (err_no == -1) { if (err_no == -1) {
++retry_num; ++retry_num;
is_read_failed = true; is_read_failed = true;
LOG(ERROR) LOG(ERROR)
<< "MemorySparseTable load failed after read, retry it! path:" << "MemorySparseTable load failed after read, retry it! path:"
<< file_list[file_start_idx + i] << " , retry_num=" << retry_num; << channel_config.path << " , retry_num=" << retry_num;
} }
} catch (...) { } catch (...) {
++retry_num; ++retry_num;
is_read_failed = true; is_read_failed = true;
LOG(ERROR) << "MemorySparseTable load failed, retry it! path:" LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
<< file_list[file_start_idx + i] << channel_config.path << " , retry_num=" << retry_num;
<< " , retry_num=" << retry_num;
} }
if (retry_num > FLAGS_pserver_table_save_max_retry) { if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable load failed reach max limit!"; LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
...@@ -225,16 +291,44 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path, ...@@ -225,16 +291,44 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
} while (is_read_failed); } while (is_read_failed);
} }
LOG(INFO) << "MemorySparseTable load success, path from " LOG(INFO) << "MemorySparseTable load success, path from "
<< file_list[file_start_idx] << " to " << file_list[start_idx] << " to " << file_list[end_idx - 1];
<< file_list[file_start_idx + _real_local_shard_num - 1];
return 0; return 0;
} }
void MemorySparseTable::Revert() {
for (size_t i = 0; i < _real_local_shard_num; ++i) {
_local_shards_new[i].clear();
}
}
void MemorySparseTable::CheckSavePrePatchDone() {
_save_patch_model_thread.join();
}
int32_t MemorySparseTable::Save(const std::string& dirname, int32_t MemorySparseTable::Save(const std::string& dirname,
const std::string& param) { const std::string& param) {
if (_real_local_shard_num == 0) {
_local_show_threshold = -1;
return 0;
}
VLOG(0) << "MemorySparseTable::save dirname: " << dirname; VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
int save_param = int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
// patch model
if (save_param == 5) {
_local_shards_patch_model.reset(_local_shards_new.release());
_local_shards_new.reset(new shard_type[_real_local_shard_num]);
_save_patch_model_thread = std::thread(std::bind(
&MemorySparseTable::SavePatch, this, std::string(dirname), save_param));
return 0;
}
// cache model
int64_t tk_size = LocalSize() * _config.sparse_table_cache_rate();
TopkCalculator tk(_real_local_shard_num, tk_size);
std::string table_path = TableDir(dirname); std::string table_path = TableDir(dirname);
_afs_client.remove(paddle::string::format_string( _afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx)); "%s/part-%03d-*", table_path.c_str(), _shard_idx));
...@@ -274,6 +368,13 @@ int32_t MemorySparseTable::Save(const std::string& dirname, ...@@ -274,6 +368,13 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
auto write_channel = auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto it = shard.begin(); it != shard.end(); ++it) { for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_config.enable_sparse_table_cache() &&
(save_param == 1 || save_param == 2) &&
_value_accesor->Save(it.value().data(), 4)) {
CostTimer timer10("sprase table top push");
tk.push(i, _value_accesor->GetField(it.value().data(), "show"));
}
if (_value_accesor->Save(it.value().data(), save_param)) { if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->ParseToString( std::string format_value = _value_accesor->ParseToString(
it.value().data(), it.value().size()); it.value().data(), it.value().size());
...@@ -310,55 +411,266 @@ int32_t MemorySparseTable::Save(const std::string& dirname, ...@@ -310,55 +411,266 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
_value_accesor->UpdateStatAfterSave(it.value().data(), save_param); _value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
} }
LOG(INFO) << "MemorySparseTable save prefix success, path: " LOG(INFO) << "MemorySparseTable save prefix success, path: "
<< channel_config.path; << channel_config.path << " feasign_size: " << feasign_size;
} }
_local_show_threshold = tk.top();
// int32 may overflow need to change return value // int32 may overflow need to change return value
return 0; return 0;
} }
int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname, int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
const std::string& param, if (!_config.enable_revert()) {
const std::string& prefix) { LOG(INFO) << "MemorySparseTable should be enabled revert.";
int save_param = return 0;
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2 }
std::string table_path = TableDir(dirname); size_t file_start_idx = _m_avg_local_shard_num * _shard_idx;
int feasign_cnt = 0; std::string table_path = TableDir(path);
size_t file_start_idx = _avg_local_shard_num * _shard_idx; _afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx));
int thread_num = _m_real_local_shard_num < 20 ? _m_real_local_shard_num : 20;
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
std::atomic<uint32_t> feasign_size_all{0}; std::atomic<uint32_t> feasign_size_all{0};
omp_set_num_threads(thread_num); omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic) #pragma omp parallel for schedule(dynamic)
for (int i = 0; i < _real_local_shard_num; ++i) { for (size_t i = 0; i < _m_real_local_shard_num; ++i) {
feasign_cnt = 0; FsChannelConfig channel_config;
auto& shard = _local_shards[i]; channel_config.path = paddle::string::format_string("%s/part-%03d-%05d",
std::string file_name =
paddle::string::format_string("%s/part-%s-%03d-%05d",
table_path.c_str(), table_path.c_str(),
prefix.c_str(),
_shard_idx, _shard_idx,
file_start_idx + i); file_start_idx + i);
std::ofstream os;
os.open(file_name); channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter =
_value_accesor->Converter(save_param).deconverter;
bool is_write_failed = false;
int feasign_size = 0;
int retry_num = 0;
int err_no = 0;
do {
err_no = 0;
feasign_size = 0;
is_write_failed = false;
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (size_t j = 0; j < _real_local_shard_num; ++j) {
if (j % _m_real_local_shard_num == i) {
auto& shard = _local_shards_patch_model[j];
for (auto it = shard.begin(); it != shard.end(); ++it) { for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_value_accesor->Save(it.value().data(), save_param)) { if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = std::string format_value = _value_accesor->ParseToString(
_value_accesor->ParseToString(it.value().data(), it.value().size()); it.value().data(), it.value().size());
std::string out_line = paddle::string::format_string( if (0 != write_channel->write_line(paddle::string::format_string(
"%lu %s\n", it.key(), format_value.c_str()); "%lu %s", it.key(), format_value.c_str()))) {
// VLOG(2) << out_line.c_str(); ++retry_num;
os.write(out_line.c_str(), sizeof(char) * out_line.size()); is_write_failed = true;
++feasign_cnt; LOG(ERROR) << "MemorySparseTable save failed, retry it! path:"
<< channel_config.path
<< " , retry_num=" << retry_num;
break;
}
++feasign_size;
}
}
}
if (is_write_failed) break;
}
write_channel->close();
if (err_no == -1) {
++retry_num;
is_write_failed = true;
LOG(ERROR)
<< "MemorySparseTable save patch failed after write, retry it! "
<< "path:" << channel_config.path << " , retry_num=" << retry_num;
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable save patch failed reach max limit!";
exit(-1);
}
} while (is_write_failed);
feasign_size_all += feasign_size;
}
LOG(INFO) << "MemorySparseTable save patch success, path:"
<< paddle::string::format_string("%s/%03d/part-%03d-",
path.c_str(),
_config.table_id(),
_shard_idx)
<< " from " << file_start_idx << " to "
<< file_start_idx + _m_real_local_shard_num - 1
<< ", feasign size: " << feasign_size_all;
return 0;
}
int64_t MemorySparseTable::CacheShuffle(
const std::string& path,
const std::string& param,
double cache_threshold,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string& msg)> send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) {
LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold;
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
if (!_config.enable_sparse_table_cache() || cache_threshold < 0) {
LOG(WARNING)
<< "cache shuffle failed not enable table cache or cache threshold < 0 "
<< _config.enable_sparse_table_cache() << " or " << cache_threshold;
// return -1;
}
int shuffle_node_num = _config.sparse_table_cache_file_num();
LOG(INFO) << "Table>> shuffle node num is: " << shuffle_node_num;
// TODO(zhaocaibei123): check shuffle_node_num <= server_node_num
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
std::vector<
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>>
writers(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, std::string>>> datas(
_real_local_shard_num);
int feasign_size = 0;
std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>>
tmp_channels;
for (size_t i = 0; i < _real_local_shard_num; ++i) {
tmp_channels.push_back(
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>());
}
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer =
writers[i];
writer.Reset(tmp_channels[i].get());
for (size_t idx = 0; idx < table_ptrs.size(); idx++) {
Table* table_ptr = table_ptrs[idx];
auto value_accesor = table_ptr->ValueAccesor();
shard_type* shard_ptr = static_cast<shard_type*>(table_ptr->GetShard(i));
for (auto it = shard_ptr->begin(); it != shard_ptr->end(); ++it) {
if (value_accesor->SaveCache(
it.value().data(), save_param, cache_threshold)) {
std::string format_value = value_accesor->ParseToString(
it.value().data(), it.value().size());
std::pair<uint64_t, std::string> pkv(it.key(), format_value.c_str());
writer << pkv;
++feasign_size;
} }
} }
os.close();
LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name
<< "feasign_cnt: " << feasign_cnt;
} }
writer.Flush();
writer.channel()->Close();
}
// LOG(INFO) << "MemorySparseTable cache KV save success to Channel feasigh
// size: " << feasign_size << " and start sparse cache data shuffle real local
// shard num: " << _real_local_shard_num;
std::vector<std::pair<uint64_t, std::string>> local_datas;
for (size_t idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer =
writers[idx_shard];
auto channel = writer.channel();
std::vector<std::pair<uint64_t, std::string>>& data = datas[idx_shard];
std::vector<paddle::framework::BinaryArchive> ars(shuffle_node_num);
while (channel->Read(data)) {
for (auto& t : data) {
auto pserver_id =
paddle::distributed::local_random_engine()() % shuffle_node_num;
if (pserver_id != _shard_idx) {
ars[pserver_id] << t;
} else {
local_datas.emplace_back(std::move(t));
}
}
std::vector<std::future<int32_t>> total_status;
std::vector<uint32_t> send_data_size(shuffle_node_num, 0);
std::vector<int> send_index(shuffle_node_num);
for (int i = 0; i < shuffle_node_num; ++i) {
send_index[i] = i;
}
std::random_shuffle(send_index.begin(), send_index.end());
for (auto index = 0u; index < shuffle_node_num; ++index) {
int i = send_index[index];
if (i == _shard_idx) {
continue;
}
if (ars[i].Length() == 0) {
continue;
}
std::string msg(ars[i].Buffer(), ars[i].Length());
auto ret = send_msg_func(101, i, msg);
total_status.push_back(std::move(ret));
send_data_size[i] += ars[i].Length();
}
for (auto& t : total_status) {
t.wait();
}
ars.clear();
ars = std::vector<paddle::framework::BinaryArchive>(shuffle_node_num);
data = std::vector<std::pair<uint64_t, std::string>>();
}
}
shuffled_channel->Write(std::move(local_datas));
return 0; return 0;
} }
int32_t MemorySparseTable::SaveCache(
const std::string& path,
const std::string& param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) {
if (_shard_idx >= _config.sparse_table_cache_file_num()) {
return 0;
}
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
std::string table_path = paddle::string::format_string(
"%s/%03d_cache/", path.c_str(), _config.table_id());
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d", table_path.c_str(), _shard_idx));
uint32_t feasign_size = 0;
FsChannelConfig channel_config;
// not compress cache model
channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_path.c_str(), _shard_idx);
channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter =
_value_accesor->Converter(save_param).deconverter;
auto write_channel = _afs_client.open_w(channel_config, 1024 * 1024 * 40);
std::vector<std::pair<uint64_t, std::string>> data;
bool is_write_failed = false;
shuffled_channel->Close();
while (shuffled_channel->Read(data)) {
for (auto& t : data) {
++feasign_size;
if (0 != write_channel->write_line(paddle::string::format_string(
"%lu %s", t.first, t.second.c_str()))) {
LOG(ERROR) << "Cache Table save failed, "
"path:"
<< channel_config.path << ", retry it!";
is_write_failed = true;
break;
}
}
data = std::vector<std::pair<uint64_t, std::string>>();
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
write_channel->close();
LOG(INFO) << "MemorySparseTable cache save success, feasign: " << feasign_size
<< ", path: " << channel_config.path;
shuffled_channel->Open();
return feasign_size;
}
int64_t MemorySparseTable::LocalSize() { int64_t MemorySparseTable::LocalSize() {
int64_t local_size = 0; int64_t local_size = 0;
for (int i = 0; i < _real_local_shard_num; ++i) { for (int i = 0; i < _real_local_shard_num; ++i) {
...@@ -548,7 +860,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values, ...@@ -548,7 +860,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
ret = itr.value_ptr(); ret = itr.value_ptr();
} }
int pull_data_idx = keys[i].second; int pull_data_idx = keys[i].second;
pull_values[pull_data_idx] = (char*)ret; // NOLINT pull_values[pull_data_idx] = reinterpret_cast<char*>(ret);
} }
return 0; return 0;
}); });
...@@ -589,6 +901,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -589,6 +901,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
&task_keys]() -> int { &task_keys]() -> int {
auto& keys = task_keys[shard_id]; auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id]; auto& local_shard = _local_shards[shard_id];
auto& local_shard_new = _local_shards_new[shard_id];
float data_buffer[value_col]; // NOLINT float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer; float* data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) { for (size_t i = 0; i < keys.size(); ++i) {
...@@ -630,6 +943,14 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys, ...@@ -630,6 +943,14 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
} }
memcpy(value_data, data_buffer_ptr, value_size * sizeof(float)); memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
} }
if (_config.enable_revert()) {
FixedFeatureValue* feature_value_new = &(local_shard_new[key]);
auto new_size = feature_value.size();
feature_value_new->resize(new_size);
memcpy(feature_value_new->data(),
value_data,
new_size * sizeof(float));
}
} }
return 0; return 0;
}); });
......
...@@ -65,17 +65,25 @@ class MemorySparseTable : public Table { ...@@ -65,17 +65,25 @@ class MemorySparseTable : public Table {
int32_t InitializeShard() override { return 0; } int32_t InitializeShard() override { return 0; }
int32_t InitializeValue(); int32_t InitializeValue();
virtual int32_t Load(const std::string& path, int32_t Load(const std::string& path, const std::string& param) override;
const std::string& param) override;
virtual int32_t Save(const std::string& path, int32_t Save(const std::string& path, const std::string& param) override;
const std::string& param) override;
int32_t LoadLocalFS(const std::string& path, const std::string& param); int32_t SaveCache(
int32_t SaveLocalFS(const std::string& path, const std::string& path,
const std::string& param, const std::string& param,
const std::string& prefix); paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) override;
virtual double GetCacheThreshold() { return _local_show_threshold; }
int64_t CacheShuffle(
const std::string& path,
const std::string& param,
double cache_threshold,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string& msg)> send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) override;
int64_t LocalSize(); int64_t LocalSize();
int64_t LocalMFSize(); int64_t LocalMFSize();
...@@ -89,20 +97,38 @@ class MemorySparseTable : public Table { ...@@ -89,20 +97,38 @@ class MemorySparseTable : public Table {
int32_t PushSparse(const uint64_t* keys, const float** values, size_t num); int32_t PushSparse(const uint64_t* keys, const float** values, size_t num);
int32_t Flush() override; int32_t Flush() override;
virtual int32_t Shrink(const std::string& param) override; int32_t Shrink(const std::string& param) override;
void Clear() override; void Clear() override;
void* GetShard(size_t shard_idx) override { void* GetShard(size_t shard_idx) override {
return &_local_shards[shard_idx]; return &_local_shards[shard_idx];
} }
virtual void Revert();
virtual void CheckSavePrePatchDone();
protected: protected:
virtual int32_t SavePatch(const std::string& path, int save_param);
virtual int32_t LoadPatch(const std::vector<std::string>& file_list,
int save_param);
const int _task_pool_size = 24; const int _task_pool_size = 24;
int _avg_local_shard_num; int _avg_local_shard_num;
int _real_local_shard_num; int _real_local_shard_num;
int _sparse_table_shard_num; int _sparse_table_shard_num;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::unique_ptr<shard_type[]> _local_shards; std::unique_ptr<shard_type[]> _local_shards;
// for patch model
int _m_avg_local_shard_num;
int _m_real_local_shard_num;
int _m_sparse_table_shard_num;
float _shard_merge_rate{1.0f};
double _local_show_threshold{0.0};
std::unique_ptr<shard_type[]> _local_shards_new;
std::unique_ptr<shard_type[]> _local_shards_patch_model;
std::thread _save_patch_model_thread;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -71,8 +71,8 @@ class Table { ...@@ -71,8 +71,8 @@ class Table {
virtual int32_t Initialize(const TableParameter &config, virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config); const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0; virtual int32_t Pull(TableContext &context) = 0; // NOLINT
virtual int32_t Push(TableContext &context) = 0; virtual int32_t Push(TableContext &context) = 0; // NOLINT
// only for barrier // only for barrier
virtual int32_t Barrier(const uint32_t trainer_id, virtual int32_t Barrier(const uint32_t trainer_id,
...@@ -125,7 +125,8 @@ class Table { ...@@ -125,7 +125,8 @@ class Table {
const std::string &param, const std::string &param,
double cache_threshold, double cache_threshold,
std::function<std::future<int32_t>( std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string &msg)> send_msg_func, int msg_type, int to_pserver_id, std::string &msg)> // NOLINT
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>> paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel, &shuffled_channel,
const std::vector<Table *> &table_ptrs) { const std::vector<Table *> &table_ptrs) {
...@@ -147,6 +148,10 @@ class Table { ...@@ -147,6 +148,10 @@ class Table {
virtual void *GetShard(size_t shard_idx) = 0; virtual void *GetShard(size_t shard_idx) = 0;
virtual std::pair<int64_t, int64_t> PrintTableStat() { return {0, 0}; } virtual std::pair<int64_t, int64_t> PrintTableStat() { return {0, 0}; }
// for patch model
virtual void Revert() {}
virtual void CheckSavePrePatchDone() {}
protected: protected:
virtual int32_t Initialize() = 0; virtual int32_t Initialize() = 0;
virtual int32_t InitializeAccessor(); virtual int32_t InitializeAccessor();
......
...@@ -853,6 +853,24 @@ int32_t FleetWrapper::SaveCache(int table_id, ...@@ -853,6 +853,24 @@ int32_t FleetWrapper::SaveCache(int table_id,
return feasign_cnt; return feasign_cnt;
} }
void FleetWrapper::Revert() {
auto ret = worker_ptr_->Revert();
ret.wait();
if (ret.get() == -1) {
LOG(ERROR) << "table revert failed";
exit(-1);
}
}
void FleetWrapper::CheckSavePrePatchDone() {
auto ret = worker_ptr_->CheckSavePrePatchDone();
ret.wait();
if (ret.get() == -1) {
LOG(ERROR) << "table revert failed";
exit(-1);
}
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() { std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t { struct engine_wrapper_t {
std::default_random_engine engine; std::default_random_engine engine;
......
...@@ -300,6 +300,8 @@ class FleetWrapper { ...@@ -300,6 +300,8 @@ class FleetWrapper {
const int mode, const int mode,
const double cache_threshold); const double cache_threshold);
int32_t SaveCache(int table_id, const std::string& path, const int mode); int32_t SaveCache(int table_id, const std::string& path, const int mode);
void Revert();
void CheckSavePrePatchDone();
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_; static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_; static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
......
...@@ -150,9 +150,6 @@ TEST(MemorySparseTable, SGD) { ...@@ -150,9 +150,6 @@ TEST(MemorySparseTable, SGD) {
VLOG(3) << update_val << ": " << pull_values[i * (emb_dim + 1) + j]; VLOG(3) << update_val << ": " << pull_values[i * (emb_dim + 1) + j];
} }
} }
MemorySparseTable *ctr_table = dynamic_cast<MemorySparseTable *>(table);
ctr_table->SaveLocalFS("./work/table.save", "0", "test");
} }
} // namespace distributed } // namespace distributed
......
...@@ -114,12 +114,14 @@ message TableParameter { ...@@ -114,12 +114,14 @@ message TableParameter {
optional TensorAccessorParameter tensor = 5; optional TensorAccessorParameter tensor = 5;
optional CommonAccessorParameter common = 6; optional CommonAccessorParameter common = 6;
optional TableType type = 7; optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ]; optional bool compress_in_save = 8 [ default = true ];
optional GraphParameter graph_parameter = 9; optional GraphParameter graph_parameter = 9;
// for cache model // for cache model
optional bool enable_sparse_table_cache = 10 [ default = true ]; optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
optional bool enable_revert = 13 [ default = true ];
optional float shard_merge_rate = 14 [ default = 1.0 ];
} }
message TableAccessorParameter { message TableAccessorParameter {
......
...@@ -75,7 +75,9 @@ void BindDistFleetWrapper(py::module* m) { ...@@ -75,7 +75,9 @@ void BindDistFleetWrapper(py::module* m) {
.def("client_flush", &FleetWrapper::ClientFlush) .def("client_flush", &FleetWrapper::ClientFlush)
.def("get_cache_threshold", &FleetWrapper::GetCacheThreshold) .def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &FleetWrapper::CacheShuffle) .def("cache_shuffle", &FleetWrapper::CacheShuffle)
.def("save_cache", &FleetWrapper::SaveCache); .def("save_cache", &FleetWrapper::SaveCache)
.def("revert", &FleetWrapper::Revert)
.def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone);
} }
void BindPSHost(py::module* m) { void BindPSHost(py::module* m) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册