提交 46be5864 编写于 作者: R rensilin

Merge branch 'master' of ssh://icode.baidu.com:8235/baidu/feed-mlarch/paddle-trainer

Change-Id: Icd7a4bd99682fd78137198591d04eda1ddea7385
...@@ -87,7 +87,7 @@ public: ...@@ -87,7 +87,7 @@ public:
std::string tags; std::string tags;
std::map<std::string, std::vector<float>> vec_feas; std::map<std::string, std::vector<float>> vec_feas;
int sample_type; int sample_type;
std::map<std::string, std::vector<int>> auc_category_info_map; //为细维度计算auc准备的数据 std::map<std::string, std::vector<std::string>> auc_category_info_map; //为细维度计算auc准备的数据
std::vector<FeatureItem> hot_feas, cold_feas; //冷(int32_t)热(uint64_t)feasign std::vector<FeatureItem> hot_feas, cold_feas; //冷(int32_t)热(uint64_t)feasign
void clear() { void clear() {
...@@ -149,7 +149,7 @@ public: ...@@ -149,7 +149,7 @@ public:
return cursor; return cursor;
} }
void parse_feas(char* buffer) const { void serialize_to_compress_feas(char* buffer) const {
if (buffer == nullptr) { if (buffer == nullptr) {
return ; return ;
} }
...@@ -293,16 +293,21 @@ public: ...@@ -293,16 +293,21 @@ public:
std::string& single_category_str = all_category_vec[i]; std::string& single_category_str = all_category_vec[i];
std::vector<std::string> str_vec = paddle::string::split_string(single_category_str, "="); std::vector<std::string> str_vec = paddle::string::split_string(single_category_str, "=");
CHECK(str_vec.size() == 2); CHECK(str_vec.size() == 2);
/*std::string category_name = str_vec[0]; std::string category_name = str_vec[0];
std::vector<int> category_info_vec = paddle::string::split_string<int>(str_vec[1], ","); std::vector<std::string> category_info_vec = paddle::string::split_string<std::string>(str_vec[1], ",");
CHECK(category_info_vec.size() > 0); CHECK(category_info_vec.size() > 0);
CHECK(rec.auc_category_info_map.insert({category_name, category_info_vec}).second);*/ CHECK(rec.auc_category_info_map.insert({category_name, category_info_vec}).second);
} }
} else { } else {
uint64_t sign = 0; uint64_t sign = 0;
int slot = -1; int slot = -1;
CHECK((sign = (uint64_t) strtoull(str, &cursor, 10), cursor != str)); sign = (uint64_t)strtoull(str, &cursor, 10);
if (cursor == str) { //FIXME abacus没有这种情况
str++;
continue;
}
//CHECK((sign = (uint64_t)strtoull(str, &cursor, 10), cursor != str));
str = cursor; str = cursor;
CHECK(*str++ == ':'); CHECK(*str++ == ':');
CHECK(!isspace(*str)); CHECK(!isspace(*str));
...@@ -320,17 +325,12 @@ public: ...@@ -320,17 +325,12 @@ public:
} }
paddle::framework::BinaryArchive bar; paddle::framework::BinaryArchive bar;
bar << rec.show; bar << rec.show << rec.clk << rec.tags << rec.vec_feas << rec.sample_type << rec.auc_category_info_map;
bar << rec.clk; uint32_t feas_len = rec.calc_compress_feas_lens(); //事先计算好压缩后feasign的空间
bar << rec.tags;
bar << rec.vec_feas;
bar << rec.sample_type;
bar << rec.auc_category_info_map;
uint32_t feas_len = rec.calc_compress_feas_lens();
bar << feas_len; bar << feas_len;
bar.Resize(bar.Length() + feas_len); bar.Resize(bar.Length() + feas_len);
rec.parse_feas(bar.Cursor()); rec.serialize_to_compress_feas(bar.Finish() - feas_len); //直接在archive内部buffer进行压缩,避免不必要的拷贝
data.data.assign(bar.Buffer(), bar.Length()); data.data.assign(bar.Buffer(), bar.Length()); //TODO 这一步拷贝是否也能避免
return 0; return 0;
} }
...@@ -355,7 +355,6 @@ public: ...@@ -355,7 +355,6 @@ public:
paddle::framework::BinaryArchive bar; paddle::framework::BinaryArchive bar;
bar.SetReadBuffer(const_cast<char*>(&data.data[0]), data.data.size(), nullptr); bar.SetReadBuffer(const_cast<char*>(&data.data[0]), data.data.size(), nullptr);
bar >> show; bar >> show;
bar >> clk; bar >> clk;
bar >> tags; bar >> tags;
...@@ -363,13 +362,14 @@ public: ...@@ -363,13 +362,14 @@ public:
bar >> sample_type; bar >> sample_type;
bar >> auc_category_info_map; bar >> auc_category_info_map;
bar >> feas_len; bar >> feas_len;
parse_feas_to_ins(bar.Cursor(), feas_len, instance.features); CHECK((bar.Finish() - bar.Cursor()) == feas_len);
deserialize_feas_to_ins(bar.Cursor(), feas_len, instance.features);
return 0; return 0;
} }
private: private:
void parse_feas_to_ins(char* buffer, uint32_t len, std::vector<FeatureItem>& ins) const { void deserialize_feas_to_ins(char* buffer, uint32_t len, std::vector<FeatureItem>& ins) const {
if (buffer == nullptr) { if (buffer == nullptr) {
return ; return ;
} }
......
...@@ -76,7 +76,7 @@ int LearnerProcess::run() { ...@@ -76,7 +76,7 @@ int LearnerProcess::run() {
uint64_t epoch_id = epoch_accessor->current_epoch_id(); uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Resume training with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str()); "Resume train with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
//判断是否先dump出base //判断是否先dump出base
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase); wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册