未验证 提交 123255cf 编写于 作者: H hutuxian 提交者: GitHub

change InitializeGPU to InitializeGPUAndLoadModel (#24377)

* Add InitializeGPUAndLoadModel to solve random hang when downloading sparse parameters.
* Update SaveBase to solve test problem.
上级 43625bda
......@@ -364,8 +364,10 @@ class BoxWrapper {
uint64_t* total_keys, const int64_t* gpu_len, int slot_num,
int total_len);
boxps::PSAgentBase* GetAgent() { return p_agent_; }
void InitializeGPU(const char* conf_file, const std::vector<int>& slot_vector,
const std::vector<std::string>& slot_omit_in_feedpass) {
void InitializeGPUAndLoadModel(
const char* conf_file, const std::vector<int>& slot_vector,
const std::vector<std::string>& slot_omit_in_feedpass,
const std::string& model_path) {
if (nullptr != s_instance_) {
VLOG(3) << "Begin InitializeGPU";
std::vector<cudaStream_t*> stream_list;
......@@ -380,7 +382,8 @@ class BoxWrapper {
}
VLOG(2) << "Begin call InitializeGPU in BoxPS";
// the second parameter is useless
s_instance_->boxps_ptr_->InitializeGPU(conf_file, -1, stream_list);
s_instance_->boxps_ptr_->InitializeGPUAndLoadModel(
conf_file, -1, stream_list, slot_vector, model_path);
p_agent_ = boxps::PSAgentBase::GetIns(feedpass_thread_num_);
p_agent_->Init();
for (const auto& slot_name : slot_omit_in_feedpass) {
......@@ -401,10 +404,27 @@ class BoxWrapper {
}
const std::string SaveBase(const char* batch_model_path,
const char* xbox_model_path) {
const char* xbox_model_path,
const std::string& date) {
VLOG(3) << "Begin SaveBase";
PADDLE_ENFORCE_EQ(
date.length(), 8,
platform::errors::PreconditionNotMet(
"date[%s] is invalid, correct example is 20190817", date.c_str()));
int year = std::stoi(date.substr(0, 4));
int month = std::stoi(date.substr(4, 2));
int day = std::stoi(date.substr(6, 2));
struct std::tm b;
b.tm_year = year - 1900;
b.tm_mon = month - 1;
b.tm_mday = day;
b.tm_min = b.tm_hour = b.tm_sec = 0;
std::time_t seconds_from_1970 = std::mktime(&b);
std::string ret_str;
int ret = boxps_ptr_->SaveBase(batch_model_path, xbox_model_path, ret_str);
int ret = boxps_ptr_->SaveBase(batch_model_path, xbox_model_path, ret_str,
seconds_from_1970 / 86400);
PADDLE_ENFORCE_EQ(ret, 0, platform::errors::PreconditionNotMet(
"SaveBase failed in BoxPS."));
return ret_str;
......
......@@ -73,7 +73,8 @@ void BindBoxWrapper(py::module* m) {
py::call_guard<py::gil_scoped_release>())
.def("save_delta", &framework::BoxWrapper::SaveDelta,
py::call_guard<py::gil_scoped_release>())
.def("initialize_gpu", &framework::BoxWrapper::InitializeGPU,
.def("initialize_gpu_and_load_model",
&framework::BoxWrapper::InitializeGPUAndLoadModel,
py::call_guard<py::gil_scoped_release>())
.def("init_metric", &framework::BoxWrapper::InitMetric,
py::call_guard<py::gil_scoped_release>())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册