未验证 提交 f382eb06 编写于 作者: Z zhaocaibei123 提交者: GitHub

add save_cache/patch (#44420)

* add save_cache/patch

* add pybind

* remove pybind

* remove const_cast

* add fleet
上级 2792b8de
......@@ -482,7 +482,7 @@ std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
......@@ -530,6 +530,14 @@ std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) {
return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {});
}
std::future<int32_t> BrpcPsClient::Revert() {
return SendCmd(-1, PS_REVERT, {});
}
std::future<int32_t> BrpcPsClient::CheckSavePrePatchDone() {
return SendCmd(-1, PS_CHECK_SAVE_PRE_PATCH_DONE, {});
}
std::future<int32_t> BrpcPsClient::Flush() {
VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true;
......@@ -1170,6 +1178,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
......
......@@ -178,6 +178,9 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> Revert() override;
std::future<int32_t> CheckSavePrePatchDone() override;
std::future<int32_t> StopServer() override;
std::future<int32_t> StartProfiler() override;
......@@ -298,16 +301,16 @@ class BrpcPsClient : public PSClient {
int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num,
std::vector<int> &request_kv_num, // NOLINT
int table_id,
int shard_idx, // NOLINT
int shard_idx,
ValueAccessor *accessor);
int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num,
std::vector<int> &request_kv_num, // NOLINT
int table_id,
int shard_idx, // NOLINT
int shard_idx,
DownpourBrpcClosure *closure,
ValueAccessor *accessor);
......
......@@ -146,7 +146,7 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
return fut;
}
auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourPServerBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourPServerBrpcClosure *>(done);
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
......@@ -209,13 +209,16 @@ int32_t BrpcPsService::Initialize() {
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
// for save cache
_service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
&BrpcPsService::SaveCacheTable;
_service_handler_map[PS_GET_CACHE_THRESHOLD] =
&BrpcPsService::GetCacheThreshold;
_service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;
_service_handler_map[PS_REVERT] = &BrpcPsService::Revert;
_service_handler_map[PS_CHECK_SAVE_PRE_PATCH_DONE] =
&BrpcPsService::CheckSavePrePatchDone;
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense");
......@@ -319,9 +322,8 @@ int32_t BrpcPsService::PullDense(Table *table,
table_context.pull_context.values = res_data->data();
table_context.num = num;
table->Pull(table_context);
// table->PullDense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()),
cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float));
butil::return_object(res_data);
......@@ -356,7 +358,6 @@ int32_t BrpcPsService::PushDenseParam(Table *table,
table_context.push_context.is_param = true;
table_context.num = num;
// if (table->PushDenseParam(values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushDenseParam failed");
}
......@@ -438,7 +439,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table,
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/*
Push Content:
|---keysData---|---valuesData---|
......@@ -484,10 +486,11 @@ int32_t BrpcPsService::PullGeoParam(Table *table,
// table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
cntl->response_attachment().append((char *)ids.data(),
cntl->response_attachment().append(reinterpret_cast<char *>(&num),
sizeof(uint32_t));
cntl->response_attachment().append(reinterpret_cast<char *>(ids.data()),
ids.size() * sizeof(uint64_t));
cntl->response_attachment().append((char *)values.data(),
cntl->response_attachment().append(reinterpret_cast<char *>(values.data()),
values.size() * sizeof(float));
return 0;
}
......@@ -517,7 +520,8 @@ int32_t BrpcPsService::PullSparse(Table *table,
}
CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
thread_local std::string req_buffer;
......@@ -539,7 +543,7 @@ int32_t BrpcPsService::PullSparse(Table *table,
table->Pull(table_context);
// table->PullSparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()),
cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float));
butil::return_object(res_data);
return 0;
......@@ -565,7 +569,8 @@ int32_t BrpcPsService::PushSparse(Table *table,
return 0;
}
CostTimer timer("pserver_server_push_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/*
Push Content:
|---keysData---|---valuesData---|
......@@ -767,6 +772,29 @@ int32_t BrpcPsService::GetCacheThreshold(Table *table,
return 0;
}
int32_t BrpcPsService::Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->Flush();
itr.second->Revert();
}
return 0;
}
int32_t BrpcPsService::CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->CheckSavePrePatchDone();
}
return 0;
}
int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
......
......@@ -53,12 +53,12 @@ class BrpcPsServer : public PSServer {
}
int32_t Port();
virtual int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int32_t StartS2S() override;
::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) override;
int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) override;
private:
virtual int32_t Initialize();
......@@ -75,118 +75,128 @@ class BrpcPsService;
typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
class BrpcPsService : public PsBaseService {
public:
virtual int32_t Initialize() override;
int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:
int32_t InitializeShardInfo();
int32_t PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushDenseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushSparseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PullGeoParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
......@@ -208,7 +218,7 @@ class DownpourPServerBrpcClosure : public PServerClosure {
}
virtual ~DownpourPServerBrpcClosure() {}
virtual void Run() override {
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
......
......@@ -67,12 +67,12 @@ class PSClient {
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t Configure( // NOLINT
virtual int32_t Configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env,
size_t client_id) final; // NOLINT
PSEnvironment &_env, // NOLINT
size_t client_id) final;
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
......@@ -293,8 +293,25 @@ class PSClient {
return fut;
}
virtual std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
virtual std::future<int32_t> GetCacheThreshold(
uint32_t table_id,
double &cache_threshold) { // NOLINT
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> Revert() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CheckSavePrePatchDone() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
......
......@@ -65,6 +65,8 @@ enum PsCmdID {
PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46;
PS_REVERT = 47;
PS_CHECK_SAVE_PRE_PATCH_DONE = 48;
// pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
}
......
......@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer {
}
float* summary_decay_rate;
double summary_decay_rate_d = 0.999999;
double summary_decay_rate_d = 0.9999999;
float* param;
};
......
......@@ -339,34 +339,37 @@ int32_t MemoryDenseTable::Save(const std::string& path,
_value_accesor->Converter(save_param).deconverter;
bool is_write_failed = false;
std::vector<std::vector<std::string>> result_buffer_param(
param_dim_, std::vector<std::string>());
std::vector<std::string> result_buffer_fixed_len;
result_buffer_fixed_len.reserve(fixed_len_params_dim_);
std::vector<std::string> result_buffer_param;
result_buffer_param.reserve(param_dim_);
auto common = _config.common();
int size = static_cast<int>(common.params().size());
if (_config.common().name() == "summary") {
for (int x = 0; x < param_dim_; ++x) {
result_buffer_param[x].emplace_back(
std::to_string(values_[param_idx_][x]));
result_buffer_param.emplace_back(std::to_string(values_[param_idx_][x]));
}
} else if (_config.common().name() == "adam_d2sum") {
std::ostringstream os;
for (int y = 0; y < param_dim_; ++y) {
os.clear();
os.str("");
os << values_[param_col_ids_[0]][y] << " 0";
for (int x = 2; x < param_col_ids_.size(); ++x) {
os << " ";
os << values_[param_col_ids_[x]][y];
}
result_buffer_param.emplace_back(std::move(os.str()));
}
} else {
std::ostringstream os;
for (int x = 0; x < size; ++x) {
int dim = common.dims()[x];
VLOG(3) << "MemoryDenseTable::save dim " << x << " size: " << dim;
for (int y = 0; y < dim; ++y) {
os.clear();
os.str("");
os << values_[x][y];
if (dim == param_dim_) {
result_buffer_param[y].emplace_back(std::move(os.str()));
} else {
result_buffer_fixed_len.emplace_back(std::move(os.str()));
}
for (int y = 0; y < param_dim_; ++y) {
os.clear();
os.str("");
os << values_[param_col_ids_[0]][y];
for (int x = 1; x < param_col_ids_.size(); ++x) {
os << " ";
os << values_[param_col_ids_[x]][y];
}
result_buffer_param.emplace_back(std::move(os.str()));
}
}
......@@ -379,12 +382,9 @@ int32_t MemoryDenseTable::Save(const std::string& path,
// 40M
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto& t : result_buffer_param) {
if (_config.common().name() == "adam_d2sum") {
t.insert(t.begin() + 1, "0"); // avg_w
}
if (0 !=
write_channel->write_line(paddle::string::join_strings(t, ' '))) {
if (0 != write_channel->write_line(t)) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "DownpourDenseTable save failed, retry it! "
......@@ -395,6 +395,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
}
++feasign_size;
VLOG(3) << "save begin close " << channel_config.path;
write_channel->close();
if (err_no == -1) {
++retry_num;
......
......@@ -65,17 +65,25 @@ class MemorySparseTable : public Table {
int32_t InitializeShard() override { return 0; }
int32_t InitializeValue();
virtual int32_t Load(const std::string& path,
const std::string& param) override;
virtual int32_t Save(const std::string& path,
const std::string& param) override;
int32_t LoadLocalFS(const std::string& path, const std::string& param);
int32_t SaveLocalFS(const std::string& path,
const std::string& param,
const std::string& prefix);
int32_t Load(const std::string& path, const std::string& param) override;
int32_t Save(const std::string& path, const std::string& param) override;
int32_t SaveCache(
const std::string& path,
const std::string& param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) override;
virtual double GetCacheThreshold() { return _local_show_threshold; }
int64_t CacheShuffle(
const std::string& path,
const std::string& param,
double cache_threshold,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string& msg)> send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) override;
int64_t LocalSize();
int64_t LocalMFSize();
......@@ -89,20 +97,38 @@ class MemorySparseTable : public Table {
int32_t PushSparse(const uint64_t* keys, const float** values, size_t num);
int32_t Flush() override;
virtual int32_t Shrink(const std::string& param) override;
int32_t Shrink(const std::string& param) override;
void Clear() override;
void* GetShard(size_t shard_idx) override {
return &_local_shards[shard_idx];
}
virtual void Revert();
virtual void CheckSavePrePatchDone();
protected:
virtual int32_t SavePatch(const std::string& path, int save_param);
virtual int32_t LoadPatch(const std::vector<std::string>& file_list,
int save_param);
const int _task_pool_size = 24;
int _avg_local_shard_num;
int _real_local_shard_num;
int _sparse_table_shard_num;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::unique_ptr<shard_type[]> _local_shards;
// for patch model
int _m_avg_local_shard_num;
int _m_real_local_shard_num;
int _m_sparse_table_shard_num;
float _shard_merge_rate{1.0f};
double _local_show_threshold{0.0};
std::unique_ptr<shard_type[]> _local_shards_new;
std::unique_ptr<shard_type[]> _local_shards_patch_model;
std::thread _save_patch_model_thread;
};
} // namespace distributed
......
......@@ -71,8 +71,8 @@ class Table {
virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0;
virtual int32_t Pull(TableContext &context) = 0; // NOLINT
virtual int32_t Push(TableContext &context) = 0; // NOLINT
// only for barrier
virtual int32_t Barrier(const uint32_t trainer_id,
......@@ -125,7 +125,8 @@ class Table {
const std::string &param,
double cache_threshold,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string &msg)> send_msg_func,
int msg_type, int to_pserver_id, std::string &msg)> // NOLINT
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel,
const std::vector<Table *> &table_ptrs) {
......@@ -147,6 +148,10 @@ class Table {
virtual void *GetShard(size_t shard_idx) = 0;
virtual std::pair<int64_t, int64_t> PrintTableStat() { return {0, 0}; }
// for patch model
virtual void Revert() {}
virtual void CheckSavePrePatchDone() {}
protected:
virtual int32_t Initialize() = 0;
virtual int32_t InitializeAccessor();
......
......@@ -853,6 +853,24 @@ int32_t FleetWrapper::SaveCache(int table_id,
return feasign_cnt;
}
void FleetWrapper::Revert() {
auto ret = worker_ptr_->Revert();
ret.wait();
if (ret.get() == -1) {
LOG(ERROR) << "table revert failed";
exit(-1);
}
}
void FleetWrapper::CheckSavePrePatchDone() {
auto ret = worker_ptr_->CheckSavePrePatchDone();
ret.wait();
if (ret.get() == -1) {
LOG(ERROR) << "table revert failed";
exit(-1);
}
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
......
......@@ -300,6 +300,8 @@ class FleetWrapper {
const int mode,
const double cache_threshold);
int32_t SaveCache(int table_id, const std::string& path, const int mode);
void Revert();
void CheckSavePrePatchDone();
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
......
......@@ -150,9 +150,6 @@ TEST(MemorySparseTable, SGD) {
VLOG(3) << update_val << ": " << pull_values[i * (emb_dim + 1) + j];
}
}
MemorySparseTable *ctr_table = dynamic_cast<MemorySparseTable *>(table);
ctr_table->SaveLocalFS("./work/table.save", "0", "test");
}
} // namespace distributed
......
......@@ -114,12 +114,14 @@ message TableParameter {
optional TensorAccessorParameter tensor = 5;
optional CommonAccessorParameter common = 6;
optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ];
optional bool compress_in_save = 8 [ default = true ];
optional GraphParameter graph_parameter = 9;
// for cache model
optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
optional bool enable_revert = 13 [ default = true ];
optional float shard_merge_rate = 14 [ default = 1.0 ];
}
message TableAccessorParameter {
......
......@@ -75,7 +75,9 @@ void BindDistFleetWrapper(py::module* m) {
.def("client_flush", &FleetWrapper::ClientFlush)
.def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &FleetWrapper::CacheShuffle)
.def("save_cache", &FleetWrapper::SaveCache);
.def("save_cache", &FleetWrapper::SaveCache)
.def("revert", &FleetWrapper::Revert)
.def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone);
}
void BindPSHost(py::module* m) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册