未验证 提交 f382eb06 编写于 作者: Z zhaocaibei123 提交者: GitHub

add save_cache/patch (#44420)

* add save_cache/patch

* add pybind

* remove pybind

* remove const_cast

* add fleet
上级 2792b8de
......@@ -482,7 +482,7 @@ std::future<int32_t> BrpcPsClient::GetCacheThreshold(uint32_t table_id,
request_call_num,
[request_call_num, cmd_id, &cache_threshold](void *done) {
int ret = 0;
auto *closure = (DownpourBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourBrpcClosure *>(done);
std::vector<double> cache_thresholds(request_call_num, 0);
for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) {
......@@ -530,6 +530,14 @@ std::future<int32_t> BrpcPsClient::Clear(uint32_t table_id) {
return SendCmd(table_id, PS_CLEAR_ONE_TABLE, {});
}
std::future<int32_t> BrpcPsClient::Revert() {
return SendCmd(-1, PS_REVERT, {});
}
std::future<int32_t> BrpcPsClient::CheckSavePrePatchDone() {
return SendCmd(-1, PS_CHECK_SAVE_PRE_PATCH_DONE, {});
}
std::future<int32_t> BrpcPsClient::Flush() {
VLOG(0) << "BrpcPsClient::flush begin";
_flushing = true;
......@@ -1170,6 +1178,7 @@ std::future<int32_t> BrpcPsClient::PullSparseParam(float **select_values,
}
closure->set_promise_value(ret);
});
closure->add_timer(timer);
auto promise = std::make_shared<std::promise<int32_t>>();
closure->add_promise(promise);
......
......@@ -178,6 +178,9 @@ class BrpcPsClient : public PSClient {
std::future<int32_t> Clear(uint32_t table_id) override;
std::future<int32_t> Revert() override;
std::future<int32_t> CheckSavePrePatchDone() override;
std::future<int32_t> StopServer() override;
std::future<int32_t> StartProfiler() override;
......@@ -298,16 +301,16 @@ class BrpcPsClient : public PSClient {
int PushSparseAsyncShardMerge(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num,
std::vector<int> &request_kv_num, // NOLINT
int table_id,
int shard_idx, // NOLINT
int shard_idx,
ValueAccessor *accessor);
int PushSparseAsyncShardPush(
std::vector<std::shared_ptr<SparseAsyncTask>> &task_list, // NOLINT
std::vector<int> &request_kv_num,
std::vector<int> &request_kv_num, // NOLINT
int table_id,
int shard_idx, // NOLINT
int shard_idx,
DownpourBrpcClosure *closure,
ValueAccessor *accessor);
......
......@@ -146,7 +146,7 @@ std::future<int32_t> BrpcPsServer::SendPServer2PServerMsg(
return fut;
}
auto *closure = new DownpourPServerBrpcClosure(1, [msg_type](void *done) {
auto *closure = (DownpourPServerBrpcClosure *)done;
auto *closure = reinterpret_cast<DownpourPServerBrpcClosure *>(done);
int32_t ret = closure->check_response(0, msg_type + 1000);
closure->set_promise_value(ret);
});
......@@ -209,13 +209,16 @@ int32_t BrpcPsService::Initialize() {
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::StopProfiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &BrpcPsService::PushGlobalStep;
// for save cache
_service_handler_map[PS_SAVE_ONE_CACHE_TABLE] =
&BrpcPsService::SaveCacheTable;
_service_handler_map[PS_GET_CACHE_THRESHOLD] =
&BrpcPsService::GetCacheThreshold;
_service_handler_map[PS_CACHE_SHUFFLE] = &BrpcPsService::CacheShuffle;
_service_handler_map[PS_REVERT] = &BrpcPsService::Revert;
_service_handler_map[PS_CHECK_SAVE_PRE_PATCH_DONE] =
&BrpcPsService::CheckSavePrePatchDone;
auto &profiler = CostProfiler::instance();
profiler.register_profiler("pserver_server_pull_dense");
profiler.register_profiler("pserver_server_push_dense");
......@@ -319,9 +322,8 @@ int32_t BrpcPsService::PullDense(Table *table,
table_context.pull_context.values = res_data->data();
table_context.num = num;
table->Pull(table_context);
// table->PullDense(res_data->data(), num);
cntl->response_attachment().append((char *)(res_data->data()),
cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float));
butil::return_object(res_data);
......@@ -356,7 +358,6 @@ int32_t BrpcPsService::PushDenseParam(Table *table,
table_context.push_context.is_param = true;
table_context.num = num;
// if (table->PushDenseParam(values, num) != 0) {
if (table->Push(table_context) != 0) {
set_response_code(response, -1, "PushDenseParam failed");
}
......@@ -438,7 +439,8 @@ int32_t BrpcPsService::PushSparseParam(Table *table,
"least 1 for num of sparse_key");
return 0;
}
uint32_t num = *(uint32_t *)(request.params(0).c_str());
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/*
Push Content:
|---keysData---|---valuesData---|
......@@ -484,10 +486,11 @@ int32_t BrpcPsService::PullGeoParam(Table *table,
// table->PullGeoParam(trainer_id, &values, &ids);
uint32_t num = ids.size();
cntl->response_attachment().append((char *)(&num), sizeof(uint32_t));
cntl->response_attachment().append((char *)ids.data(),
cntl->response_attachment().append(reinterpret_cast<char *>(&num),
sizeof(uint32_t));
cntl->response_attachment().append(reinterpret_cast<char *>(ids.data()),
ids.size() * sizeof(uint64_t));
cntl->response_attachment().append((char *)values.data(),
cntl->response_attachment().append(reinterpret_cast<char *>(values.data()),
values.size() * sizeof(float));
return 0;
}
......@@ -517,7 +520,8 @@ int32_t BrpcPsService::PullSparse(Table *table,
}
CostTimer timer("pserver_server_pull_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
auto dim = table->ValueAccesor()->GetAccessorInfo().select_dim;
thread_local std::string req_buffer;
......@@ -539,7 +543,7 @@ int32_t BrpcPsService::PullSparse(Table *table,
table->Pull(table_context);
// table->PullSparse(res_data->data(), value);
cntl->response_attachment().append((char *)(res_data->data()),
cntl->response_attachment().append(reinterpret_cast<char *>(res_data->data()),
res_data->size() * sizeof(float));
butil::return_object(res_data);
return 0;
......@@ -565,7 +569,8 @@ int32_t BrpcPsService::PushSparse(Table *table,
return 0;
}
CostTimer timer("pserver_server_push_sparse");
uint32_t num = *(uint32_t *)(request.params(0).c_str());
const uint32_t num =
*(reinterpret_cast<const uint32_t *>(request.params(0).c_str()));
/*
Push Content:
|---keysData---|---valuesData---|
......@@ -767,6 +772,29 @@ int32_t BrpcPsService::GetCacheThreshold(Table *table,
return 0;
}
int32_t BrpcPsService::Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->Flush();
itr.second->Revert();
}
return 0;
}
int32_t BrpcPsService::CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->GetTable());
for (auto &itr : table_map) {
itr.second->CheckSavePrePatchDone();
}
return 0;
}
int32_t BrpcPsService::ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
......
......@@ -53,12 +53,12 @@ class BrpcPsServer : public PSServer {
}
int32_t Port();
virtual int32_t StartS2S() override;
virtual ::std::future<int32_t> SendPServer2PServerMsg(
int32_t StartS2S() override;
::std::future<int32_t> SendPServer2PServerMsg(
int msg_type, int to_pserver_id, const std::string &msg) override;
virtual int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) override;
int32_t ReceiveFromPServer(int msg_type,
int pserver_id,
const std::string &msg) override;
private:
virtual int32_t Initialize();
......@@ -75,118 +75,128 @@ class BrpcPsService;
typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
class BrpcPsService : public PsBaseService {
public:
virtual int32_t Initialize() override;
int32_t Initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
void service(::google::protobuf::RpcController *controller,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:
int32_t InitializeShardInfo();
int32_t PullDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushDense(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushDenseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushSparseParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PullSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PullGeoParam(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t Barrier(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushSparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t LoadOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t LoadAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ShrinkTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ClearOneTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t ClearAllTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StopServer(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StartProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t StopProfiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PrintTableStat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t PushGlobalStep(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t CacheShuffle(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t SaveCacheTable(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t GetCacheThreshold(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t Revert(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
int32_t CheckSavePrePatchDone(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response, // NOLINT
brpc::Controller *cntl);
bool _is_initialize_shard_info;
std::mutex _initialize_shard_mutex;
std::unordered_map<int32_t, serviceHandlerFunc> _service_handler_map;
......@@ -208,7 +218,7 @@ class DownpourPServerBrpcClosure : public PServerClosure {
}
virtual ~DownpourPServerBrpcClosure() {}
virtual void Run() override {
void Run() override {
if (_waiting_num.fetch_sub(1) == 1) {
_callback(this);
delete this;
......
......@@ -67,12 +67,12 @@ class PSClient {
PSClient(PSClient &&) = delete;
PSClient(const PSClient &) = delete;
virtual int32_t Configure( // NOLINT
virtual int32_t Configure(
const PSParameter &config,
const std::map<uint64_t, std::vector<paddle::distributed::Region>>
&regions,
PSEnvironment &_env,
size_t client_id) final; // NOLINT
PSEnvironment &_env, // NOLINT
size_t client_id) final;
virtual int32_t CreateClient2ClientConnection(int pserver_timeout_ms,
int pserver_connect_timeout_ms,
......@@ -293,8 +293,25 @@ class PSClient {
return fut;
}
virtual std::future<int32_t> GetCacheThreshold(uint32_t table_id,
double &cache_threshold) {
virtual std::future<int32_t> GetCacheThreshold(
uint32_t table_id,
double &cache_threshold) { // NOLINT
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> Revert() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
promise.set_value(-1);
return fut;
}
virtual std::future<int32_t> CheckSavePrePatchDone() {
VLOG(0) << "Did not implement";
std::promise<int32_t> promise;
std::future<int> fut = promise.get_future();
......
......@@ -65,6 +65,8 @@ enum PsCmdID {
PS_SAVE_WITH_SHARD = 44;
PS_QUERY_WITH_SCOPE = 45;
PS_QUERY_WITH_SHARD = 46;
PS_REVERT = 47;
PS_CHECK_SAVE_PRE_PATCH_DONE = 48;
// pserver2pserver cmd start from 100
PS_S2S_MSG = 101;
}
......
......@@ -299,7 +299,7 @@ class DSummary : public DenseOptimizer {
}
float* summary_decay_rate;
double summary_decay_rate_d = 0.999999;
double summary_decay_rate_d = 0.9999999;
float* param;
};
......
......@@ -339,34 +339,37 @@ int32_t MemoryDenseTable::Save(const std::string& path,
_value_accesor->Converter(save_param).deconverter;
bool is_write_failed = false;
std::vector<std::vector<std::string>> result_buffer_param(
param_dim_, std::vector<std::string>());
std::vector<std::string> result_buffer_fixed_len;
result_buffer_fixed_len.reserve(fixed_len_params_dim_);
std::vector<std::string> result_buffer_param;
result_buffer_param.reserve(param_dim_);
auto common = _config.common();
int size = static_cast<int>(common.params().size());
if (_config.common().name() == "summary") {
for (int x = 0; x < param_dim_; ++x) {
result_buffer_param[x].emplace_back(
std::to_string(values_[param_idx_][x]));
result_buffer_param.emplace_back(std::to_string(values_[param_idx_][x]));
}
} else if (_config.common().name() == "adam_d2sum") {
std::ostringstream os;
for (int y = 0; y < param_dim_; ++y) {
os.clear();
os.str("");
os << values_[param_col_ids_[0]][y] << " 0";
for (int x = 2; x < param_col_ids_.size(); ++x) {
os << " ";
os << values_[param_col_ids_[x]][y];
}
result_buffer_param.emplace_back(std::move(os.str()));
}
} else {
std::ostringstream os;
for (int x = 0; x < size; ++x) {
int dim = common.dims()[x];
VLOG(3) << "MemoryDenseTable::save dim " << x << " size: " << dim;
for (int y = 0; y < dim; ++y) {
os.clear();
os.str("");
os << values_[x][y];
if (dim == param_dim_) {
result_buffer_param[y].emplace_back(std::move(os.str()));
} else {
result_buffer_fixed_len.emplace_back(std::move(os.str()));
}
for (int y = 0; y < param_dim_; ++y) {
os.clear();
os.str("");
os << values_[param_col_ids_[0]][y];
for (int x = 1; x < param_col_ids_.size(); ++x) {
os << " ";
os << values_[param_col_ids_[x]][y];
}
result_buffer_param.emplace_back(std::move(os.str()));
}
}
......@@ -379,12 +382,9 @@ int32_t MemoryDenseTable::Save(const std::string& path,
// 40M
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto& t : result_buffer_param) {
if (_config.common().name() == "adam_d2sum") {
t.insert(t.begin() + 1, "0"); // avg_w
}
if (0 !=
write_channel->write_line(paddle::string::join_strings(t, ' '))) {
if (0 != write_channel->write_line(t)) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "DownpourDenseTable save failed, retry it! "
......@@ -395,6 +395,7 @@ int32_t MemoryDenseTable::Save(const std::string& path,
}
++feasign_size;
VLOG(3) << "save begin close " << channel_config.path;
write_channel->close();
if (err_no == -1) {
++retry_num;
......
......@@ -12,15 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include <omp.h>
#include <sstream>
#include "glog/logging.h"
#include "paddle/fluid/distributed/common/cost_timer.h"
#include "paddle/fluid/distributed/common/local_random.h"
#include "paddle/fluid/distributed/common/topk_calculator.h"
#include "paddle/fluid/distributed/ps/table/memory_sparse_table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/io/fs.h"
// #include "boost/lexical_cast.hpp"
#include "paddle/fluid/platform/enforce.h"
DEFINE_bool(pserver_print_missed_key_num_every_push,
......@@ -68,6 +71,30 @@ int32_t MemorySparseTable::InitializeValue() {
_local_shards.reset(new shard_type[_real_local_shard_num]);
if (_config.enable_revert()) {
// calculate merged shard number based on config param;
_shard_merge_rate = _config.has_shard_merge_rate()
? _config.shard_merge_rate()
: _shard_merge_rate;
CHECK((_m_avg_local_shard_num = static_cast<int>(
std::ceil(_avg_local_shard_num * _shard_merge_rate)),
_m_avg_local_shard_num <= _avg_local_shard_num));
CHECK((_m_real_local_shard_num = static_cast<int>(
std::ceil(_real_local_shard_num * _shard_merge_rate)),
_m_real_local_shard_num <= _real_local_shard_num));
uint32_t avg_shard_server_num =
_sparse_table_shard_num / _avg_local_shard_num;
uint32_t last_server_shard_num =
_sparse_table_shard_num - avg_shard_server_num * _avg_local_shard_num;
_m_sparse_table_shard_num =
avg_shard_server_num * _m_avg_local_shard_num +
std::ceil(last_server_shard_num * _shard_merge_rate);
LOG(INFO) << "merged shard info: [" << _m_sparse_table_shard_num << "|"
<< _m_avg_local_shard_num << "|" << _m_real_local_shard_num
<< "]";
_local_shards_new.reset(new shard_type[_real_local_shard_num]);
}
return 0;
}
......@@ -93,8 +120,16 @@ int32_t MemorySparseTable::Load(const std::string& path,
return -1;
}
if (load_param == 5) {
return LoadPatch(file_list, load_param);
}
size_t file_start_idx = _shard_idx * _avg_local_shard_num;
if (file_start_idx >= file_list.size()) {
return 0;
}
size_t feature_value_size =
_value_accesor->GetAccessorInfo().size / sizeof(float);
......@@ -161,30 +196,37 @@ int32_t MemorySparseTable::Load(const std::string& path,
return 0;
}
int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
const std::string& param) {
std::string table_path = TableDir(path);
auto file_list = paddle::framework::localfs_list(table_path);
size_t expect_shard_num = _sparse_table_shard_num;
if (file_list.size() != expect_shard_num) {
LOG(WARNING) << "MemorySparseTable file_size:" << file_list.size()
<< " not equal to expect_shard_num:" << expect_shard_num;
return -1;
}
if (file_list.size() == 0) {
LOG(WARNING) << "MemorySparseTable load file is empty, path:" << path;
return -1;
int32_t MemorySparseTable::LoadPatch(const std::vector<std::string>& file_list,
int load_param) {
if (!_config.enable_revert()) {
LOG(INFO) << "MemorySparseTable should be enabled revert.";
return 0;
}
// 聚合分片数据索引
int start_idx = _shard_idx * _m_avg_local_shard_num;
int end_idx = start_idx + _m_real_local_shard_num;
// 原始分片数据索引
int o_start_idx = _shard_idx * _avg_local_shard_num;
int o_end_idx = o_start_idx + _real_local_shard_num;
size_t file_start_idx = _shard_idx * _avg_local_shard_num;
if (start_idx >= file_list.size()) {
return 0;
}
size_t feature_value_size =
_value_accesor->GetAccessorInfo().size / sizeof(float);
end_idx =
end_idx < _m_sparse_table_shard_num ? end_idx : _m_sparse_table_shard_num;
int thread_num = (end_idx - start_idx) < 15 ? (end_idx - start_idx) : 15;
int thread_num = _real_local_shard_num < 15 ? _real_local_shard_num : 15;
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < _real_local_shard_num; ++i) {
for (size_t i = start_idx; i < end_idx; ++i) {
FsChannelConfig channel_config;
channel_config.path = file_list[i];
channel_config.converter = _value_accesor->Converter(load_param).converter;
channel_config.deconverter =
_value_accesor->Converter(load_param).deconverter;
bool is_read_failed = false;
int retry_num = 0;
int err_no = 0;
......@@ -192,31 +234,55 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
is_read_failed = false;
err_no = 0;
std::string line_data;
std::ifstream file(file_list[file_start_idx + i]);
auto read_channel = _afs_client.open_r(channel_config, 0, &err_no);
char* end = NULL;
auto& shard = _local_shards[i];
int m_local_shard_id = i % _m_avg_local_shard_num;
std::unordered_set<size_t> global_shard_idx;
std::string global_shard_idx_str;
for (size_t j = o_start_idx; j < o_end_idx; ++j) {
if ((j % _avg_local_shard_num) % _m_real_local_shard_num ==
m_local_shard_id) {
global_shard_idx.insert(j);
global_shard_idx_str.append(std::to_string(j)).append(",");
}
}
try {
while (std::getline(file, line_data) && line_data.size() > 1) {
while (read_channel->read_line(line_data) == 0 &&
line_data.size() > 1) {
uint64_t key = std::strtoul(line_data.data(), &end, 10);
auto index_iter =
global_shard_idx.find(key % _sparse_table_shard_num);
if (index_iter == global_shard_idx.end()) {
LOG(WARNING) << "MemorySparseTable key:" << key
<< " not match shard,"
<< " file_idx:" << i
<< " global_shard_idx:" << global_shard_idx_str
<< " shard num:" << _sparse_table_shard_num
<< " file:" << channel_config.path;
continue;
}
size_t local_shard_idx = *index_iter % _avg_local_shard_num;
auto& shard = _local_shards[local_shard_idx];
auto& value = shard[key];
value.resize(feature_value_size);
int parse_size = _value_accesor->ParseFromString(++end, value.data());
value.resize(parse_size);
}
file.close();
read_channel->close();
if (err_no == -1) {
++retry_num;
is_read_failed = true;
LOG(ERROR)
<< "MemorySparseTable load failed after read, retry it! path:"
<< file_list[file_start_idx + i] << " , retry_num=" << retry_num;
<< channel_config.path << " , retry_num=" << retry_num;
}
} catch (...) {
++retry_num;
is_read_failed = true;
LOG(ERROR) << "MemorySparseTable load failed, retry it! path:"
<< file_list[file_start_idx + i]
<< " , retry_num=" << retry_num;
<< channel_config.path << " , retry_num=" << retry_num;
}
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable load failed reach max limit!";
......@@ -225,16 +291,44 @@ int32_t MemorySparseTable::LoadLocalFS(const std::string& path,
} while (is_read_failed);
}
LOG(INFO) << "MemorySparseTable load success, path from "
<< file_list[file_start_idx] << " to "
<< file_list[file_start_idx + _real_local_shard_num - 1];
<< file_list[start_idx] << " to " << file_list[end_idx - 1];
return 0;
}
void MemorySparseTable::Revert() {
for (size_t i = 0; i < _real_local_shard_num; ++i) {
_local_shards_new[i].clear();
}
}
void MemorySparseTable::CheckSavePrePatchDone() {
_save_patch_model_thread.join();
}
int32_t MemorySparseTable::Save(const std::string& dirname,
const std::string& param) {
if (_real_local_shard_num == 0) {
_local_show_threshold = -1;
return 0;
}
VLOG(0) << "MemorySparseTable::save dirname: " << dirname;
int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
// patch model
if (save_param == 5) {
_local_shards_patch_model.reset(_local_shards_new.release());
_local_shards_new.reset(new shard_type[_real_local_shard_num]);
_save_patch_model_thread = std::thread(std::bind(
&MemorySparseTable::SavePatch, this, std::string(dirname), save_param));
return 0;
}
// cache model
int64_t tk_size = LocalSize() * _config.sparse_table_cache_rate();
TopkCalculator tk(_real_local_shard_num, tk_size);
std::string table_path = TableDir(dirname);
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx));
......@@ -274,6 +368,13 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_config.enable_sparse_table_cache() &&
(save_param == 1 || save_param == 2) &&
_value_accesor->Save(it.value().data(), 4)) {
CostTimer timer10("sprase table top push");
tk.push(i, _value_accesor->GetField(it.value().data(), "show"));
}
if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->ParseToString(
it.value().data(), it.value().size());
......@@ -310,55 +411,266 @@ int32_t MemorySparseTable::Save(const std::string& dirname,
_value_accesor->UpdateStatAfterSave(it.value().data(), save_param);
}
LOG(INFO) << "MemorySparseTable save prefix success, path: "
<< channel_config.path;
<< channel_config.path << " feasign_size: " << feasign_size;
}
_local_show_threshold = tk.top();
// int32 may overflow need to change return value
return 0;
}
int32_t MemorySparseTable::SaveLocalFS(const std::string& dirname,
const std::string& param,
const std::string& prefix) {
int save_param =
atoi(param.c_str()); // checkpoint:0 xbox delta:1 xbox base:2
std::string table_path = TableDir(dirname);
int feasign_cnt = 0;
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
int32_t MemorySparseTable::SavePatch(const std::string& path, int save_param) {
if (!_config.enable_revert()) {
LOG(INFO) << "MemorySparseTable should be enabled revert.";
return 0;
}
size_t file_start_idx = _m_avg_local_shard_num * _shard_idx;
std::string table_path = TableDir(path);
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d-*", table_path.c_str(), _shard_idx));
int thread_num = _m_real_local_shard_num < 20 ? _m_real_local_shard_num : 20;
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
std::atomic<uint32_t> feasign_size_all{0};
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < _real_local_shard_num; ++i) {
feasign_cnt = 0;
auto& shard = _local_shards[i];
std::string file_name =
paddle::string::format_string("%s/part-%s-%03d-%05d",
table_path.c_str(),
prefix.c_str(),
_shard_idx,
file_start_idx + i);
std::ofstream os;
os.open(file_name);
for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value =
_value_accesor->ParseToString(it.value().data(), it.value().size());
std::string out_line = paddle::string::format_string(
"%lu %s\n", it.key(), format_value.c_str());
// VLOG(2) << out_line.c_str();
os.write(out_line.c_str(), sizeof(char) * out_line.size());
++feasign_cnt;
for (size_t i = 0; i < _m_real_local_shard_num; ++i) {
FsChannelConfig channel_config;
channel_config.path = paddle::string::format_string("%s/part-%03d-%05d",
table_path.c_str(),
_shard_idx,
file_start_idx + i);
channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter =
_value_accesor->Converter(save_param).deconverter;
bool is_write_failed = false;
int feasign_size = 0;
int retry_num = 0;
int err_no = 0;
do {
err_no = 0;
feasign_size = 0;
is_write_failed = false;
auto write_channel =
_afs_client.open_w(channel_config, 1024 * 1024 * 40, &err_no);
for (size_t j = 0; j < _real_local_shard_num; ++j) {
if (j % _m_real_local_shard_num == i) {
auto& shard = _local_shards_patch_model[j];
for (auto it = shard.begin(); it != shard.end(); ++it) {
if (_value_accesor->Save(it.value().data(), save_param)) {
std::string format_value = _value_accesor->ParseToString(
it.value().data(), it.value().size());
if (0 != write_channel->write_line(paddle::string::format_string(
"%lu %s", it.key(), format_value.c_str()))) {
++retry_num;
is_write_failed = true;
LOG(ERROR) << "MemorySparseTable save failed, retry it! path:"
<< channel_config.path
<< " , retry_num=" << retry_num;
break;
}
++feasign_size;
}
}
}
if (is_write_failed) break;
}
write_channel->close();
if (err_no == -1) {
++retry_num;
is_write_failed = true;
LOG(ERROR)
<< "MemorySparseTable save patch failed after write, retry it! "
<< "path:" << channel_config.path << " , retry_num=" << retry_num;
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
if (retry_num > FLAGS_pserver_table_save_max_retry) {
LOG(ERROR) << "MemorySparseTable save patch failed reach max limit!";
exit(-1);
}
} while (is_write_failed);
feasign_size_all += feasign_size;
}
LOG(INFO) << "MemorySparseTable save patch success, path:"
<< paddle::string::format_string("%s/%03d/part-%03d-",
path.c_str(),
_config.table_id(),
_shard_idx)
<< " from " << file_start_idx << " to "
<< file_start_idx + _m_real_local_shard_num - 1
<< ", feasign size: " << feasign_size_all;
return 0;
}
int64_t MemorySparseTable::CacheShuffle(
const std::string& path,
const std::string& param,
double cache_threshold,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string& msg)> send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) {
LOG(INFO) << "cache shuffle with cache threshold: " << cache_threshold;
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
if (!_config.enable_sparse_table_cache() || cache_threshold < 0) {
LOG(WARNING)
<< "cache shuffle failed not enable table cache or cache threshold < 0 "
<< _config.enable_sparse_table_cache() << " or " << cache_threshold;
// return -1;
}
int shuffle_node_num = _config.sparse_table_cache_file_num();
LOG(INFO) << "Table>> shuffle node num is: " << shuffle_node_num;
// TODO(zhaocaibei123): check shuffle_node_num <= server_node_num
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
int thread_num = _real_local_shard_num < 20 ? _real_local_shard_num : 20;
std::vector<
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>>
writers(_real_local_shard_num);
std::vector<std::vector<std::pair<uint64_t, std::string>>> datas(
_real_local_shard_num);
int feasign_size = 0;
std::vector<paddle::framework::Channel<std::pair<uint64_t, std::string>>>
tmp_channels;
for (size_t i = 0; i < _real_local_shard_num; ++i) {
tmp_channels.push_back(
paddle::framework::MakeChannel<std::pair<uint64_t, std::string>>());
}
omp_set_num_threads(thread_num);
#pragma omp parallel for schedule(dynamic)
for (size_t i = 0; i < _real_local_shard_num; ++i) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer =
writers[i];
writer.Reset(tmp_channels[i].get());
for (size_t idx = 0; idx < table_ptrs.size(); idx++) {
Table* table_ptr = table_ptrs[idx];
auto value_accesor = table_ptr->ValueAccesor();
shard_type* shard_ptr = static_cast<shard_type*>(table_ptr->GetShard(i));
for (auto it = shard_ptr->begin(); it != shard_ptr->end(); ++it) {
if (value_accesor->SaveCache(
it.value().data(), save_param, cache_threshold)) {
std::string format_value = value_accesor->ParseToString(
it.value().data(), it.value().size());
std::pair<uint64_t, std::string> pkv(it.key(), format_value.c_str());
writer << pkv;
++feasign_size;
}
}
}
writer.Flush();
writer.channel()->Close();
}
// LOG(INFO) << "MemorySparseTable cache KV save success to Channel feasigh
// size: " << feasign_size << " and start sparse cache data shuffle real local
// shard num: " << _real_local_shard_num;
std::vector<std::pair<uint64_t, std::string>> local_datas;
for (size_t idx_shard = 0; idx_shard < _real_local_shard_num; ++idx_shard) {
paddle::framework::ChannelWriter<std::pair<uint64_t, std::string>>& writer =
writers[idx_shard];
auto channel = writer.channel();
std::vector<std::pair<uint64_t, std::string>>& data = datas[idx_shard];
std::vector<paddle::framework::BinaryArchive> ars(shuffle_node_num);
while (channel->Read(data)) {
for (auto& t : data) {
auto pserver_id =
paddle::distributed::local_random_engine()() % shuffle_node_num;
if (pserver_id != _shard_idx) {
ars[pserver_id] << t;
} else {
local_datas.emplace_back(std::move(t));
}
}
std::vector<std::future<int32_t>> total_status;
std::vector<uint32_t> send_data_size(shuffle_node_num, 0);
std::vector<int> send_index(shuffle_node_num);
for (int i = 0; i < shuffle_node_num; ++i) {
send_index[i] = i;
}
std::random_shuffle(send_index.begin(), send_index.end());
for (auto index = 0u; index < shuffle_node_num; ++index) {
int i = send_index[index];
if (i == _shard_idx) {
continue;
}
if (ars[i].Length() == 0) {
continue;
}
std::string msg(ars[i].Buffer(), ars[i].Length());
auto ret = send_msg_func(101, i, msg);
total_status.push_back(std::move(ret));
send_data_size[i] += ars[i].Length();
}
for (auto& t : total_status) {
t.wait();
}
ars.clear();
ars = std::vector<paddle::framework::BinaryArchive>(shuffle_node_num);
data = std::vector<std::pair<uint64_t, std::string>>();
}
os.close();
LOG(INFO) << "MemorySparseTable save prefix success, path:" << file_name
<< "feasign_cnt: " << feasign_cnt;
}
shuffled_channel->Write(std::move(local_datas));
return 0;
}
int32_t MemorySparseTable::SaveCache(
const std::string& path,
const std::string& param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) {
if (_shard_idx >= _config.sparse_table_cache_file_num()) {
return 0;
}
int save_param = atoi(param.c_str()); // batch_model:0 xbox:1
size_t file_start_idx = _avg_local_shard_num * _shard_idx;
std::string table_path = paddle::string::format_string(
"%s/%03d_cache/", path.c_str(), _config.table_id());
_afs_client.remove(paddle::string::format_string(
"%s/part-%03d", table_path.c_str(), _shard_idx));
uint32_t feasign_size = 0;
FsChannelConfig channel_config;
// not compress cache model
channel_config.path = paddle::string::format_string(
"%s/part-%03d", table_path.c_str(), _shard_idx);
channel_config.converter = _value_accesor->Converter(save_param).converter;
channel_config.deconverter =
_value_accesor->Converter(save_param).deconverter;
auto write_channel = _afs_client.open_w(channel_config, 1024 * 1024 * 40);
std::vector<std::pair<uint64_t, std::string>> data;
bool is_write_failed = false;
shuffled_channel->Close();
while (shuffled_channel->Read(data)) {
for (auto& t : data) {
++feasign_size;
if (0 != write_channel->write_line(paddle::string::format_string(
"%lu %s", t.first, t.second.c_str()))) {
LOG(ERROR) << "Cache Table save failed, "
"path:"
<< channel_config.path << ", retry it!";
is_write_failed = true;
break;
}
}
data = std::vector<std::pair<uint64_t, std::string>>();
}
if (is_write_failed) {
_afs_client.remove(channel_config.path);
}
write_channel->close();
LOG(INFO) << "MemorySparseTable cache save success, feasign: " << feasign_size
<< ", path: " << channel_config.path;
shuffled_channel->Open();
return feasign_size;
}
int64_t MemorySparseTable::LocalSize() {
int64_t local_size = 0;
for (int i = 0; i < _real_local_shard_num; ++i) {
......@@ -548,7 +860,7 @@ int32_t MemorySparseTable::PullSparsePtr(char** pull_values,
ret = itr.value_ptr();
}
int pull_data_idx = keys[i].second;
pull_values[pull_data_idx] = (char*)ret; // NOLINT
pull_values[pull_data_idx] = reinterpret_cast<char*>(ret);
}
return 0;
});
......@@ -589,6 +901,7 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
&task_keys]() -> int {
auto& keys = task_keys[shard_id];
auto& local_shard = _local_shards[shard_id];
auto& local_shard_new = _local_shards_new[shard_id];
float data_buffer[value_col]; // NOLINT
float* data_buffer_ptr = data_buffer;
for (size_t i = 0; i < keys.size(); ++i) {
......@@ -630,6 +943,14 @@ int32_t MemorySparseTable::PushSparse(const uint64_t* keys,
}
memcpy(value_data, data_buffer_ptr, value_size * sizeof(float));
}
if (_config.enable_revert()) {
FixedFeatureValue* feature_value_new = &(local_shard_new[key]);
auto new_size = feature_value.size();
feature_value_new->resize(new_size);
memcpy(feature_value_new->data(),
value_data,
new_size * sizeof(float));
}
}
return 0;
});
......
......@@ -65,17 +65,25 @@ class MemorySparseTable : public Table {
int32_t InitializeShard() override { return 0; }
int32_t InitializeValue();
virtual int32_t Load(const std::string& path,
const std::string& param) override;
virtual int32_t Save(const std::string& path,
const std::string& param) override;
int32_t LoadLocalFS(const std::string& path, const std::string& param);
int32_t SaveLocalFS(const std::string& path,
const std::string& param,
const std::string& prefix);
int32_t Load(const std::string& path, const std::string& param) override;
int32_t Save(const std::string& path, const std::string& param) override;
int32_t SaveCache(
const std::string& path,
const std::string& param,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel) override;
virtual double GetCacheThreshold() { return _local_show_threshold; }
int64_t CacheShuffle(
const std::string& path,
const std::string& param,
double cache_threshold,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string& msg)> send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>&
shuffled_channel,
const std::vector<Table*>& table_ptrs) override;
int64_t LocalSize();
int64_t LocalMFSize();
......@@ -89,20 +97,38 @@ class MemorySparseTable : public Table {
int32_t PushSparse(const uint64_t* keys, const float** values, size_t num);
int32_t Flush() override;
virtual int32_t Shrink(const std::string& param) override;
int32_t Shrink(const std::string& param) override;
void Clear() override;
void* GetShard(size_t shard_idx) override {
return &_local_shards[shard_idx];
}
virtual void Revert();
virtual void CheckSavePrePatchDone();
protected:
virtual int32_t SavePatch(const std::string& path, int save_param);
virtual int32_t LoadPatch(const std::vector<std::string>& file_list,
int save_param);
const int _task_pool_size = 24;
int _avg_local_shard_num;
int _real_local_shard_num;
int _sparse_table_shard_num;
std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::unique_ptr<shard_type[]> _local_shards;
// for patch model
int _m_avg_local_shard_num;
int _m_real_local_shard_num;
int _m_sparse_table_shard_num;
float _shard_merge_rate{1.0f};
double _local_show_threshold{0.0};
std::unique_ptr<shard_type[]> _local_shards_new;
std::unique_ptr<shard_type[]> _local_shards_patch_model;
std::thread _save_patch_model_thread;
};
} // namespace distributed
......
......@@ -71,8 +71,8 @@ class Table {
virtual int32_t Initialize(const TableParameter &config,
const FsClientParameter &fs_config);
virtual int32_t Pull(TableContext &context) = 0;
virtual int32_t Push(TableContext &context) = 0;
virtual int32_t Pull(TableContext &context) = 0; // NOLINT
virtual int32_t Push(TableContext &context) = 0; // NOLINT
// only for barrier
virtual int32_t Barrier(const uint32_t trainer_id,
......@@ -125,7 +125,8 @@ class Table {
const std::string &param,
double cache_threshold,
std::function<std::future<int32_t>(
int msg_type, int to_pserver_id, std::string &msg)> send_msg_func,
int msg_type, int to_pserver_id, std::string &msg)> // NOLINT
send_msg_func,
paddle::framework::Channel<std::pair<uint64_t, std::string>>
&shuffled_channel,
const std::vector<Table *> &table_ptrs) {
......@@ -147,6 +148,10 @@ class Table {
virtual void *GetShard(size_t shard_idx) = 0;
virtual std::pair<int64_t, int64_t> PrintTableStat() { return {0, 0}; }
// for patch model
virtual void Revert() {}
virtual void CheckSavePrePatchDone() {}
protected:
virtual int32_t Initialize() = 0;
virtual int32_t InitializeAccessor();
......
......@@ -853,6 +853,24 @@ int32_t FleetWrapper::SaveCache(int table_id,
return feasign_cnt;
}
void FleetWrapper::Revert() {
auto ret = worker_ptr_->Revert();
ret.wait();
if (ret.get() == -1) {
LOG(ERROR) << "table revert failed";
exit(-1);
}
}
void FleetWrapper::CheckSavePrePatchDone() {
auto ret = worker_ptr_->CheckSavePrePatchDone();
ret.wait();
if (ret.get() == -1) {
LOG(ERROR) << "table revert failed";
exit(-1);
}
}
std::default_random_engine& FleetWrapper::LocalRandomEngine() {
struct engine_wrapper_t {
std::default_random_engine engine;
......
......@@ -300,6 +300,8 @@ class FleetWrapper {
const int mode,
const double cache_threshold);
int32_t SaveCache(int table_id, const std::string& path, const int mode);
void Revert();
void CheckSavePrePatchDone();
static std::shared_ptr<paddle::distributed::PSCore> pserver_ptr_;
static std::shared_ptr<paddle::distributed::PSClient> worker_ptr_;
......
......@@ -150,9 +150,6 @@ TEST(MemorySparseTable, SGD) {
VLOG(3) << update_val << ": " << pull_values[i * (emb_dim + 1) + j];
}
}
MemorySparseTable *ctr_table = dynamic_cast<MemorySparseTable *>(table);
ctr_table->SaveLocalFS("./work/table.save", "0", "test");
}
} // namespace distributed
......
......@@ -114,12 +114,14 @@ message TableParameter {
optional TensorAccessorParameter tensor = 5;
optional CommonAccessorParameter common = 6;
optional TableType type = 7;
optional bool compress_in_save = 8 [ default = false ];
optional bool compress_in_save = 8 [ default = true ];
optional GraphParameter graph_parameter = 9;
// for cache model
optional bool enable_sparse_table_cache = 10 [ default = true ];
optional double sparse_table_cache_rate = 11 [ default = 0.00055 ];
optional uint32 sparse_table_cache_file_num = 12 [ default = 16 ];
optional bool enable_revert = 13 [ default = true ];
optional float shard_merge_rate = 14 [ default = 1.0 ];
}
message TableAccessorParameter {
......
......@@ -75,7 +75,9 @@ void BindDistFleetWrapper(py::module* m) {
.def("client_flush", &FleetWrapper::ClientFlush)
.def("get_cache_threshold", &FleetWrapper::GetCacheThreshold)
.def("cache_shuffle", &FleetWrapper::CacheShuffle)
.def("save_cache", &FleetWrapper::SaveCache);
.def("save_cache", &FleetWrapper::SaveCache)
.def("revert", &FleetWrapper::Revert)
.def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone);
}
void BindPSHost(py::module* m) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册