提交 f5af6905 编写于 作者: X xiexionghang

add sparse cache

上级 229964e4
...@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch') ...@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch') CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag') CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch') CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/feed-mlarch/hopscotch-map@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable') CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL') CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/') HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
......
...@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) { ...@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
//load trainer config //load trainer config
auto trainer_context_ptr = std::make_shared<TrainerContext>(); auto trainer_context_ptr = std::make_shared<TrainerContext>();
trainer_context_ptr->cache_dict.reset(new SignCacheDict);
trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path); trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path);
//environment //environment
......
...@@ -16,6 +16,8 @@ namespace feed { ...@@ -16,6 +16,8 @@ namespace feed {
int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
int ret = Process::initialize(context_ptr); int ret = Process::initialize(context_ptr);
auto& config = _context_ptr->trainer_config; auto& config = _context_ptr->trainer_config;
_is_dump_cache_model = config["dump_cache_model"].as<bool>(false);
_cache_load_converter = config["load_cache_converter"].as<std::string>("");
_startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false); _startup_dump_inference_base = config["startup_dump_inference_base"].as<bool>(false);
if (config["executor"]) { if (config["executor"]) {
_executors.resize(config["executor"].size()); _executors.resize(config["executor"].size());
...@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { ...@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return 0; return 0;
} }
// 更新各节点存储的CacheModel
int LearnerProcess::update_cache_model(uint64_t epoch_id, ModelSaveWay way) {
auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client();
auto* environment = _context_ptr->environment.get();
auto* epoch_accessor = _context_ptr->epoch_accessor.get();
if (!epoch_accessor->need_save_model(epoch_id, way)) {
return 0;
}
auto* ps_param = _context_ptr->pslib->get_param();
if (_is_dump_cache_model && way == ModelSaveWay::ModelSaveInferenceBase) {
auto model_dir = epoch_accessor->model_save_path(epoch_id, way);
auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
for (auto& param : table_param) {
if (param.type() != paddle::PS_SPARSE_TABLE) {
continue;
}
auto cache_model_path = fs->path_join(
model_dir, string::format_string("%03d_cache/", param.table_id()));
if (!fs->exists(cache_model_path)) {
continue;
}
auto& cache_dict = *(_context_ptr->cache_dict.get());
cache_dict.clear();
cache_dict.reserve(_cache_sign_max_num);
auto cache_file_list = fs->list(fs->path_join(cache_model_path, "part*"));
for (auto& cache_path : cache_file_list) {
auto cache_file = fs->open_read(cache_path, _cache_load_converter);
char *buffer = nullptr;
size_t buffer_size = 0;
ssize_t line_len = 0;
while ((line_len = getline(&buffer, &buffer_size, cache_file.get())) != -1) {
if (line_len <= 1) {
continue;
}
char* data_ptr;
cache_dict.append(strtoul(buffer, &data_ptr, 10));
}
if (buffer != nullptr) {
free(buffer);
}
}
break;
}
}
return 0;
}
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) { int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump) {
auto fs = _context_ptr->file_system; auto fs = _context_ptr->file_system;
auto* ps_client = _context_ptr->pslib->ps_client(); auto* ps_client = _context_ptr->pslib->ps_client();
...@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is ...@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
} }
timer.Pause(); timer.Pause();
VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec(); VLOG(2) << "Save Model Cost(s):" << timer.ElapsedSec();
// save cache model, 只有inference需要cache_model
auto* ps_param = _context_ptr->pslib->get_param();
if (_is_dump_cache_model && (way == ModelSaveWay::ModelSaveInferenceBase ||
way == ModelSaveWay::ModelSaveInferenceDelta)) {
auto& table_param = ps_param->server_param().downpour_server_param().downpour_table_param();
for (auto& param : table_param) {
if (param.type() != paddle::PS_SPARSE_TABLE) {
continue;
}
double cache_threshold = 0.0;
auto status = ps_client->get_cache_threshold(param.table_id(), cache_threshold);
CHECK(status.get() == 0) << "CacheThreshold Get failed!";
status = ps_client->cache_shuffle(param.table_id(), model_dir, std::to_string((int)way),
std::to_string(cache_threshold));
CHECK(status.get() == 0) << "Cache Shuffler Failed";
status = ps_client->save_cache(param.table_id(), model_dir, std::to_string((int)way));
auto feature_size = status.get();
CHECK(feature_size >= 0) << "Cache Save Failed";
auto cache_model_path = fs->path_join(
model_dir, string::format_string("%03d_cache/sparse_cache.meta", param.table_id()));
auto cache_meta_file = fs->open_write(cache_model_path, "");
auto meta = string::format_string("file_prefix:part\npart_num:%d\nkey_num:%d\n",
param.sparse_table_cache_file_num(), feature_size);
CHECK(fwrite(meta.c_str(), meta.size(), 1, cache_meta_file.get()) == 1) << "Cache Meta Failed";
if (feature_size > _cache_sign_max_num) {
_cache_sign_max_num = feature_size;
}
}
}
_context_ptr->epoch_accessor->update_model_donefile(epoch_id, way); _context_ptr->epoch_accessor->update_model_donefile(epoch_id, way);
return all_ret; return all_ret;
} }
...@@ -176,6 +256,9 @@ int LearnerProcess::run() { ...@@ -176,6 +256,9 @@ int LearnerProcess::run() {
{ {
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
update_cache_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
environment->barrier(EnvironmentRole::WORKER);
if (epoch_accessor->is_last_epoch(epoch_id)) { if (epoch_accessor->is_last_epoch(epoch_id)) {
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpointBase);
} else { } else {
......
...@@ -22,9 +22,13 @@ protected: ...@@ -22,9 +22,13 @@ protected:
virtual int load_model(uint64_t epoch_id); virtual int load_model(uint64_t epoch_id);
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型 // 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false); virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false);
virtual int update_cache_model(uint64_t epoch_id, ModelSaveWay way);
private: private:
bool _startup_dump_inference_base; //启动立即dump base bool _is_dump_cache_model; // 是否进行cache dump
uint32_t _cache_sign_max_num = 0; // cache sign最大个数
std::string _cache_load_converter; // cache加载的前置转换脚本
bool _startup_dump_inference_base; // 启动立即dump base
std::vector<std::shared_ptr<MultiThreadExecutor>> _executors; std::vector<std::shared_ptr<MultiThreadExecutor>> _executors;
}; };
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <sstream> #include <sstream>
#include <tsl/bhopscotch_map.h>
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h" #include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h" #include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
...@@ -35,15 +36,61 @@ enum class TrainerStatus { ...@@ -35,15 +36,61 @@ enum class TrainerStatus {
Saving = 1 // 模型存储状态 Saving = 1 // 模型存储状态
}; };
const uint32_t SignCacheMaxValueNum = 13;
struct SignCacheData {
SignCacheData() {
memset(cache_value, 0, sizeof(float) * SignCacheMaxValueNum);
}
uint32_t idx;
float cache_value[SignCacheMaxValueNum];
};
class SignCacheDict { class SignCacheDict {
public: public:
int32_t sign2index(uint64_t sign) { inline int32_t sign2index(uint64_t sign) {
return -1; auto itr = _sign2data_map.find(sign);
if (itr == _sign2data_map.end()) {
return -1;
}
return itr->second.idx;
}
inline uint64_t index2sign(int32_t index) {
if (index >= _sign_list.size()) {
return 0;
}
return _sign_list[index];
}
inline void reserve(uint32_t size) {
_sign_list.reserve(size);
_sign2data_map.reserve(size);
}
inline void clear() {
_sign_list.clear();
_sign2data_map.clear();
}
inline void append(uint64_t sign) {
if (_sign2data_map.find(sign) != _sign2data_map.end()) {
return;
}
SignCacheData data;
data.idx = _sign_list.size();
_sign_list.push_back(sign);
_sign2data_map.emplace(sign, std::move(data));
} }
uint64_t index2sign(int32_t index) { inline SignCacheData* data(uint64_t sign) {
return 0; tsl::bhopscotch_pg_map<uint64_t, SignCacheData>::iterator itr = _sign2data_map.find(sign);
if (itr == _sign2data_map.end()) {
return nullptr;
}
return const_cast<SignCacheData*>(&(itr->second));
} }
private:
std::vector<uint64_t> _sign_list;
tsl::bhopscotch_pg_map<uint64_t, SignCacheData> _sign2data_map;
}; };
class TrainerContext { class TrainerContext {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册