未验证 提交 f382eb06 编写于 作者: Z zhaocaibei123 提交者: GitHub

add save_cache/patch (#44420)

* add save_cache/patch

* add pybind

* remove pybind

* remove const_cast

* add fleet
上级 2792b8de
...@@ -482,7 +482,7 @@ std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id, ...@@ -482,7 +482,7 @@ std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
request_call_num, request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) { [request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0; int ret = 0;
auto *closure = (DownpourBrpcClosure *)done; auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
std::vector<double> cache_thresholds(request_call_num, 0); std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) { if (closure->check_response(i, cmd_id) != 0) {
...@@ -530,6 +530,14 @@ std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) { ...@@ -530,6 +530,14 @@ std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) {
return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {}); return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {});
} }
std::future<int32_t> BrpcPsClient::Revert() {
return SendCmd(-1, PS_REVERT, {});
}
std::future<int32_t> BrpcPsClient::CheckSavePrePatchDone() {
return SendCmd(-1, PS_CHECK_SAVE_PRE_PATCH_DONE, {});
}
std::future<int32_t> BrpcPsClient::Flush() { std::future<int32_t> BrpcPsClient::Flush() {
VLOG(0) << "BrpcPsClient::flush begin"; VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true; _flushing = true;
...@@ -1170,6 +1178,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values, ...@@ -1170,6 +1178,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
} }
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
closure->add_timer(timer); closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>(); auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise); closure->add_promise(promise);
......
...@@ -178,6 +178,9 @@ class BrpcPsClient : public PSClient { ...@@ -178,6 +178,9 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> Clear(uint32_t table_id) override; std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> Revert() override;
std::future<int32_t> CheckSavePrePatchDone() override;
std::future<int32_t> StopServer() override; std::future<int32_t> StopServer() override;
std::future<int32_t> StartProfiler() override; std::future<int32_t> StartProfiler() override;
...@@ -298,16 +301,16 @@ class BrpcPsClient : public PSClient { ...@@ -298,16 +301,16 @@ class BrpcPsClient : public PSClient {
int PushSparseAsyncShardMerge( int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, std::vector<int> &request_kv_num, // NOLINT
int table_id, int table_id,
int shard_idx, // NOLINT int shard_idx,
ValueAccessor *accessor); ValueAccessor *accessor);
int PushSparseAsyncShardPush( int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num, std::vector<int> &request_kv_num, // NOLINT
int table_id, int table_id,
int shard_idx, // NOLINT int shard_idx,
DownpourBrpcClosure *closure, DownpourBrpcClosure *closure,
ValueAccessor *accessor); ValueAccessor *accessor);
......
...@@ -146,7 +146,7 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg( ...@@ -146,7 +146,7 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
return fut; return fut;
} }
auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) { auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourPServerBrpcClosure *)done; auto *closure = reinterpret_cast<DownpourPServerBrpcClosure *>(done);
int32_t ret = closure->check_response(0, msg_type + 1000); int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret); closure->set_promise_value(ret);
}); });
...@@ -209,13 +209,16 @@ int32_t BrpcPsService::Initialize() { ...@@ -209,13 +209,16 @@ int32_t BrpcPsService::Initialize() {
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler; _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep; _service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
// for save cache // for save cache
_service_handler_map[PS_SAVE_ONE_CACHE_TABLE] = _service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
&BrpcPsService::SaveCacheTable; &BrpcPsService::SaveCacheTable;
_service_handler_map[PS_GET_CACHE_THRESHOLD] = _service_handler_map[PS_GET_CACHE_THRESHOLD] =
&BrpcPsService::GetCacheThreshold; &BrpcPsService::GetCacheThreshold;
_service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle; _service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;
_service_handler_map[PS_REVERT] = &BrpcPsService::Revert;
_service_handler_map[PS_CHECK_SAVE_PRE_PATCH_DONE] =
&BrpcPsService::CheckSavePrePatchDone;
auto &profiler = CostProfiler::instance(); auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense"); profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense"); profiler.register_profiler("pserver_server_push_dense");
...@@ -319,9 +322,8 @@ int32_t BrpcPsService::PullDense(Table *table, ...@@ -319,9 +322,8 @@ int32_t BrpcPsService::PullDense(Table *table,
table_context.pull_context.values = res_data->data(); table_context.pull_context.values = res_data->data();
table_context.num = num; table_context.num = num;
table->Pull(table_context); table->Pull(table_context);
// table->PullDense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()), cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float)); res_data->size() * sizeof(float));
butil::return_object(res_data); butil::return_object(res_data);
...@@ -356,7 +358,6 @@ int32_t BrpcPsService::PushDenseParam(Table *table, ...@@ -356,7 +358,6 @@ int32_t BrpcPsService::PushDenseParam(Table *table,
table_context.push_context.is_param = true; table_context.push_context.is_param = true;
table_context.num = num; table_context.num = num;
// if (table->PushDenseParam(values, num) != 0) {
if (table->Push(table_context) != 0) { if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushDenseParam failed"); set_response_code(response, -1, "PushDenseParam failed");
} }
...@@ -438,7 +439,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table, ...@@ -438,7 +439,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table,
"least 1 for num of sparse_key"); "least 1 for num of sparse_key");
return 0; return 0;
} }
uint32_t num = *(uint32_t *)(request.params(0).c_str()); const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/* /*
Push Content: Push Content:
|---keysData---|---valuesData---| |---keysData---|---valuesData---|
...@@ -484,10 +486,11 @@ int32_t BrpcPsService::PullGeoParam(Table *table, ...@@ -484,10 +486,11 @@ int32_t BrpcPsService::PullGeoParam(Table *table,
// table->PullGeoParam(trainer_id, &values, &ids); // table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size(); uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t)); cntl->response_attachment().append(reinterpret_cast<char *>(&num),
cntl->response_attachment().append((char *)ids.data(), sizeof(uint32_t));
cntl->response_attachment().append(reinterpret_cast<char *>(ids.data()),
ids.size() * sizeof(uint64_t)); ids.size() * sizeof(uint64_t));
cntl->response_attachment().append((char *)values.data(), cntl->response_attachment().append(reinterpret_cast<char *>(values.data()),
values.size() * sizeof(float)); values.size() * sizeof(float));
return 0; return 0;
} }
...@@ -517,7 +520,8 @@ int32_t BrpcPsService::PullSparse(Table *table, ...@@ -517,7 +520,8 @@ int32_t BrpcPsService::PullSparse(Table *table,
} }
CostTimer timer("pserver_server_pull_sparse"); CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim; auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
thread_local std::string req_buffer; thread_local std::string req_buffer;
...@@ -539,7 +543,7 @@ int32_t BrpcPsService::PullSparse(Table *table, ...@@ -539,7 +543,7 @@ int32_t BrpcPsService::PullSparse(Table *table,
table->Pull(table_context); table->Pull(table_context);
// table->PullSparse(res_data->data(), value); // table->PullSparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()), cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float)); res_data->size() * sizeof(float));
butil::return_object(res_data); butil::return_object(res_data);
return 0; return 0;
...@@ -565,7 +569,8 @@ int32_t BrpcPsService::PushSparse(Table *table, ...@@ -565,7 +569,8 @@ int32_t BrpcPsService::PushSparse(Table *table,
return 0; return 0;
} }
CostTimer timer("pserver_server_push_sparse"); CostTimer timer("pserver_server_push_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str()); const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/* /*
Push Content: Push Content:
|---keysData---|---valuesData---| |---keysData---|---valuesData---|
...@@ -767,6 +772,29 @@ int32_t BrpcPsService::GetCacheThreshold(Table *table, ...@@ -767,6 +772,29 @@ int32_t BrpcPsService::GetCacheThreshold(Table *table,
return 0; return 0;
} }
int32_t BrpcPsService::Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->Flush();
itr.second->Revert();
}
return 0;
}
int32_t BrpcPsService::CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->CheckSavePrePatchDone();
}
return 0;
}
int32_t BrpcPsService::ShrinkTable(Table *table, int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
......
...@@ -53,10 +53,10 @@ class BrpcPsServer : public PSServer { ...@@ -53,10 +53,10 @@ class BrpcPsServer : public PSServer {
} }
int32_t Port(); int32_t Port();
virtual int32_t StartS2S() override; int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg( ::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override; int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type, int32_t ReceiveFromPServer(int msg_type,
int pserver_id, int pserver_id,
const std::string &msg) override; const std::string &msg) override;
...@@ -75,14 +75,14 @@ class BrpcPsService; ...@@ -75,14 +75,14 @@ class BrpcPsService;
typedef int32_t (BrpcPsService::*serviceHandlerFunc)( typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table, Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
class BrpcPsService : public PsBaseService { class BrpcPsService : public PsBaseService {
public: public:
virtual int32_t Initialize() override; int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller, void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request, const PsRequestMessage *request,
PsResponseMessage *response, PsResponseMessage *response,
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
...@@ -91,100 +91,110 @@ class BrpcPsService : public PsBaseService { ...@@ -91,100 +91,110 @@ class BrpcPsService : public PsBaseService {
int32_t InitializeShardInfo(); int32_t InitializeShardInfo();
int32_t PullDense(Table *table, int32_t PullDense(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushDense(Table *table, int32_t PushDense(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushDenseParam(Table *table, int32_t PushDenseParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushSparseParam(Table *table, int32_t PushSparseParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PullSparse(Table *table, int32_t PullSparse(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PullGeoParam(Table *table, int32_t PullGeoParam(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t Barrier(Table *table, int32_t Barrier(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushSparse(Table *table, int32_t PushSparse(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t LoadOneTable(Table *table, int32_t LoadOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t LoadAllTable(Table *table, int32_t LoadAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t SaveOneTable(Table *table, int32_t SaveOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t SaveAllTable(Table *table, int32_t SaveAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ShrinkTable(Table *table, int32_t ShrinkTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ClearOneTable(Table *table, int32_t ClearOneTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t ClearAllTable(Table *table, int32_t ClearAllTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t StopServer(Table *table, int32_t StopServer(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t StartProfiler(Table *table, int32_t StartProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t StopProfiler(Table *table, int32_t StopProfiler(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PrintTableStat(Table *table, int32_t PrintTableStat(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table, int32_t PushGlobalStep(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t CacheShuffle(Table *table, int32_t CacheShuffle(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table, int32_t SaveCacheTable(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table, int32_t GetCacheThreshold(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl); brpc::Controller *cntl);
bool _is_initialize_shard_info; bool _is_initialize_shard_info;
...@@ -208,7 +218,7 @@ class DownpourPServerBrpcClosure : public PServerClosure { ...@@ -208,7 +218,7 @@ class DownpourPServerBrpcClosure : public PServerClosure {
} }
virtual ~DownpourPServerBrpcClosure() {} virtual ~DownpourPServerBrpcClosure() {}
virtual void Run() override { void Run() override {
if (_waiting_num.fetch_sub(1) == 1) { if (_waiting_num.fetch_sub(1) == 1) {
_callback(this); _callback(this);
delete this; delete this;
......
...@@ -67,12 +67,12 @@ class PSClient { ...@@ -67,12 +67,12 @@ class PSClient {
PSClient(PSClient &&) = delete; PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete; PSClient(const PSClient &) = delete;
virtual int32_t Configure( // NOLINT virtual int32_t Configure(
const PSParameter &config, const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>> const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions, &regions,
PSEnvironment &_env, PSEnvironment &_env, // NOLINT
size_t client_id) final; // NOLINT size_t client_id) final;
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms, virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms, int pserver_connect_timeout_ms,
...@@ -293,8 +293,25 @@ class PSClient { ...@@ -293,8 +293,25 @@ class PSClient {
return fut; return fut;
} }
virtual std::future<int32_t> GetCacheThreshold(uint32_t table_id, virtual std::future<int32_t> GetCacheThreshold(
double &cache_threshold) { uint32_t table_id,
double &cache_threshold) { // NOLINT
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> Revert() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CheckSavePrePatchDone() {
VLOG(0) << "Did not implement"; VLOG(0) << "Did not implement";
std::promise<int32_t> promise; std::promise<int32_t> promise;
std::future<int> fut = promise.get_future(); std::future<int> fut = promise.get_future();
......
...@@ -65,6 +65,8 @@ enum PsCmdID { ...@@ -65,6 +65,8 @@ enum PsCmdID {
PS_SAVE_WITH_SHARD = 44; PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45; PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46; PS_QUERY_WITH_SHARD = 46;
PS_REVERT = 47;
PS_CHECK_SAVE_PRE_PATCH_DONE = 48;
// pserver2pserver cmd start from 100 // pserver2pserver cmd start from 100
PS_S2S_MSG = 101; PS_S2S_MSG = 101;
} }
......
...@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer { ...@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer {
} }
float* summary_decay_rate; float* summary_decay_rate;
double summary_decay_rate_d = 0.999999; double summary_decay_rate_d = 0.9999999;
float* param; float* param;
}; };
......
...@@ -339,34 +339,37 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -339,34 +339,37 @@ int32_t MemoryDenseTable::Save(const std::string& path,
_value_accesor->Converter(save_param).deconverter; _value_accesor->Converter(save_param).deconverter;
bool is_write_failed = false; bool is_write_failed = false;
std::vector<std::vector<std::string>> result_buffer_param( std::vector<std::string> result_buffer_param;
param_dim_, std::vector<std::string>()); result_buffer_param.reserve(param_dim_);
std::vector<std::string> result_buffer_fixed_len;
result_buffer_fixed_len.reserve(fixed_len_params_dim_);
auto common = _config.common(); auto common = _config.common();
int size = static_cast<int>(common.params().size()); int size = static_cast<int>(common.params().size());
if (_config.common().name() == "summary") { if (_config.common().name() == "summary") {
for (int x = 0; x < param_dim_; ++x) { for (int x = 0; x < param_dim_; ++x) {
result_buffer_param[x].emplace_back( result_buffer_param.emplace_back(std::to_string(values_[param_idx_][x]));
std::to_string(values_[param_idx_][x]));
} }
} else if (_config.common().name() == "adam_d2sum") {
} else {
std::ostringstream os; std::ostringstream os;
for (int x = 0; x < size; ++x) { for (int y = 0; y < param_dim_; ++y) {
int dim = common.dims()[x];
VLOG(3) << "MemoryDenseTable::save dim " << x << " size: " << dim;
for (int y = 0; y < dim; ++y) {
os.clear(); os.clear();
os.str(""); os.str("");
os << values_[x][y]; os << values_[param_col_ids_[0]][y] << " 0";
if (dim == param_dim_) { for (int x = 2; x < param_col_ids_.size(); ++x) {
result_buffer_param[y].emplace_back(std::move(os.str())); os << " ";
} else { os << values_[param_col_ids_[x]][y];
result_buffer_fixed_len.emplace_back(std::move(os.str()));
} }
result_buffer_param.emplace_back(std::move(os.str()));
} }
} else {
std::ostringstream os;
for (int y = 0; y < param_dim_; ++y) {
os.clear();
os.str("");
os << values_[param_col_ids_[0]][y];
for (int x = 1; x < param_col_ids_.size(); ++x) {
os << " ";
os << values_[param_col_ids_[x]][y];
}
result_buffer_param.emplace_back(std::move(os.str()));
} }
} }
...@@ -379,12 +382,9 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -379,12 +382,9 @@ int32_t MemoryDenseTable::Save(const std::string& path,
// 40M // 40M
auto write_channel = auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no); _afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto& t : result_buffer_param) { for (auto& t : result_buffer_param) {
if (_config.common().name() == "adam_d2sum") { if (0 != write_channel->write_line(t)) {
t.insert(t.begin() + 1, "0"); // avg_w
}
if (0 !=
write_channel->write_line(paddle::string::join_strings(t, ' '))) {
++retry_num; ++retry_num;
is_write_failed = true; is_write_failed = true;
LOG(ERROR) << "DownpourDenseTable save failed, retry it! " LOG(ERROR) << "DownpourDenseTable save failed, retry it! "
...@@ -395,6 +395,7 @@ int32_t MemoryDenseTable::Save(const std::string& path, ...@@ -395,6 +395,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
} }
++feasign_size; ++feasign_size;
VLOG(3) << "save begin close " << channel_config.path;
write_channel->close(); write_channel->close();
if (err_no == -1) { if (err_no == -1) {
++retry_num; ++retry_num;
......
...@@ -65,17 +65,25 @@ class MemorySparseTable : public Table { ...@@ -65,17 +65,25 @@ class MemorySparseTable : public Table {
int32_t InitializeShard() override { return 0; } int32_t InitializeShard() override { return 0; }
int32_t InitializeValue(); int32_t InitializeValue();
virtual int32_t Load(const std::string& path, int32_t Load(const std::string& path, const std::string& param) override;
const std::string& param) override;
virtual int32_t Save(const std::string& path, int32_t Save(const std::string& path, const std::string& param) override;
const std::string& param) override;
int32_t LoadLocalFS(const std::string& path, const std::string& param); int32_t SaveCache(
int32_t SaveLocalFS(const std::string& path, const std::string& path,
const std::string& param, const std::string& param,
const std::string& prefix); paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) override;
virtual double GetCacheThreshold() { return _local_show_threshold; }
int64_t CacheShuffle(
const std::string& path,
const std::string& param,
double cache_threshold,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string& msg)> send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) override;
int64_t LocalSize(); int64_t LocalSize();
int64_t LocalMFSize(); int64_t LocalMFSize();
...@@ -89,20 +97,38 @@ class MemorySparseTable : public Table { ...@@ -89,20 +97,38 @@ class MemorySparseTable : public Table {
int32_t PushSparse(const uint64_t* keys, const float** values, size_t num); int32_t PushSparse(const uint64_t* keys, const float** values, size_t num);
int32_t Flush() override; int32_t Flush() override;
virtual int32_t Shrink(const std::string& param) override; int32_t Shrink(const std::string& param) override;
void Clear() override; void Clear() override;
void* GetShard(size_t shard_idx) override { void* GetShard(size_t shard_idx) override {
return &_local_shards[shard_idx]; return &_local_shards[shard_idx];
} }
virtual void Revert();
virtual void CheckSavePrePatchDone();
protected: protected:
virtual int32_t SavePatch(const std::string& path, int save_param);
virtual int32_t LoadPatch(const std::vector<std::string>& file_list,
int save_param);
const int _task_pool_size = 24; const int _task_pool_size = 24;
int _avg_local_shard_num; int _avg_local_shard_num;
int _real_local_shard_num; int _real_local_shard_num;
int _sparse_table_shard_num; int _sparse_table_shard_num;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool; std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::unique_ptr<shard_type[]> _local_shards; std::unique_ptr<shard_type[]> _local_shards;
// for patch model
int _m_avg_local_shard_num;
int _m_real_local_shard_num;
int _m_sparse_table_shard_num;
float _shard_merge_rate{1.0f};
double _local_show_threshold{0.0};
std::unique_ptr<shard_type[]> _local_shards_new;
std::unique_ptr<shard_type[]> _local_shards_patch_model;
std::thread _save_patch_model_thread;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -71,8 +71,8 @@ class Table { ...@@ -71,8 +71,8 @@ class Table {
virtual int32_t Initialize(const TableParameter &config, virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config); const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0; virtual int32_t Pull(TableContext &context) = 0; // NOLINT
virtual int32_t Push(TableContext &context) = 0; virtual int32_t Push(TableContext &context) = 0; // NOLINT
// only for barrier // only for barrier
virtual int32_t Barrier(const uint32_t trainer_id, virtual int32_t Barrier(const uint32_t trainer_id,
...@@ -125,7 +125,8 @@ class Table { ...@@ -125,7 +125,8 @@ class Table {
const std::string &param, const std::string &param,
double cache_threshold, double cache_threshold,
std::function<std::future<int32_t>( std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string &msg)> send_msg_func, int msg_type, int to_pserver_id, std::string &msg)> // NOLINT
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>> paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel, &shuffled_channel,
const std::vector<Table *> &table_ptrs) { const std::vector<Table *> &table_ptrs) {
...@@ -147,6 +148,10 @@ class Table { ...@@ -147,6 +148,10 @@ class Table {
virtual void *GetShard(size_t shard_idx) = 0; virtual void *GetShard(size_t shard_idx) = 0;
virtual std::pair<int64_t, int64_t> PrintTableStat() { return {0, 0}; } virtual std::pair<int64_t, int64_t> PrintTableStat() { return {0, 0}; }
// for patch model
virtual void Revert() {}
virtual void CheckSavePrePatchDone() {}
protected: protected:
virtual int32_t Initialize() = 0; virtual int32_t Initialize() = 0;
virtual int32_t InitializeAccessor(); virtual int32_t InitializeAccessor();
......
...@@ -853,6 +853,24 @@ int32_t FleetWrapper::SaveCache(int table_id, ...@@ -853,6 +853,24 @@ int32_t FleetWrapper::SaveCache(int table_id,
return feasign_cnt; return feasign_cnt;
} }
void FleetWrapper::Revert() {
auto ret = worker_ptr_->Revert();
ret.wait();
if (ret.get() == -1) {
LOG(ERROR) << "table revert failed";
exit(-1);
}
}
void FleetWrapper::CheckSavePrePatchDone() {
auto ret = worker_ptr_->CheckSavePrePatchDone();
ret.wait();
if (ret.get() == -1) {
LOG(ERROR) << "table revert failed";
exit(-1);
}
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() { std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t { struct engine_wrapper_t {
std::default_random_engine engine; std::default_random_engine engine;
......
...@@ -300,6 +300,8 @@ class FleetWrapper { ...@@ -300,6 +300,8 @@ class FleetWrapper {
const int mode, const int mode,
const double cache_threshold); const double cache_threshold);
int32_t SaveCache(int table_id, const std::string& path, const int mode); int32_t SaveCache(int table_id, const std::string& path, const int mode);
void Revert();
void CheckSavePrePatchDone();
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_; static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_; static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
......
...@@ -150,9 +150,6 @@ TEST(MemorySparseTable, SGD) { ...@@ -150,9 +150,6 @@ TEST(MemorySparseTable, SGD) {
VLOG(3) << update_val << ": " << pull_values[i * (emb_dim + 1) + j]; VLOG(3) << update_val << ": " << pull_values[i * (emb_dim + 1) + j];
} }
} }
MemorySparseTable *ctr_table = dynamic_cast<MemorySparseTable *>(table);
ctr_table->SaveLocalFS("./work/table.save", "0", "test");
} }
} // namespace distributed } // namespace distributed
......
...@@ -114,12 +114,14 @@ message TableParameter { ...@@ -114,12 +114,14 @@ message TableParameter {
optional TensorAccessorParameter tensor = 5; optional TensorAccessorParameter tensor = 5;
optional CommonAccessorParameter common = 6; optional CommonAccessorParameter common = 6;
optional TableType type = 7; optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ]; optional bool compress_in_save = 8 [ default = true ];
optional GraphParameter graph_parameter = 9; optional GraphParameter graph_parameter = 9;
// for cache model // for cache model
optional bool enable_sparse_table_cache = 10 [ default = true ]; optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ]; optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ]; optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
optional bool enable_revert = 13 [ default = true ];
optional float shard_merge_rate = 14 [ default = 1.0 ];
} }
message TableAccessorParameter { message TableAccessorParameter {
......
...@@ -75,7 +75,9 @@ void BindDistFleetWrapper(py::module* m) { ...@@ -75,7 +75,9 @@ void BindDistFleetWrapper(py::module* m) {
.def("client_flush", &FleetWrapper::ClientFlush) .def("client_flush", &FleetWrapper::ClientFlush)
.def("get_cache_threshold", &FleetWrapper::GetCacheThreshold) .def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &FleetWrapper::CacheShuffle) .def("cache_shuffle", &FleetWrapper::CacheShuffle)
.def("save_cache", &FleetWrapper::SaveCache); .def("save_cache", &FleetWrapper::SaveCache)
.def("revert", &FleetWrapper::Revert)
.def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone);
} }
void BindPSHost(py::module* m) { void BindPSHost(py::module* m) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册